summaryrefslogtreecommitdiff
path: root/gr-fec/python/fec/polar/encoder.py
blob: 3b5eea2a9485644caf61d48bd5bdb3256a1572dd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python
#
# Copyright 2015 Free Software Foundation, Inc.
#
# GNU Radio is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3, or (at your option)
# any later version.
#
# GNU Radio is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with GNU Radio; see the file COPYING.  If not, write to
# the Free Software Foundation, Inc., 51 Franklin Street,
# Boston, MA 02110-1301, USA.
#

import numpy as np
from common import PolarCommon
import helper_functions as hf


class PolarEncoder(PolarCommon):
    def __init__(self, n, k, frozen_bit_position, frozenbits=None):
        PolarCommon.__init__(self, n, k, frozen_bit_position, frozenbits)
        self.G = hf.get_Fn(n)

    def get_gn(self):
        return self.G

    def _prepare_input_data(self, vec):
        vec = self._insert_frozen_bits(vec)
        vec = self._reverse_bits(vec)
        return vec

    def _encode_matrix(self, data):
        data = np.dot(data, self.G) % 2
        data = data.astype(dtype=int)
        return data

    def _encode_efficient(self, vec):
        n_stages = int(np.log2(self.N))
        pos = np.arange(self.N, dtype=int)
        for i in range(n_stages):
            splitted = np.reshape(pos, (2 ** (i + 1), -1))
            upper_branch = splitted[0::2].flatten()
            lower_branch = splitted[1::2].flatten()
            vec[upper_branch] = (vec[upper_branch] + vec[lower_branch]) % 2
        return vec

    def encode(self, data, is_packed=False):
        if not len(data) == self.K:
            raise ValueError("len(data)={0} is not equal to k={1}!".format(len(data), self.K))
        if is_packed:
            data = np.unpackbits(data)
        if np.max(data) > 1 or np.min(data) < 0:
            raise ValueError("can only encode bits!")
        data = self._prepare_input_data(data)
        data = self._encode_efficient(data)
        if is_packed:
            data = np.packbits(data)
        return data


def compare_results(encoder, ntests, k):
    for n in range(ntests):
        bits = np.random.randint(2, size=k)
        preped = encoder._prepare_input_data(bits)
        menc = encoder._encode_matrix(preped)
        fenc = encoder._encode_efficient(preped)
        if (menc == fenc).all() == False:
            return False
    return True


def test_pseudo_rate_1_encoder(encoder, ntests, k):
    for n in range(ntests):
        bits = np.random.randint(2, size=k)
        u = encoder._prepare_input_data(bits)
        fenc = encoder._encode_efficient(u)
        u_hat = encoder._encode_efficient(fenc)
        if not (u_hat == u).all():
            print('rate-1 encoder/decoder failed')
            print u
            print u_hat
            return False
    return True


def test_encoder_impls():
    print('Compare encoder implementations, matrix vs. efficient')
    ntests = 1000
    n = 16
    k = 8
    frozenbits = np.zeros(n - k)
    # frozenbitposition8 = np.array((0, 1, 2, 4), dtype=int)  # keep it!
    frozenbitposition = np.array((0, 1, 2, 3, 4, 5, 8, 9), dtype=int)
    encoder = PolarEncoder(n, k, frozenbitposition, frozenbits)
    print 'result:', compare_results(encoder, ntests, k)

    print('Test rate-1 encoder/decoder chain results')
    r1_test = test_pseudo_rate_1_encoder(encoder, ntests, k)
    print 'Test rate-1 encoder/decoder:', r1_test


def main():
    test_encoder_impls()


if __name__ == '__main__':
    main()