import numpy as np
import matplotlib.pyplot as plt
import skimage
from skimage.feature import hog
from skimage import exposure

#%%
#HOG
img = plt.imread('braided_0081.jpg')
ii = skimage.color.rgb2gray(img)



#%%

def hog_fct(img):
    img = skimage.transform.resize(img, [128,128])
    #use hog from skimage.feature, 9 orientations, cell = 8x8, blocks = 2x2 cells (to complete)
    #...
    
    return hog_desc, hog_image

hog_desc, hog_image = hog_fct(img)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))

ax1.axis('off')
ax1.imshow(img, cmap=plt.cm.gray)
ax1.set_title('Input image')

# Rescale histogram for better display
hog_image_rescaled = exposure.rescale_intensity(hog_image, in_range=(0, 10))

ax2.axis('off')
ax2.imshow(hog_image_rescaled, cmap=plt.cm.gray)
ax2.set_title('Histogram of Oriented Gradients')
plt.show()

#%% LBP

def LBP(img, nhist=256):

    if img.ndim==3:
        img = skimage.color.rgb2gray(img)    
    h, w = img.shape
    
    #LBP coding (to complete)
    #...
    
    #Histogram 
    LBP_histo, tmp = np.histogram(LBP_map, nhist)
    return LBP_map, LBP_histo

LBP_map, LBP_histo = LBP(img, 256)
plt.figure()
plt.subplot(121)
plt.imshow(img, cmap='gray'); plt.title('Image')
plt.show()
plt.subplot(122)
plt.imshow(LBP_map, cmap='gray')

plt.figure()
plt.hist(LBP_map.ravel(), 128)
plt.show()
plt.title('LBP histogram 128-bin')

plt.figure()
plt.hist(LBP_map.ravel(), 256)
plt.show()
plt.title('LBP histogram 256-bin')


#%% Block-wise LBP

def LBP_BW(img, cell_size=16, nhist=256):
    if img.ndim==3:
        img = skimage.color.rgb2gray(img)    
    h, w = img.shape
    LBP_map, LBP_histo = LBP(img, nhist)
    h, w = img.shape
    LBP_histo_BW = []
    #Concatenate local histogram (np.histogram, append)
    #...
    LBP_histo_BW = np.array(LBP_histo_BW)
    return LBP_histo_BW

LBP_histo_BW = LBP_BW(img).ravel()


#%% Compute for all images
IL = np.load('I_train.npy')
IT = np.load('I_test.npy')
YL = np.load('Y_train.npy')
YT = np.load('Y_test.npy')

#%%
feature = 1 #1, 2, 3
if (feature==1):
    p = LBP_histo.shape[0]
if (feature==2):
    p = LBP_histo_BW.shape[0]
if (feature==3):
    p = hog_desc.shape[0]

XL = np.zeros((IL.shape[-1], p))
for i in range(0,IL.shape[-1]):
    print(i)
    img = IL[:,:,:,i]
    if (feature == 1):
        lbp_map, desc = LBP(img)
    if (feature == 2):
        desc = LBP_BW(img)
    if (feature == 3):
        desc, tmp = hog_fct(img)
    print(desc.shape)
    XL[i,:] = desc.ravel()

XT = np.zeros((IT.shape[-1], p))
for i in range(0,IT.shape[-1]):
    print(i)
    img = IT[:,:,:,i]
    if (feature == 1):
        lbp_map, desc = LBP(img)
    if (feature == 2):
        desc = LBP_BW(img)
    if (feature == 3):
        desc, tmp = hog_fct(img)
    print(desc.shape)
    XT[i,:] = desc.ravel()


#%% SVM classifier

from sklearn.svm import SVC

#(to complete)
#use SVC classifier
classifier = SVC(C = 100, kernel="rbf",  random_state=0)
#fit on XL, YL, learning data
classifier.fit(XL, YL.ravel())
#apply on XT test data
YT_pred = classifier.predict(XT)
#compute accuracy
accuracy_SVM = np.mean(YT_pred == YT.ravel())
print(f"{accuracy_SVM = }")



#%% PCA
from sklearn.decomposition import PCA

#use PCA with 50 components
#...
#fit on XL learning data
#...
#apply on XL and XT data
#...
#...

#Apply SVM classification
#...
#print(f"{accuracy_SVM_PCA = }")


