import numpy as np
import matplotlib.pyplot as plt
from skimage.measure import shannon_entropy
#import cv2
from scipy.fftpack import dct, idct
import os #to get file size

def PSNR(original, compressed): 
    mse = np.mean((original - compressed) ** 2) 
    if(mse == 0):  
        return 10000
    max_pixel = 255.0
    psnr = 20 * np.log10(max_pixel / np.sqrt(mse)) 
    return psnr

def jpeg_compress(I1, q):

    #Matrice de quantification
    Q = np.array([[16,  11,  10,  16,  24,   40,  51,   61],
        [12,  12,  14,  19,  26,   58,  60,   55],
        [14,  13,  16,  24,  40,   57,  69,   56],
        [14,  17,  22,  29,  51,   87,  80,   62],
        [18,  22,  37,  56,  68,   109,  103,  77],
        [24,  35,  55,  64,  81,   104,  113,  92],
        [49,  64,  78,  87,  103,  121,  120,  101],
        [72,  92,  95,  98,  112,  100,  103,  99 ]]).astype('float')
    
    #Mapping q - alpha 
    if (q < 50):
        alpha = 50/q
    else:
        alpha = (100-q)/50

    #DCT
    F = dct(dct(I1, norm='ortho').T, norm='ortho')
    #print(F)
    #Quantification
    Fq = np.fix(F/(alpha*Q))
    #print(Fq)
    #Déquantification
    Fdq = (Fq*Q*alpha) #%On retourne dans l'espace de valeurs initial
    #print(Fdq)
    #DCT inverse
    I1_c = idct(idct(Fdq, norm='ortho').T, norm='ortho')
    
    return I1_c
    
def blockproc(im, block_sz, func, arg):
    h, w = im.shape
    m, n = block_sz
    for x in range(0, h, m):
        for y in range(0, w, n):
            block = im[x:x+m, y:y+n]
            block[:,:] = func(block, arg)
    return im

#%% Sauvegarde en jpg (sans recompression. plt.imsave ne le permet pas)
#1ere option avec opencv
#cv2.imwrite('cameraman.jpg', img, [cv2.IMWRITE_JPEG_QUALITY, 100])

#2eme option avec PIL (Pillow)
from PIL import Image
image = Image.open('../img/cameraman.tif')
w, h = image.size
img = np.array(image.getdata()).reshape([h,w]).astype('uint8')

#%%
#modifications...
image.putdata(img.ravel())
image.save('cameraman.jpg', quality=100)

#Récupération de l'info de taille
file_stat = os.stat('cameraman.jpg')
size_init = file_stat.st_size #%en octets
print(size_init) 

#Algo JPEG
q = 80 #Facteur de qualité (à faire varier)

#Traitement par blocs - blockproc
img_compress_jpg =  blockproc(img.copy().astype('float'), (8, 8), jpeg_compress, q)

#Display
plt.figure() 
plt.subplot(121)
plt.imshow(img, cmap='gray', vmin=0, vmax=255)
plt.title('Image initiale')
plt.subplot(122)
plt.imshow(img_compress_jpg, cmap='gray', vmin=0, vmax=255)
plt.show()


#%%
#Entropie
entropy_compress = shannon_entropy(img_compress_jpg)
entropy_init = shannon_entropy(img)

#Sauvegarde en jpg
image.putdata(img_compress_jpg.ravel())
image.save('img_compress_jpg.jpg', quality=100)
file_stat = os.stat('img_compress_jpg.jpg')
size_compress = file_stat.st_size #%en octets
print(size_compress)

#Calculer le facteur et le taux de compression
F = size_init/size_compress
T = 100*(1-1/F)

#Mesurer le rapport signal à bruit entre l'image initiale et compressée
psnr_val = 20*np.log10(255/np.sqrt(np.mean((img-img_compress_jpg)**2)))


F_tab = np.zeros(99)
psnr_tab = np.zeros(99)
entropy_tab = np.zeros(99)

plt.figure()
#Facteur de qualité (à faire varier)
for q in np.arange(0,99):
      
    print(q)
    
    #Traitement par blocs
    img_compress_jpg =  blockproc(img.copy().astype('float'), (8, 8), jpeg_compress, q+1).astype('float')
    
    #Entropie
    entropy_tab[q] = shannon_entropy(img_compress_jpg.astype('uint8'))
    
    #Sauvegarde en jpg    
    image.putdata(img_compress_jpg.ravel())
    image.save('img_compress_jpg.jpg', quality=100)
    file_stat = os.stat('img_compress_jpg.jpg')
    size_compress = file_stat.st_size
    
    #Facteur de compression
    F_tab[q] = size_init/size_compress
    
    #PSNR
    #psnr_tab[q] = PSNR(img_compress_jpg.astype('float'), img.astype('float'))
    psnr_tab[q] = psnr_val = 20*np.log10(255/np.sqrt(np.mean((img-img_compress_jpg)**2)))
    

plt.figure()
plt.plot(entropy_tab)
plt.title('Entropie')
plt.xlabel('q')

plt.figure()
plt.plot(F_tab)
plt.title('Facteur de compression')
plt.xlabel('q')

plt.figure()
plt.plot(psnr_tab)
plt.title('PSNR')
plt.xlabel('q')
plt.show()


