# Copyright 2008 Free Software Foundation, Inc.
#
# This file is part of GNU Radio
#
# GNU Radio is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3, or (at your option)
# any later version.
#
# GNU Radio is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with GNU Radio; see the file COPYING.  If not, write to
# the Free Software Foundation, Inc., 51 Franklin Street,
# Boston, MA 02110-1301, USA.
#

default_win_size = 1000

from gnuradio import gr
from gnuradio import blocks
import gnuradio.gr.gr_threading as _threading
import numpy

#generate 1s counts array
_1s_counts = [sum([1&(i>>j) for j in range(8)]) for i in range(2**8)]

class input_watcher(_threading.Thread):
    """
    Read samples from the message queue and hand them to the callback.
    """

    def __init__(self, msgq, callback):
        self._msgq = msgq
        self._callback = callback
        _threading.Thread.__init__(self)
        self.setDaemon(1)
        self.keep_running = True
        self.start()

    def run(self):
        r = ''
        while True:
            msg = self._msgq.delete_head()
            itemsize = int(msg.arg1())
            nitems = int(msg.arg2())
            s = r + msg.to_string()
            i = (nitems-nitems%2)*itemsize
            r = s[i:]
            s = s[:i]
            samples = numpy.fromstring(s, numpy.int8)
            self._callback(samples)

class error_rate(gr.hier_block2):
    """
    Sample the incoming data streams (byte) and calculate the bit or symbol error rate.
    Write the running rate to the output data stream (float).
    """

    def __init__(self, type='BER', win_size=default_win_size, bits_per_symbol=2):
        """
        Error rate constructor.

        Args:
            type: a string 'BER' or 'SER'
            win_size: the number of samples to calculate over
            bits_per_symbol: the number of information bits per symbol (BER only)
        """
        #init
        gr.hier_block2.__init__(
            self, 'error_rate',
            gr.io_signature(2, 2, gr.sizeof_char),
            gr.io_signature(1, 1, gr.sizeof_float),
        )

        print "Warning: the blks2.error_rate is deprecated."

        assert type in ('BER', 'SER')
        self._max_samples = win_size
        self._bits_per_symbol = bits_per_symbol
        #setup message queue
        msg_source = blocks.message_source(gr.sizeof_float, 1)
        self._msgq_source = msg_source.msgq()
        msgq_sink = gr.msg_queue(2)
        msg_sink = blocks.message_sink(gr.sizeof_char, msgq_sink, False) #False -> blocking
        inter = blocks.interleave(gr.sizeof_char)
        #start thread
        self._num_errs = 0
        self._err_index = 0
        self._num_samps = 0
        self._err_array = numpy.zeros(self._max_samples, numpy.int8)
        if type == 'BER':
            input_watcher(msgq_sink, self._handler_ber)
        elif type == 'SER':
            input_watcher(msgq_sink, self._handler_ser)
        #connect
        self.connect(msg_source, self)
        self.connect((self, 0), (inter, 0))
        self.connect((self, 1), (inter, 1))
        self.connect(inter, msg_sink)

    def _handler_ber(self, samples):
        num = len(samples)/2
        arr = numpy.zeros(num, numpy.float32)
        for i in range(num):
            old_err = self._err_array[self._err_index]
            #record error
            self._err_array[self._err_index] = _1s_counts[samples[i*2] ^ samples[i*2 + 1]]
            self._num_errs = self._num_errs + self._err_array[self._err_index] - old_err
            #increment index
            self._err_index = (self._err_index + 1)%self._max_samples
            self._num_samps = min(self._num_samps + 1, self._max_samples)
            #write sample
            arr[i] = float(self._num_errs)/float(self._num_samps*self._bits_per_symbol)
        #write message
        msg = gr.message_from_string(arr.tostring(), 0, gr.sizeof_float, num)
        self._msgq_source.insert_tail(msg)

    def _handler_ser(self, samples):
        num = len(samples)/2
        arr = numpy.zeros(num, numpy.float32)
        for i in range(num):
            old_err = self._err_array[self._err_index]
            #record error
            ref = samples[i*2]
            res = samples[i*2 + 1]
            if ref == res:
                self._err_array[self._err_index] = 0
            else:
                self._err_array[self._err_index] = 1
            #update number of errors
            self._num_errs = self._num_errs + self._err_array[self._err_index] - old_err
            #increment index
            self._err_index = (self._err_index + 1)%self._max_samples
            self._num_samps = min(self._num_samps + 1, self._max_samples)
            #write sample
            arr[i] = float(self._num_errs)/float(self._num_samps)
        #write message
        msg = gr.message_from_string(arr.tostring(), 0, gr.sizeof_float, num)
        self._msgq_source.insert_tail(msg)