import torch
import numpy as np
import tqdm
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torch.nn.functional as F
import math
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as f
from torch.utils.data import DataLoader, Dataset
import time
from params import *
import os
from tqdm import tqdm
import numpy as np
import cv2
import matplotlib.pyplot as plt

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),  # Output: [batch, 32, 16, 16]
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # Output: [batch, 64, 8, 8]
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),  # Output: [batch, 128, 4, 4]
            nn.Linear(4*4*128, self.latent_dim),

    def forward(self, x):
        return self.encoder(x)

class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.latent_dim = latent_dim
        self.decoder = nn.Sequential(
            nn.Linear(self.latent_dim, 4*4*128),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # Output: [batch, 64, 8, 8]
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # Output: [batch, 32, 16, 16]
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),  # Output: [batch, 3, 32, 32]

    def forward(self, x):
        return self.decoder(x)

class Autoencoder(nn.Module):
    def __init__(self, latent_dim):
        super(Autoencoder, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def AWGN(self, input, SNRdB):

        normalized_tensor = f.normalize(input, dim=1)

        SNR = 10.0 ** (SNRdB / 10.0)

        std = 1 / math.sqrt(self.latent_dim * SNR)
        n = torch.normal(0, std, size=normalized_tensor.size()).to(device)

        return normalized_tensor + n

    def forward(self, x, SNRdB):

        encoded = self.encoder(x)

        channel_output = self.AWGN(encoded, SNRdB)
        decoded = self.decoder(channel_output)

        return decoded

def patch_std(image, patch_size=2):
    # Calculate the standard deviation within each patch
    H, W = image.shape
    std_map = np.zeros((H // patch_size, W // patch_size))
    for i in range(0, H, patch_size):
        for j in range(0, W, patch_size):
            patch = image[i:i + patch_size, j:j + patch_size]
            std_map[i // patch_size, j // patch_size] = np.std(patch)
    return std_map

def mask_patches_chessboard(images, patch_size=2, mask_ratio=0.5, complexity_based=False):
    if mask_ratio != 0.5:
        B, C, H, W = images.shape
        masked_images = images.clone()

        for b in range(B):
            image = images[b].permute(1, 2, 0).cpu().numpy() * 255
            image = image.astype(np.uint8)
            gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)

            # Calculate complexity for each patch
            complexity_map = patch_std(gray, patch_size)

            # Initialize mask with chessboard pattern for complexity map dimensions
            complexity_height, complexity_width = complexity_map.shape
            mask = np.zeros((complexity_height, complexity_width), dtype=bool)
            mask[::2, 1::2] = 1
            mask[1::2, ::2] = 1

            if complexity_based:
                if mask_ratio > 0.5:
                    additional_masking_ratio = (mask_ratio - 0.5) / 0.5
                    complexity_threshold = np.quantile(complexity_map[~mask], 1 - additional_masking_ratio)
                    additional_mask = complexity_map <= complexity_threshold
                    mask[~mask] = additional_mask[~mask]

                    unmasking_ratio = (0.5 - mask_ratio) / 0.5
                    complexity_threshold = np.quantile(complexity_map[mask], unmasking_ratio)
                    unmask = complexity_map >= complexity_threshold
                    mask[mask] = ~unmask[mask]

            # Apply mask to the original image based on complexity map
            for i in range(complexity_height):
                for j in range(complexity_width):
                    if mask[i, j]:
                        image[i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = 0

            # Convert image back to PyTorch format
            image = image.astype(np.float32) / 255.0
            masked_images[b] = torch.from_numpy(image).permute(2, 0, 1)

    elif mask_ratio == 0.5:
        B, C, H, W = images.shape
        masked_images = images.clone()

        # Create the chessboard pattern
        pattern = np.tile(np.array([[1, 0] * (W // (2 * patch_size)), [0, 1] * (W // (2 * patch_size))]),
                          (H // (2 * patch_size), 1))

        for b in range(B):
            image = images[b].permute(1, 2, 0).cpu().numpy() * 255
            image = image.astype(np.uint8)

            # Apply masking
            mask = np.repeat(np.repeat(pattern, patch_size, axis=0), patch_size, axis=1)
            image[mask == 0] = 0  # Apply chessboard pattern masking

            # Convert back to PyTorch format
            image = image.astype(np.float32) / 255.0
            masked_images[b] = torch.from_numpy(image).permute(2, 0, 1)

    return masked_images

def load_model(model_path):
    model = torch.load(model_path)
    model.eval()  # 모델을 평가 모드로 설정
    return model

transform = transforms.Compose([transforms.ToTensor()])
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=1, shuffle=True)

def visualize_reconstruction(model, data_loader):
    model.eval()  # 모델을 평가 모드로 설정
    data_iter = iter(data_loader)
    images, labels = next(data_iter)  # 데이터 로더에서 이미지를 가져옴

    # 마스킹 처리
    masked_images = mask_patches_chessboard(images, patch_size=params['PS'], mask_ratio=params['MR'], complexity_based=True).to(device)

    # 모델을 사용하여 이미지 복원
    with torch.no_grad():
        reconstructed_images = model(masked_images, params['SNR'])

    images = images.numpy()
    masked_images = masked_images.cpu().numpy()
    reconstructed_images = reconstructed_images.cpu().numpy()

    plt.figure(figsize=(9, 3))
    plt.subplot(1, 3, 1)
    plt.imshow(np.transpose(images[0], (1, 2, 0)))
    plt.title('Original Image')
    plt.subplot(1, 3, 2)
    plt.imshow(np.transpose(masked_images[0], (1, 2, 0)))
    plt.title('Masked Image')
    plt.subplot(1, 3, 3)
    plt.imshow(np.transpose(reconstructed_images[0], (1, 2, 0)))
    plt.title('Reconstructed Image')

model_path = 'trained_model/CBS(PS=' + str(params['PS'])+ '_DIM=' + str(params['DIM']) + '_MR=' + str(params['MR']) + '_SNR=' + str(params['SNR'])+ '_PSNR=' + str(params['PSNR'])+').pt'

model = load_model(model_path).to(device)

# 복원된 이미지 시각화
visualize_reconstruction(model, testloader)
