import math as m
import numpy as np

def zero_pad(x, nzeros):
    x_padded = np.pad(x.copy(), (0, nzeros), 'constant', constant_values=0)
    return x_padded

# circular convolution with DFT
def cconv_dft(signal, coefs):
    x = signal.copy()
    h = coefs.copy()
    # zero padding
    h = zero_pad(h, x.size - h.size)
    X = np.fft.fft(x)
    H = np.fft.fft(h)
    Y = X * H
    y = np.fft.ifft(Y)
    return np.real(y)

# direct version
# using that python array access is circular: a[-2] = a[a.size-2]
def cconv_direct(x, h):
    y = np.zeros(x.size)
    for n in range(x.size):
        for i in range(h.size):
            k = n
            while k-i <0:
                k += x.size
            y[n] += x[k-i] * h[i]
    return y

# compressed version
def cconv_compressed(x, h):
    return [ sum(x[n-i]*h[i] for i in range(h.size)) for n in range(x.size) ]

def fold(x, N):
    if x.size < N:
        y = zero_pad(x, N - x.size)
    else:
        y = x[0:N]
        j = N
        while j < x.size:
            k = min(x.size, j+N)
            toadd = zero_pad(x[j:k], N-k-j)
            print(toadd.size)
            y += toadd
            j = k
    return y

def cconv_mqz(x,y):
    bsize = x.size
    ssize = y.size
    w = np.convolve(x,y)
    z1 = w[0:bsize]
    z2 = zero_pad(w[bsize:bsize+ssize], bsize-ssize+1)
    z = fold( z1 + z2, bsize )
    return z

def cconv(x,h):
    return cconv_direct(x,h)


def dwt(x, L=1, t="Haar"):
    # Test compatibility of x and L
    if x.size % (2 ** L) != 0:
        print("{0:d}-stage dwt: signal length must be divisible by {1:d}".format(L, 2 **L))
        return
    # select filter bank
    if t == "Haar":
        l = m.sqrt(2)/2 * np.array([ 1, 1 ])
        il = 0
        h = m.sqrt(2)/2 * np.array([ 1, -1 ])
        ih = 0
    elif t == "Le Gall 5/3":
        if L > np.log2(x.size)-2:
            #print("Warning: Le Gall 5/3 for size {0:d} can be iterated up to {1:d} stages!".format(x.size,int(np.floor(np.log2(x.size)-2))))
            L = np.floor(np.log2(x.size-2))
        l = np.array([ -1/8, 1/4, 3/4, 1/4, -1/8 ])
        il = -2
        h = np.array([ -1/2, 1, -1/2 ])
        ih = 0
    elif t == "Daubechies 4-tap":
        c0 = 1 + m.sqrt(3)
        c1 = 3 + m.sqrt(3)
        c2 = 3 - m.sqrt(3)
        c3 = 1 - m.sqrt(3)
        l = (1 / (4*m.sqrt(2)) ) * np.array([ c0, c1, c2, c3 ])
        il = 0
        h = (1 / (4*m.sqrt(2)) ) * np.array([ c3, -c2, c1, -c0 ])
        ih = 0
    else:
        print("dwt: t must be 'Haar', 'Le Gall 5/3' or 'Daubechies 4-tap'")
        return

    X = x.copy()
    # This is the actual dwt
    N = X.size
    for j in range(L):
        y = (X[0:N]).copy()
        xl = cconv(y,l)
        X[0:N//2] = np.roll(xl,il)[0:N:2]
        xh = cconv(y,h)
        X[N//2:N] = np.roll(xh,ih)[0:N:2]
        N //= 2

    return X

def idwt(X, L=1, t="Haar"):
    # Test compatibility of x and L
    if X.size % (2 ** L) != 0:
        print("{0:d}-stage dwt: signal length must be divisible by {1:d}".format(L, 2 **L))
        return
    # select filter bank
    if t == "Haar":
        l = m.sqrt(2)/2 * np.array([ 1, 1 ])
        il = -1
        h = m.sqrt(2)/2 * np.array([ 1, -1 ])
        ih = -1
        ix = 0
    elif t == "Le Gall 5/3":
        if L > np.log2(X.size)-2:
            #print("Warning: Le Gall 5/3 for size {0:d} can be iterated up to {1:d} stages!".format(X.size,int(np.floor(np.log2(X.size)-2))))
            L = int(np.floor(np.log2(X.size-2)))
        l = np.array([ 1/2, 1, 1/2 ])
        il = 0
        h = np.array([ -1/8, -1/4, 3/4, -1/4, -1/8 ])
        ih = -2
        ix = -1
    elif t == "Daubechies 4-tap":
        c0 = 1 + m.sqrt(3)
        c1 = 3 + m.sqrt(3)
        c2 = 3 - m.sqrt(3)
        c3 = 1 - m.sqrt(3)
        l = (1 / (4*m.sqrt(2)) ) * np.array([ c0, c1, c2, c3 ])
        il = -3
        h = (1 / (4*m.sqrt(2)) ) * np.array([ c3, -c2, c1, -c0 ])
        ih = -3
        ix = 0
    else:
        print("dwt: t must be 'Haar', 'Le Gall 5/3' or 'Daubechies 4-tap'")
        return

    x = X.copy()
    # This is the actual idwt
    N = x.size//(2**L)
    for j in range(L):
        Y = (x[0:2*N])
        UXl = np.zeros(2*N)
        UXh = np.zeros(2*N)
        UXl[0:2*N:2] = Y[0:N]
        UXl[2*N-1] = 0
        UXh[0:2*N:2] = Y[N:2*N]
        UXh[2*N-1] = 0
        Xl = np.roll(cconv(UXl,l),il)
        Xh = np.roll(cconv(UXh,h),ih)
        x[0:2*N] = np.roll(Xl+Xh,ix)
        N *= 2

    return x

def dwt2(x, L=1, t="Haar"):
    X = np.float64(x)
    # imagem with dimension MxN
    M, N = X.shape
    for k in range(L):
        for j in range(N):
            X[0:M, j] = dwt( X[0:M, j], 1, t)
        for j in range(M):
            X[j, 0:N] = dwt( X[j, 0:N], 1, t)
        M //= 2
        N //= 2
    return X

def idwt2(X, L=1, t="Haar"):
    x = X.copy()
    # imagem with dimension MxN
    my, mx = x.shape
    M = my // (2**(L-1))
    N = mx // (2**(L-1))
    for k in range(L):
        for j in range(M):
            x[j, 0:N] = idwt( x[j, 0:N], 1, t)
        for j in range(N):
            x[0:M, j] = idwt( x[0:M, j], 1, t)
        M *= 2
        N *= 2
    return x

