#!/usr/bin/env python
#
# Copyright 2015 Free Software Foundation, Inc.
#
# This file is part of GNU Radio
#
# SPDX-License-Identifier: GPL-3.0-or-later
#
#


import string
import sys

import numpy as np
from numpy.linalg import inv, det
from numpy.random import shuffle, randint


# 0 gives no debug output, 1 gives a little, 2 gives a lot
# verbose = 1 #######################################################

def read_alist_file(filename):
    """
    This function reads in an alist file and creates the
    corresponding parity check matrix H. The format of alist
    files is described at:
    http://www.inference.phy.cam.ac.uk/mackay/codes/alist.html
    """

    with open(filename, 'r') as myfile:
        data = myfile.readlines()
        numCols, numRows = parse_alist_header(data[0])
        H = zeros((numRows, numCols))
        # The locations of 1s starts in the 5th line of the file
        for lineNumber in np.arange(4, 4 + numCols):
            indices = data[lineNumber].split()
            for index in indices:
                H[int(index) - 1, lineNumber - 4] = 1
                # The subsequent lines in the file list the indices for where
                # the 1s are in the rows, but this is redundant
                # information.

        return H


def parse_alist_header(header):
    size = header.split()
    return int(size[0]), int(size[1])


def write_alist_file(filename, H, verbose=0):
    """
    This function writes an alist file for the parity check
    matrix. The format of alist files is described at:
    http://www.inference.phy.cam.ac.uk/mackay/codes/alist.html
    """
    with open(filename, 'w') as myfile:

        numRows = H.shape[0]
        numCols = H.shape[1]

        tempstring = repr(numCols) + ' ' + repr(numRows) + '\n'
        myfile.write(tempstring)

        tempstring1 = ''
        tempstring2 = ''
        maxRowWeight = 0
        for rowNum in np.arange(numRows):
            nonzeros = array(H[rowNum, :].nonzero())
            rowWeight = nonzeros.shape[1]
            if rowWeight > maxRowWeight:
                maxRowWeight = rowWeight
            tempstring1 = tempstring1 + repr(rowWeight) + ' '
            for tempArray in nonzeros:
                for index in tempArray:
                    tempstring2 = tempstring2 + repr(index + 1) + ' '
            tempstring2 = tempstring2 + '\n'
        tempstring1 = tempstring1 + '\n'

        tempstring3 = ''
        tempstring4 = ''
        maxColWeight = 0
        for colNum in np.arange(numCols):
            nonzeros = array(H[:, colNum].nonzero())
            colWeight = nonzeros.shape[1]
            if colWeight > maxColWeight:
                maxColWeight = colWeight
            tempstring3 = tempstring3 + repr(colWeight) + ' '
            for tempArray in nonzeros:
                for index in tempArray:
                    tempstring4 = tempstring4 + repr(index + 1) + ' '
            tempstring4 = tempstring4 + '\n'
        tempstring3 = tempstring3 + '\n'

        tempstring = repr(maxColWeight) + ' ' + repr(maxRowWeight) + '\n'
        # write out max column and row weights
        myfile.write(tempstring)
        # write out all of the column weights
        myfile.write(tempstring3)
        # write out all of the row weights
        myfile.write(tempstring1)
        # write out the nonzero indices for each column
        myfile.write(tempstring4)
        # write out the nonzero indices for each row
        myfile.write(tempstring2)


class LDPC_matrix(object):
    """ Class for a LDPC parity check matrix """

    def __init__(self, alist_filename=None,
                 n_p_q=None,
                 H_matrix=None):
        if (alist_filename != None):
            self.H = read_alist_file(alist_filename)
        elif (n_p_q != None):
            self.H = self.regular_LDPC_code_contructor(n_p_q)
        elif (H_matrix != None):
            self.H = H_matrix
        else:
            print('Error: provide either an alist filename, ', end='')
            print('parameters for constructing regular LDPC parity, ', end='')
            print('check matrix, or a numpy array.')

        self.rank = linalg.matrix_rank(self.H)
        self.numRows = self.H.shape[0]
        self.n = self.H.shape[1]
        self.k = self.n - self.numRows

    def regular_LDPC_code_contructor(self, n_p_q):
        """
        This function constructs a LDPC parity check matrix
        H. The algorithm follows Gallager's approach where we create
        p submatrices and stack them together. Reference: Turbo
        Coding for Satellite and Wireless Communications, section
        9,3.

        Note: the matrices computed from this algorithm will never
        have full rank. (Reference Gallager's Dissertation.) They
        will have rank = (number of rows - p + 1). To convert it
        to full rank, use the function get_full_rank_H_matrix
        """

        n = n_p_q[0]  # codeword length
        p = n_p_q[1]  # column weight
        q = n_p_q[2]  # row weight
        # TODO: There should probably be other guidelines for n/p/q,
        # but I have not found any specifics in the literature....

        # For this algorithm, n/p must be an integer, because the
        # number of rows in each submatrix must be a whole number.
        ratioTest = (n * 1.0) / q
        if ratioTest % 1 != 0:
            print('\nError in regular_LDPC_code_contructor: The ', end='')
            print('ratio of inputs n/q must be a whole number.\n')
            return

        # First submatrix first:
        m = (n * p) / q  # number of rows in H matrix
        submatrix1 = zeros((m / p, n))
        for row in np.arange(m / p):
            range1 = row * q
            range2 = (row + 1) * q
            submatrix1[row, range1:range2] = 1
            H = submatrix1

        # Create the other submatrices and vertically stack them on.
        submatrixNum = 2
        newColumnOrder = np.arange(n)
        while submatrixNum <= p:
            submatrix = zeros((m / p, n))
            shuffle(newColumnOrder)

            for columnNum in np.arange(n):
                submatrix[:, columnNum] = \
                    submatrix1[:, newColumnOrder[columnNum]]

            H = vstack((H, submatrix))
            submatrixNum = submatrixNum + 1

        # Double check the row weight and column weights.
        size = H.shape
        rows = size[0]
        cols = size[1]

        # Check the row weights.
        for rowNum in np.arange(rows):
            nonzeros = array(H[rowNum, :].nonzero())
            if nonzeros.shape[1] != q:
                print('Row', rowNum, 'has incorrect weight!')
                return

        # Check the column weights
        for columnNum in np.arange(cols):
            nonzeros = array(H[:, columnNum].nonzero())
            if nonzeros.shape[1] != p:
                print('Row', columnNum, 'has incorrect weight!')
                return

        return H


def greedy_upper_triangulation(H, verbose=0):
    """
    This function performs row/column permutations to bring
    H into approximate upper triangular form via greedy
    upper triangulation method outlined in Modern Coding
    Theory Appendix 1, Section A.2
    """
    H_t = H.copy()

    # Per email from Dr. Urbanke, author of this textbook, this
    # algorithm requires H to be full rank
    if linalg.matrix_rank(H_t) != H_t.shape[0]:
        print('Rank of H:', linalg.matrix_rank(tempArray))
        print('H has', H_t.shape[0], 'rows')
        print('Error: H must be full rank.')
        return

    size = H_t.shape
    n = size[1]
    k = n - size[0]
    g = t = 0

    while t != (n - k - g):
        H_residual = H_t[t:n - k - g, t:n]
        size = H_residual.shape
        numRows = size[0]
        numCols = size[1]

        minResidualDegrees = zeros((1, numCols), dtype=int)

        for colNum in np.arange(numCols):
            nonZeroElements = array(H_residual[:, colNum].nonzero())
            minResidualDegrees[0, colNum] = nonZeroElements.shape[1]

        # Find the minimum nonzero residual degree
        nonZeroElementIndices = minResidualDegrees.nonzero()
        nonZeroElements = minResidualDegrees[nonZeroElementIndices[0],
                                             nonZeroElementIndices[1]]
        minimumResidualDegree = nonZeroElements.min()

        # Get indices of all of the columns in H_t that have degree
        # equal to the min positive residual degree, then pick a
        # random column c.
        indices = (minResidualDegrees == minimumResidualDegree) \
            .nonzero()[1]
        indices = indices + t
        if indices.shape[0] == 1:
            columnC = indices[0]
        else:
            randomIndex = randint(0, indices.shape[0], (1, 1))[0][0]
            columnC = indices[randomIndex]

        Htemp = H_t.copy()

        if minimumResidualDegree == 1:
            # This is the 'extend' case
            rowThatContainsNonZero = H_residual[:, columnC - t].nonzero()[0][0]

            # Swap column c with column t. (Book says t+1 but we
            # index from 0, not 1.)
            Htemp[:, columnC] = H_t[:, t]
            Htemp[:, t] = H_t[:, columnC]
            H_t = Htemp.copy()
            Htemp = H_t.copy()
            # Swap row r with row t. (Book says t+1 but we index from
            # 0, not 1.)
            Htemp[rowThatContainsNonZero + t, :] = H_t[t, :]
            Htemp[t, :] = H_t[rowThatContainsNonZero + t, :]
            H_t = Htemp.copy()
            Htemp = H_t.copy()
        else:
            # This is the 'choose' case.
            rowsThatContainNonZeros = H_residual[:, columnC - t] \
                .nonzero()[0]

            # Swap column c with column t. (Book says t+1 but we
            # index from 0, not 1.)
            Htemp[:, columnC] = H_t[:, t]
            Htemp[:, t] = H_t[:, columnC]
            H_t = Htemp.copy()
            Htemp = H_t.copy()

            # Swap row r1 with row t
            r1 = rowsThatContainNonZeros[0]
            Htemp[r1 + t, :] = H_t[t, :]
            Htemp[t, :] = H_t[r1 + t, :]
            numRowsLeft = rowsThatContainNonZeros.shape[0] - 1
            H_t = Htemp.copy()
            Htemp = H_t.copy()

            # Move the other rows that contain nonZero entries to the
            # bottom of the matrix. We can't just swap them,
            # otherwise we will be pulling up rows that we pushed
            # down before. So, use a rotation method.
            for index in np.arange(1, numRowsLeft + 1):
                rowInH_residual = rowsThatContainNonZeros[index]
                rowInH_t = rowInH_residual + t - index + 1
                m = n - k
                # Move the row with the nonzero element to the
                # bottom; don't update H_t.
                Htemp[m - 1, :] = H_t[rowInH_t, :]
                # Now rotate the bottom rows up.
                sub_index = 1
                while sub_index < (m - rowInH_t):
                    Htemp[m - sub_index - 1, :] = H_t[m - sub_index, :]
                    sub_index = sub_index + 1
                H_t = Htemp.copy()
                Htemp = H_t.copy()

            # Save temp H as new H_t.
            H_t = Htemp.copy()
            Htemp = H_t.copy()
            g = g + (minimumResidualDegree - 1)

        t = t + 1

    if g == 0:
        if verbose:
            print('Error: gap is 0.')
        return

    # We need to ensure phi is nonsingular.
    T = H_t[0:t, 0:t]
    E = H_t[t:t + g, 0:t]
    A = H_t[0:t, t:t + g]
    C = H_t[t:t + g, t:t + g]
    D = H_t[t:t + g, t + g:n]

    invTmod2array = inv_mod2(T)
    temp1 = dot(E, invTmod2array) % 2
    temp2 = dot(temp1, A) % 2
    phi = (C - temp2) % 2
    if phi.any():
        try:
            # Try to take the inverse of phi.
            invPhi = inv_mod2(phi)
        except linalg.linalg.LinAlgError:
            # Phi is singular
            if verbose > 1:
                print('Initial phi is singular')
        else:
            # Phi is nonsingular, so we need to use this version of H.
            if verbose > 1:
                print('Initial phi is nonsingular')
            return [H_t, g, t]
    else:
        if verbose:
            print('Initial phi is all zeros:\n', phi)

    # If the C and D submatrices are all zeros, there is no point in
    # shuffling them around in an attempt to find a good phi.
    if not (C.any() or D.any()):
        if verbose:
            print('C and D are all zeros. There is no hope in', )
            print('finding a nonsingular phi matrix. ')
        return

    # We can't look at every row/column permutation possibility
    # because there would be (n-t)! column shuffles and g! row
    # shuffles. g has gotten up to 12 in tests, so 12! would still
    # take quite some time. Instead, we will just pick an arbitrary
    # number of max iterations to perform, then break.
    maxIterations = 300
    iterationCount = 0
    columnsToShuffle = np.arange(t, n)
    rowsToShuffle = np.arange(t, t + g)

    while iterationCount < maxIterations:
        if verbose > 1:
            print('iterationCount:', iterationCount)
        tempH = H_t.copy()

        shuffle(columnsToShuffle)
        shuffle(rowsToShuffle)
        index = 0
        for newDestinationColumnNumber in np.arange(t, n):
            oldColumnNumber = columnsToShuffle[index]
            tempH[:, newDestinationColumnNumber] = \
                H_t[:, oldColumnNumber]
            index += 1

        tempH2 = tempH.copy()
        index = 0
        for newDesinationRowNumber in np.arange(t, t + g):
            oldRowNumber = rowsToShuffle[index]
            tempH[newDesinationRowNumber, :] = tempH2[oldRowNumber, :]
            index += 1

        # Now test this new H matrix.
        H_t = tempH.copy()
        T = H_t[0:t, 0:t]
        E = H_t[t:t + g, 0:t]
        A = H_t[0:t, t:t + g]
        C = H_t[t:t + g, t:t + g]
        invTmod2array = inv_mod2(T)
        temp1 = dot(E, invTmod2array) % 2
        temp2 = dot(temp1, A) % 2
        phi = (C - temp2) % 2
        if phi.any():
            try:
                # Try to take the inverse of phi.
                invPhi = inv_mod2(phi)
            except linalg.linalg.LinAlgError:
                # Phi is singular
                if verbose > 1:
                    print('Phi is still singular')
            else:
                # Phi is nonsingular, so we're done.
                if verbose:
                    print('Found a nonsingular phi on', )
                    print('iterationCount = ', iterationCount)
                return [H_t, g, t]
        else:
            if verbose > 1:
                print('phi is all zeros')

        iterationCount += 1

    # If we've reached this point, then we haven't found a
    # version of H that has a nonsingular phi.
    if verbose:
        print('--- Error: nonsingular phi matrix not found.')


def inv_mod2(squareMatrix, verbose=0):
    """
    Calculates the mod 2 inverse of a matrix.
    """
    A = squareMatrix.copy()
    t = A.shape[0]

    # Special case for one element array [1]
    if A.size == 1 and A[0] == 1:
        return array([1])

    Ainverse = inv(A)
    B = det(A) * Ainverse
    C = B % 2

    # Encountered lots of rounding errors with this function.
    # Previously tried floor, C.astype(int), and casting with (int)
    # and none of that works correctly, so doing it the tedious way.

    test = dot(A, C) % 2
    tempTest = zeros_like(test)
    for colNum in np.arange(test.shape[1]):
        for rowNum in np.arange(test.shape[0]):
            value = test[rowNum, colNum]
            if (abs(1 - value)) < 0.01:
                # this is a 1
                tempTest[rowNum, colNum] = 1
            elif (abs(2 - value)) < 0.01:
                # there shouldn't be any 2s after B % 2, but I'm
                # seeing them!
                tempTest[rowNum, colNum] = 0
            elif (abs(0 - value)) < 0.01:
                # this is a 0
                tempTest[rowNum, colNum] = 0
            else:
                if verbose > 1:
                    print('In inv_mod2. Rounding error on this', )
                    print('value? Mod 2 has already been done.', )
                    print('value:', value)

    test = tempTest.copy()

    if (test - eye(t, t) % 2).any():
        if verbose:
            print('Error in inv_mod2: did not find inverse.')
            # TODO is this the most appropriate error to raise?
        raise linalg.linalg.LinAlgError
    else:
        return C


def swap_columns(a, b, arrayIn):
    """
    Swaps two columns in a matrix.
    """
    arrayOut = arrayIn.copy()
    arrayOut[:, a] = arrayIn[:, b]
    arrayOut[:, b] = arrayIn[:, a]
    return arrayOut


def move_row_to_bottom(i, arrayIn):
    """"
    Moves a specified row (just one) to the bottom of the matrix,
    then rotates the rows at the bottom up.

    For example, if we had a matrix with 5 rows, and we wanted to
    push row 2 to the bottom, then the resulting row order would be:
    1,3,4,5,2
    """
    arrayOut = arrayIn.copy()
    numRows = arrayOut.shape[0]
    # Push the specified row to the bottom.
    arrayOut[numRows - 1] = arrayIn[i, :]
    # Now rotate the bottom rows up.
    index = 2
    while (numRows - index) >= i:
        arrayOut[numRows - index, :] = arrayIn[numRows - index + 1]
        index = index + 1
    return arrayOut


def get_full_rank_H_matrix(H, verbose=False):
    """
    This function accepts a parity check matrix H and, if it is not
    already full rank, will determine which rows are dependent and
    remove them. The updated matrix will be returned.
    """
    tempArray = H.copy()
    if linalg.matrix_rank(tempArray) == tempArray.shape[0]:
        if verbose:
            print('Returning H; it is already full rank.')
        return tempArray

    numRows = tempArray.shape[0]
    numColumns = tempArray.shape[1]
    limit = numRows
    rank = 0
    i = 0

    # Create an array to save the column permutations.
    columnOrder = np.arange(numColumns).reshape(1, numColumns)

    # Create an array to save the row permutations. We just need
    # this to know which dependent rows to delete.
    rowOrder = np.arange(numRows).reshape(numRows, 1)

    while i < limit:
        if verbose:
            print('In get_full_rank_H_matrix; i:', i)
            # Flag indicating that the row contains a non-zero entry
        found = False
        for j in np.arange(i, numColumns):
            if tempArray[i, j] == 1:
                # Encountered a non-zero entry at (i, j)
                found = True
                # Increment rank by 1
                rank = rank + 1
                # Make the entry at (i,i) be 1
                tempArray = swap_columns(j, i, tempArray)
                # Keep track of the column swapping
                columnOrder = swap_columns(j, i, columnOrder)
                break
        if found == True:
            for k in np.arange(0, numRows):
                if k == i:
                    continue
                # Checking for 1's
                if tempArray[k, i] == 1:
                    # Add row i to row k
                    tempArray[k, :] = tempArray[k, :] + tempArray[i, :]
                    # Addition is mod2
                    tempArray = tempArray.copy() % 2
                    # All the entries above & below (i, i) are now 0
            i = i + 1
        if found == False:
            # Push the row of 0s to the bottom, and move the bottom
            # rows up (sort of a rotation thing).
            tempArray = move_row_to_bottom(i, tempArray)
            # Decrease limit since we just found a row of 0s
            limit -= 1
            # Keep track of row swapping
            rowOrder = move_row_to_bottom(i, rowOrder)

    # Don't need the dependent rows
    finalRowOrder = rowOrder[0:i]

    # Reorder H, per the permutations taken above .
    # First, put rows in order, omitting the dependent rows.
    newNumberOfRowsForH = finalRowOrder.shape[0]
    newH = zeros((newNumberOfRowsForH, numColumns))
    for index in np.arange(newNumberOfRowsForH):
        newH[index, :] = H[finalRowOrder[index], :]

    # Next, put the columns in order.
    tempHarray = newH.copy()
    for index in np.arange(numColumns):
        newH[:, index] = tempHarray[:, columnOrder[0, index]]

    if verbose:
        print('original H.shape:', H.shape)
        print('newH.shape:', newH.shape)

    return newH


def get_best_matrix(H, numIterations=100, verbose=0):
    """
    This function will run the Greedy Upper Triangulation algorithm
    for numIterations times, looking for the lowest possible gap.
    The submatrices returned are those needed for real-time encoding.
    """

    hadFirstJoy = 0
    index = 1
    while index <= numIterations:
        if verbose:
            print('--- In get_best_matrix, iteration:', index)
        index += 1
        try:
            ret = greedy_upper_triangulation(H, verbose)
        except ValueError as e:
            if verbose > 1:
                print('greedy_upper_triangulation error: ', e)
        else:
            if ret:
                [betterH, gap, t] = ret
            else:
                continue

            if not hadFirstJoy:
                hadFirstJoy = 1
                bestGap = gap
                bestH = betterH.copy()
                bestT = t
            elif gap < bestGap:
                bestGap = gap
                bestH = betterH.copy()
                bestT = t

    if hadFirstJoy:
        return [bestH, bestGap]
    else:
        if verbose:
            print('Error: Could not find appropriate H form', )
            print('for encoding.')
        return


def getSystematicGmatrix(GenMatrix):
    """
    This function finds the systematic form of the generator
    matrix GenMatrix. This form is G = [I P] where I is an identity
    matrix and P is the parity submatrix. If the GenMatrix matrix
    provided is not full rank, then dependent rows will be deleted.

    This function does not convert parity check (H) matrices to the
    generator matrix format. Use the function getSystematicGmatrixFromH
    for that purpose.
    """
    tempArray = GenMatrix.copy()
    numRows = tempArray.shape[0]
    numColumns = tempArray.shape[1]
    limit = numRows
    rank = 0
    i = 0
    while i < limit:
        # Flag indicating that the row contains a non-zero entry
        found = False
        for j in np.arange(i, numColumns):
            if tempArray[i, j] == 1:
                # Encountered a non-zero entry at (i, j)
                found = True
                # Increment rank by 1
                rank = rank + 1
                # make the entry at (i,i) be 1
                tempArray = swap_columns(j, i, tempArray)
                break
        if found == True:
            for k in np.arange(0, numRows):
                if k == i:
                    continue
                # Checking for 1's
                if tempArray[k, i] == 1:
                    # add row i to row k
                    tempArray[k, :] = tempArray[k, :] + tempArray[i, :]
                    # Addition is mod2
                    tempArray = tempArray.copy() % 2
                    # All the entries above & below (i, i) are now 0
            i = i + 1
        if found == False:
            # push the row of 0s to the bottom, and move the bottom
            # rows up (sort of a rotation thing)
            tempArray = move_row_to_bottom(i, tempArray)
            # decrease limit since we just found a row of 0s
            limit -= 1
            # the rows below i are the dependent rows, which we discard
    G = tempArray[0:i, :]
    return G


def getSystematicGmatrixFromH(H, verbose=False):
    """
    If given a parity check matrix H, this function returns a
    generator matrix G in the systematic form: G = [I P]
      where:  I is an identity matrix, size k x k
              P is the parity submatrix, size k x (n-k)
    If the H matrix provided is not full rank, then dependent rows
    will be deleted first.
    """
    if verbose:
        print('received H with size: ', H.shape)

    # First, put the H matrix into the form H = [I|m] where:
    #   I is (n-k) x (n-k) identity matrix
    #   m is (n-k) x k
    # This part is just copying the algorithm from getSystematicGmatrix
    tempArray = getSystematicGmatrix(H)

    # Next, swap I and m columns so the matrix takes the forms [m|I].
    n = H.shape[1]
    k = n - H.shape[0]
    I_temp = tempArray[:, 0:(n - k)]
    m = tempArray[:, (n - k):n]
    newH = concatenate((m, I_temp), axis=1)

    # Now the submatrix m is the transpose of the parity submatrix,
    # i.e. H is in the form H = [P'|I]. So G is just [I|P]
    k = m.shape[1]
    G = concatenate((identity(k), m.T), axis=1)
    if verbose:
        print('returning G with size: ', G.shape)
    return G