#!/usr/bin/env python
# Copyright 2012 Free Software Foundation, Inc.
# 
# This file is part of GNU Radio
# 
# GNU Radio is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3, or (at your option)
# any later version.
# 
# GNU Radio is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with GNU Radio; see the file COPYING.  If not, write to
# the Free Software Foundation, Inc., 51 Franklin Street,
# Boston, MA 02110-1301, USA.
# 

import time

from gnuradio import gr, gr_unittest, digital, blocks
import pmt

class qa_header_payload_demux (gr_unittest.TestCase):

    def setUp (self):
        self.tb = gr.top_block ()

    def tearDown (self):
        self.tb = None

    def test_001_t (self):
        """ Simplest possible test: put in zeros, then header,
        then payload, trigger signal, try to demux.
        The return signal from the header parser is faked via _post()
        Add in some tags for fun.
        """
        n_zeros = 1
        header = (1, 2, 3)
        payload = tuple(range(5, 20))
        data_signal = (0,) * n_zeros + header + payload
        trigger_signal = [0,] * len(data_signal)
        trigger_signal[n_zeros] = 1
        # This is dropped:
        testtag1 = gr.tag_t()
        testtag1.offset = 0
        testtag1.key = pmt.string_to_symbol('tag1')
        testtag1.value = pmt.from_long(0)
        # This goes on output 0, item 0:
        testtag2 = gr.tag_t()
        testtag2.offset = n_zeros
        testtag2.key = pmt.string_to_symbol('tag2')
        testtag2.value = pmt.from_long(23)
        # This goes on output 0, item 2:
        testtag3 = gr.tag_t()
        testtag3.offset = n_zeros + len(header) - 1
        testtag3.key = pmt.string_to_symbol('tag3')
        testtag3.value = pmt.from_long(42)
        # This goes on output 1, item 3:
        testtag4 = gr.tag_t()
        testtag4.offset = n_zeros + len(header) + 3
        testtag4.key = pmt.string_to_symbol('tag4')
        testtag4.value = pmt.from_long(314)

        data_src = blocks.vector_source_f(data_signal, False, tags=(testtag1, testtag2, testtag3, testtag4))
        trigger_src = blocks.vector_source_b(trigger_signal, False)
        hpd = digital.header_payload_demux(
            len(header), 1, 0, "frame_len", "detect", False, gr.sizeof_float
        )
        self.assertEqual(pmt.length(hpd.message_ports_in()), 1)
        header_sink = blocks.vector_sink_f()
        payload_sink = blocks.vector_sink_f()

        self.tb.connect(data_src,    (hpd, 0))
        self.tb.connect(trigger_src, (hpd, 1))
        self.tb.connect((hpd, 0), header_sink)
        self.tb.connect((hpd, 1), payload_sink)
        self.tb.start()
        time.sleep(.2) # Need this, otherwise, the next message is ignored
        hpd.to_basic_block()._post(
                pmt.intern('header_data'),
                pmt.from_long(len(payload))
        )
        while len(payload_sink.data()) < len(payload):
            time.sleep(.2)
        self.tb.stop()
        self.tb.wait()

        self.assertEqual(header_sink.data(),  header)
        self.assertEqual(payload_sink.data(), payload)
        ptags_header = []
        for tag in header_sink.tags():
            ptag = gr.tag_to_python(tag)
            ptags_header.append({'key': ptag.key, 'offset': ptag.offset})
        expected_tags_header = [
                {'key': 'tag2', 'offset': 0},
                {'key': 'tag3', 'offset': 2},
        ]
        self.assertEqual(expected_tags_header, ptags_header)
        ptags_payload = []
        for tag in payload_sink.tags():
            ptag = gr.tag_to_python(tag)
            ptags_payload.append({'key': ptag.key, 'offset': ptag.offset})
        expected_tags_payload = [
                {'key': 'frame_len', 'offset': 0},
                {'key': 'tag4', 'offset': 3},
        ]
        self.assertEqual(expected_tags_payload, ptags_payload)

    def test_002_symbols (self):
        """ 
        Same as before, but operate on symbols
        """
        n_zeros = 1
        items_per_symbol = 3
        gi = 1
        n_symbols = 4
        header = (1, 2, 3)
        payload = (1, 2, 3)
        data_signal = (0,) * n_zeros + (0,) + header + ((0,) + payload) * n_symbols
        trigger_signal = [0,] * len(data_signal)
        trigger_signal[n_zeros] = 1
        # This is dropped:
        testtag1 = gr.tag_t()
        testtag1.offset = 0
        testtag1.key = pmt.string_to_symbol('tag1')
        testtag1.value = pmt.from_long(0)
        # This goes on output 0, item 0 (from the GI)
        testtag2 = gr.tag_t()
        testtag2.offset = n_zeros
        testtag2.key = pmt.string_to_symbol('tag2')
        testtag2.value = pmt.from_long(23)
        # This goes on output 0, item 0 (middle of the header symbol)
        testtag3 = gr.tag_t()
        testtag3.offset = n_zeros + gi + 1
        testtag3.key = pmt.string_to_symbol('tag3')
        testtag3.value = pmt.from_long(42)
        # This goes on output 1, item 1 (middle of the first payload symbol)
        testtag4 = gr.tag_t()
        testtag4.offset = n_zeros + (gi + items_per_symbol) * 2 + 1
        testtag4.key = pmt.string_to_symbol('tag4')
        testtag4.value = pmt.from_long(314)
        data_src = blocks.vector_source_f(data_signal, False, tags=(testtag1, testtag2, testtag3, testtag4))
        trigger_src = blocks.vector_source_b(trigger_signal, False)
        hpd = digital.header_payload_demux(
            len(header) / items_per_symbol, # Header length (in symbols)
            items_per_symbol,               # Items per symbols
            gi,                             # Items per guard time
            "frame_len",                    # Frame length tag key
            "detect",                       # Trigger tag key
            True,                           # Output symbols (not items)
            gr.sizeof_float                 # Bytes per item
        )
        self.assertEqual(pmt.length(hpd.message_ports_in()), 1)
        header_sink = blocks.vector_sink_f(items_per_symbol)
        payload_sink = blocks.vector_sink_f(items_per_symbol)
        self.tb.connect(data_src,    (hpd, 0))
        self.tb.connect(trigger_src, (hpd, 1))
        self.tb.connect((hpd, 0), header_sink)
        self.tb.connect((hpd, 1), payload_sink)
        self.tb.start()
        time.sleep(.2) # Need this, otherwise, the next message is ignored
        hpd.to_basic_block()._post(
                pmt.intern('header_data'),
                pmt.from_long(n_symbols)
        )
        while len(payload_sink.data()) < len(payload) * n_symbols:
            time.sleep(.2)
        self.tb.stop()
        self.tb.wait()
        self.assertEqual(header_sink.data(),  header)
        self.assertEqual(payload_sink.data(), payload * n_symbols)
        ptags_header = []
        for tag in header_sink.tags():
            ptag = gr.tag_to_python(tag)
            ptags_header.append({'key': ptag.key, 'offset': ptag.offset})
        expected_tags_header = [
                {'key': 'tag2', 'offset': 0},
                {'key': 'tag3', 'offset': 0},
        ]
        self.assertEqual(expected_tags_header, ptags_header)
        ptags_payload = []
        for tag in payload_sink.tags():
            ptag = gr.tag_to_python(tag)
            ptags_payload.append({'key': ptag.key, 'offset': ptag.offset})
        expected_tags_payload = [
                {'key': 'frame_len', 'offset': 0},
                {'key': 'tag4', 'offset': 1},
        ]
        self.assertEqual(expected_tags_payload, ptags_payload)

    def test_003_t (self):
        """
        Like test 1, but twice, plus one fail
        """
        n_zeros = 5
        header = (1, 2, 3)
        header_fail = (-1, -2, -4) # Contents don't really matter
        payload1 = tuple(range(5, 20))
        payload2 = (42,)
        data_signal = (0,) * n_zeros + header + payload1
        trigger_signal = [0,] * len(data_signal) * 2
        trigger_signal[n_zeros] = 1
        trigger_signal[len(data_signal)] = 1
        trigger_signal[len(data_signal)+len(header_fail)+n_zeros] = 1
        tx_signal = data_signal + header_fail + (0,) * n_zeros + header + payload2 + (0,) * 1000
        data_src = blocks.vector_source_f(tx_signal, False)
        trigger_src = blocks.vector_source_b(trigger_signal, False)
        hpd = digital.header_payload_demux(
            len(header), 1, 0, "frame_len", "detect", False, gr.sizeof_float
        )
        self.assertEqual(pmt.length(hpd.message_ports_in()), 1)
        header_sink = blocks.vector_sink_f()
        payload_sink = blocks.vector_sink_f()
        self.tb.connect(data_src,    (hpd, 0))
        self.tb.connect(trigger_src, (hpd, 1))
        self.tb.connect((hpd, 0), header_sink)
        self.tb.connect((hpd, 1), payload_sink)
        self.tb.start()
        time.sleep(.2) # Need this, otherwise, the next message is ignored
        hpd.to_basic_block()._post(
                pmt.intern('header_data'),
                pmt.from_long(len(payload1))
        )
        while len(payload_sink.data()) < len(payload1):
            time.sleep(.2)
        hpd.to_basic_block()._post(
                pmt.intern('header_data'),
                pmt.PMT_F
        )
        # This next command is a bit of a showstopper, but there's no condition to check upon
        # to see if the previous msg handling is finished
        time.sleep(.7)
        hpd.to_basic_block()._post(
                pmt.intern('header_data'),
                pmt.from_long(len(payload2))
        )
        while len(payload_sink.data()) < len(payload1) + len(payload2):
            time.sleep(.2)
        self.tb.stop()
        self.tb.wait()
        self.assertEqual(header_sink.data(),  header + header_fail + header)
        self.assertEqual(payload_sink.data(), payload1 + payload2)

if __name__ == '__main__':
    gr_unittest.run(qa_header_payload_demux, "qa_header_payload_demux.xml")