import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torch.nn.functional as F
import math
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import time
from tqdm import tqdm
import numpy as np
from skimage.metrics import structural_similarity as ssim
from params import *
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def patch_importance(image, patch_size=2, type='variance', how_many=2, noise_scale=0):
    if isinstance(image, torch.Tensor):
        image = image.numpy()

    H, W = image.shape[-2:]
    extended_patch_size = patch_size + 2 * how_many
    value_map = np.zeros((H // patch_size, W // patch_size))

    for i in range(0, H, patch_size):
        for j in range(0, W, patch_size):
            start_i = max(i - how_many, 0)
            end_i = min(i + patch_size + how_many, H)
            start_j = max(j - how_many, 0)
            end_j = min(j + patch_size + how_many, W)

            extended_patch = image[start_i:end_i, start_j:end_j]

            if type == 'variance':
                value = np.std(extended_patch)
            elif type == 'mean_brightness':
                value = np.mean(extended_patch)
            elif type == 'contrast':
                value = extended_patch.max() - extended_patch.min()
            elif type == 'edge_density':
                dy, dx = np.gradient(extended_patch)
                value = np.sum(np.sqrt(dx ** 2 + dy ** 2))
            elif type == 'color_diversity':
                value = np.std(extended_patch)

            noise = np.random.randn() * noise_scale
            value_map[i // patch_size, j // patch_size] = value + noise

    return value_map

def chessboard_mask(images, patch_size=2, mask_ratio=0.5, importance_type='variance', how_many=1, noise_scale=0):
    B, C, H, W = images.shape
    masked_images = images.clone()
    unmasked_counts = []
    unmasked_patches = []
    patch_index = []

    target_unmasked_ratio = 1 - mask_ratio
    num_patches = (H // patch_size) * (W // patch_size)
    target_unmasked_patches = int(num_patches * target_unmasked_ratio)

    for b in range(B):

        patch_importance_map = patch_importance(images[b, 0], patch_size, importance_type, how_many, noise_scale)

        mask = np.zeros((H // patch_size, W // patch_size), dtype=bool)
        for i in range(H // patch_size):
            for j in range(W // patch_size):
                if (i + j) % 2 == 0:
                    mask[i, j] = True

        unmasked_count = np.sum(~mask)

        if mask_ratio < 0.5:
            masked_indices = np.argwhere(mask)
            importances = patch_importance_map[mask]
            sorted_indices = masked_indices[np.argsort(importances)[::-1]]

            for idx in sorted_indices:
                if unmasked_count >= target_unmasked_patches:
                mask[tuple(idx)] = False
                unmasked_count += 1

        elif mask_ratio > 0.5:
            unmasked_indices = np.argwhere(~mask)
            importances = patch_importance_map[~mask]
            sorted_indices = unmasked_indices[np.argsort(importances)]

            for idx in sorted_indices:
                if unmasked_count <= target_unmasked_patches:
                mask[tuple(idx)] = True
                unmasked_count -= 1

        patches = []
        for i in range(H // patch_size):
            for j in range(W // patch_size):
                if mask[i, j]:
                    masked_images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = 0
                    patch = images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size]

                    patch_index.append((H // patch_size)*i + j)

        unmasked_patches.append(torch.cat(patches, dim=-1))
        unmasked_patches_image = torch.cat(unmasked_patches, dim=-1)


        if mask_ratio == 0.33984 :
            split_len = 26

        split_tensor = torch.split(unmasked_patches_image, split_len, dim=2)

        reshaped = torch.cat(split_tensor, dim=1)

    return masked_images, reshaped, torch.tensor(patch_index)

class Encoder1D(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder1D, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1),

            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1),

            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),

            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1),

            nn.Flatten(),  #torch.Size([64, 1280])
            nn.Linear(512, self.latent_dim),

    def forward(self, x):

        encoded = self.encoder(x)

        return encoded

class Decoder1D(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder1D, self).__init__()
        self.latent_dim = latent_dim
        self.decoder = nn.Sequential(
            nn.Linear(self.latent_dim, 128 * 2*2),
            nn.Unflatten(1, (128, 2, 2)),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),  # Output: [batch, 32, 63]
            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=0),  # Output: [batch, 32, 127]
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=0),  # Output: [batch, 16, 255]
            nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2, padding=1, output_padding=1),  # Output: [batch, 3, 511]


    def reconstruct_masked_image(self, unmasked_patches, patch_indices, image_shape, patch_size):
        B, C, H, W = image_shape
        reconstructed_images = torch.zeros((B, C, H, W)).to(unmasked_patches.device)

        for b in range(B):
            patches = []
            for i in range(len(patch_indices[b])):
                patches.append(unmasked_patches[b, :, patch_size * i: patch_size * (i + 1)])

            for idx, linear_idx in enumerate(patch_indices[b]):
                i = linear_idx // (W // patch_size)
                j = linear_idx % (W // patch_size)

                reconstructed_images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = patches[idx].unsqueeze(2)

        return reconstructed_images

    def forward(self, x, patch_indices, image_shape, patch_size):
        decoded = self.decoder(x)
        #masked_image_recon = self.reconstruct_masked_image(decoded, patch_indices, image_shape, patch_size)
        return decoded

class Autoencoder1D(nn.Module):
    def __init__(self, latent_dim, patch_size):
        super(Autoencoder1D, self).__init__()
        self.latent_dim = latent_dim
        self.patch_size = patch_size
        self.encoder = Encoder1D(latent_dim)
        self.decoder = Decoder1D(latent_dim)

    def Power_norm(self, z, P=1 / np.sqrt(2)):

        batch_size, z_dim = z.shape
        z_power = torch.sqrt(torch.sum(z ** 2, 1))
        z_M = z_power.repeat(z_dim, 1)

        return np.sqrt(P * z_dim) * z / z_M.t()

    def Power_norm_complex(self, z, P=1 / np.sqrt(2)):
        batch_size, z_dim = z.shape
        z_com = torch.complex(z[:, 0:z_dim:2], z[:, 1:z_dim:2])
        z_com_conj = torch.complex(z[:, 0:z_dim:2], -z[:, 1:z_dim:2])
        z_power = torch.sum(z_com * z_com_conj, 1).real
        z_M = z_power.repeat(z_dim // 2, 1)
        z_nlz = np.sqrt(P * z_dim) * z_com / torch.sqrt(z_M.t())
        z_out = torch.zeros(batch_size, z_dim).to(device)
        z_out[:, 0:z_dim:2] = z_nlz.real
        z_out[:, 1:z_dim:2] = z_nlz.imag
        return z_out

    def AWGN_channel(self, x, snr, P=1):
        batch_size, length = x.shape
        gamma = 10 ** (snr / 10.0)
        noise = np.sqrt(P / gamma) * torch.randn(batch_size, length).cuda()
        y = x + noise
        return y

    def Fading_channel(self, x, snr, P=1):
        gamma = 10 ** (snr / 10.0)
        [batch_size, feature_length] = x.shape
        K = feature_length // 2

        h_I = torch.randn(batch_size, K).to(device)
        h_R = torch.randn(batch_size, K).to(device)
        h_com = torch.complex(h_I, h_R)
        x_com = torch.complex(x[:, 0:feature_length:2], x[:, 1:feature_length:2])
        y_com = h_com * x_com

        n_I = np.sqrt(P / gamma) * torch.randn(batch_size, K).to(device)
        n_R = np.sqrt(P / gamma) * torch.randn(batch_size, K).to(device)
        noise = torch.complex(n_I, n_R)

        y_add = y_com + noise
        y = y_add / h_com

        y_out = torch.zeros(batch_size, feature_length).to(device)
        y_out[:, 0:feature_length:2] = y.real
        y_out[:, 1:feature_length:2] = y.imag

        return y_out

    def forward(self, x, SNRdB, channel, patch_index, image_shape, patch_size):
        encoded = self.encoder(x)
        if channel == 'AWGN':
            normalized_x = self.Power_norm(encoded)
            channel_output = self.AWGN_channel(normalized_x, SNRdB)
        elif channel == 'Rayleigh':
            normalized_complex_x = self.Power_norm_complex(encoded)
            channel_output = self.Fading_channel(normalized_complex_x, SNRdB)
        decoded = self.decoder(channel_output, patch_index, image_shape, patch_size)
        return decoded

def preprocess_and_save_dataset(dataset, root_dir, patch_size, mask_ratio, importance_type, how_many, noise_scale):
    os.makedirs(root_dir, exist_ok=True)
    for i, (images, _) in tqdm(enumerate(dataset), total=len(dataset)):

        masked_images, unmasked_patches_image, patch_index = chessboard_mask(images.unsqueeze(0), patch_size, mask_ratio, importance_type, how_many, noise_scale)

            'original_images': images,
            'masked_images'  : masked_images.squeeze(0),
            'unmasked_patches': unmasked_patches_image,
            'patch_index' : patch_index

        }, os.path.join(root_dir, f'data_{i}.pt'))

class PreprocessedCIFAR10Dataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.file_paths = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith('.pt')]

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        data = torch.load(file_path, weights_only=False)

        original_images = data['original_images']
        masked_images = data['masked_images']
        unmasked_patches = data['unmasked_patches']
        patch_index = data['patch_index']

        if self.transform:
            unmasked_patches = self.transform(unmasked_patches)
            masked_images = self.transform(masked_images)
            original_images = self.transform(original_images)
            patch_index = self.transform(patch_index)

        return original_images, masked_images, unmasked_patches, patch_index

def train(latent_dim, patch_size, mask_ratio, trainloader, testloader, ES, IT, HM, NS):
    for snr_i in range(len(params['SNR'])) :
        model = Autoencoder1D(latent_dim=latent_dim, patch_size=patch_size).to(device)
        print("Model size : {}".format(count_parameters(model)))
        criterion = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr=params['LR'])
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=8, factor=0.5, verbose=True)

        min_test_cost = float('inf')
        epochs_no_improve = 0
        n_epochs_stop = ES

        print("+++++ SNR = {} Training Start! +++++\t".format(params['SNR'][snr_i]))

        max_psnr = 0
        previous_best_model_path = None

        for epoch in range(params['EP']):
            # ========================================== Train ==========================================
            train_loss = 0.0

            timetemp = time.time()

            for original_images, masked_images, unmasked_patches, patch_index in trainloader:
                unmasked_patches = unmasked_patches.squeeze(2).to(device)

                original_images = original_images.to(device)
                image_shape = original_images.shape


                outputs = model(unmasked_patches, SNRdB = params['SNR'][snr_i], channel = params['channel'], patch_index = patch_index, image_shape = image_shape, patch_size = patch_size)

                loss = criterion(unmasked_patches, outputs)
                train_loss += loss.item()

            train_cost = train_loss / len(trainloader)
            tr_psnr = round(10 * math.log10(1.0 / train_cost), 3)

            # ========================================================================
            test_loss = 0.0
            with torch.no_grad():
                for original_images, masked_images, unmasked_patches, patch_index in testloader:
                    unmasked_patches = unmasked_patches.squeeze(2).to(device)
                    original_images = original_images.to(device)
                    image_shape = original_images.shape
                    outputs = model(unmasked_patches, SNRdB = params['SNR'][snr_i], channel = params['channel'], patch_index = patch_index, image_shape = image_shape, patch_size = patch_size)
                    loss = criterion(unmasked_patches, outputs)
                    test_loss += loss.item()

            test_cost = test_loss / len(testloader)
            test_psnr = round(10 * math.log10(1.0 / test_cost), 3)


            # 조기 중지 조건 확인
            if test_cost < min_test_cost:
                min_test_cost = test_cost
                epochs_no_improve = 0
                epochs_no_improve += 1

            if epochs_no_improve == n_epochs_stop:
                print("Early stopping!")
                break  # 조기 종료

            training_time = time.time() - timetemp

                "[{:>3}-Epoch({:>5}sec.)]  PSNR(Train / Val) : {:>6.4f} / {:>6.4f}".format(
                    epoch + 1, round(training_time, 2), tr_psnr, test_psnr))

            if test_psnr > max_psnr:
                save_folder = 'trained_model'

                if not os.path.exists(save_folder):
                previous_psnr = max_psnr
                max_psnr = test_psnr

                if previous_best_model_path is not None:
                    print(f"Performance update!! {previous_psnr} to {max_psnr}")

                save_path = os.path.join(save_folder, f"CBM(PS={patch_size}_DIM={latent_dim}_MR={mask_ratio}_IT={IT}_HM={HM}_NS={NS}_SNR={params['SNR'][snr_i]}_PSNR={max_psnr}).pt")
                torch.save(model, save_path)
                print(f"Saved new best model at {save_path}")

                previous_best_model_path = save_path

if __name__ == '__main__':
    for ps_i in range(len(params['PS'])):
        for dim_i in range(len(params['DIM'])):
            for mr_i in range(len(params['MR'])):
                for hm_i in range(len(params['HM'])) :
                    Processed_train_path = "ProcessedTrain(PS=" + str(params['PS'][ps_i]) + "_MR=" + str(params['MR'][mr_i]) + "_IT=" + str(params['IT']) + "_HM=" + str(params['HM'][hm_i]) + ")"
                    Processed_test_path  = "ProcessedTest(PS=" + str(params['PS'][ps_i]) + "_MR=" + str(params['MR'][mr_i]) + "_IT=" + str(params['IT']) + "_HM=" + str(params['HM'][hm_i]) + ")"

                    if not os.path.exists(Processed_train_path):
                        transform = transforms.Compose([transforms.ToTensor()])
                        train_cifar10 = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
                        preprocess_and_save_dataset(train_cifar10, Processed_train_path, patch_size=params['PS'][ps_i], mask_ratio=params['MR'][mr_i], importance_type=params['IT'], how_many=params['HM'][hm_i], noise_scale=params['NS'])

                    if not os.path.exists(Processed_test_path):
                        transform = transforms.Compose([transforms.ToTensor()])
                        test_cifar10 = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
                        preprocess_and_save_dataset(test_cifar10, Processed_test_path, patch_size=params['PS'][ps_i], mask_ratio=params['MR'][mr_i], importance_type=params['IT'], how_many=params['HM'][hm_i], noise_scale=params['NS'])

                    traindataset = PreprocessedCIFAR10Dataset(root_dir=Processed_train_path)
                    testdataset  = PreprocessedCIFAR10Dataset(root_dir=Processed_test_path)

                    trainloader = DataLoader(traindataset, batch_size=params['BS'], shuffle=True, num_workers=4, drop_last = True)
                    testloader = DataLoader(testdataset, batch_size=params['BS'], shuffle=True, num_workers=4, drop_last = True)

                    train(params['DIM'][dim_i], params['PS'][ps_i], params['MR'][mr_i], trainloader, testloader, params['ES'], params['IT'], params['HM'][hm_i], params['NS'])
params = {
    'BS': 64,
    'LR': 0.001,
    'EP': 5000,
    'SNR': [40, 0],
    'DIM': [512, 256],
    'MR' : [0.33984],
    'PS' : [2],
    'ES' : 40,
    'IT' : 'variance',
    'HM' : [1, 2, 3, 4],
    'NS' : 0,
    'channel' : 'Rayleigh'

