UOMOP

ㅇㄹㅇㄹ 본문

MAE

ㅇㄹㅇㄹ

Happy PinGu 2024. 5. 23. 21:09
import torch
import torch.nn as nn
import math
import torch.nn.functional as f
import torch.optim as optim
from params import *
import time
import os
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torch.nn.functional as F
import math
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as f
from torch.utils.data import DataLoader, Dataset
import time
from params import *
import os
from tqdm import tqdm
import numpy as np
import cv2


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def masked2vector(input, patch_size) :

    input = input.squeeze(0)

    #print("input.shape : {}".format(input.shape))

    num_patch = int(input.shape[1] * input.shape[2] / 4)
    #print("num_patch : {}".format(num_patch))
    num_row = int(math.sqrt(num_patch))
    output_vector = []
    indice = []
    start = 0

    #print("num_row : {}".format(num_row))

    for i in range(num_row):
        for j in range(num_row):

            if input[0, patch_size * i, patch_size * j].item() != 0:

                indice.append(start)

                R = input[0, patch_size * i: patch_size * (i + 1), patch_size * j: patch_size * (j + 1)].flatten()
                G = input[1, patch_size * i: patch_size * (i + 1), patch_size * j: patch_size * (j + 1)].flatten()
                B = input[2, patch_size * i: patch_size * (i + 1), patch_size * j: patch_size * (j + 1)].flatten()

                concatenated_vector = torch.cat((R, G, B))

                for k in range(len(concatenated_vector)):
                    output_vector.append(concatenated_vector[k].item())

            start += 1

    return torch.tensor(output_vector), indice

def patch_std(image, patch_size=2):

    H, W = image.shape
    std_map = np.zeros((H // patch_size, W // patch_size))
    noise_scale_1 = 1e-2  # 작은 노이즈 스케일 설정
    #noise_scale_2 = 1e-4


    for i in range(0, H, patch_size):
        for j in range(0, W, patch_size):
            patch = image[i:i + patch_size, j:j + patch_size]
            std_value = np.std(patch)
            noise_1 = np.random.randn() * noise_scale_1
            #noise_2 = np.random.randn() * noise_scale_2
            std_map[i // patch_size, j // patch_size] = std_value + noise_1

    #print(std_map)

    return std_map
def mask_patches_chessboard(images, patch_size=2, mask_ratio=0.5, complexity_based=False):

    if mask_ratio != 0.5:
        B, C, H, W = images.shape
        masked_images = images.clone()

        for b in range(B):
            image = images[b].permute(1, 2, 0).cpu().numpy() * 255
            image = image.astype(np.uint8)
            gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)

            # Calculate complexity for each patch
            complexity_map = patch_std(gray, patch_size)

            # Initialize mask with chessboard pattern for complexity map dimensions
            complexity_height, complexity_width = complexity_map.shape
            mask = np.zeros((complexity_height, complexity_width), dtype=bool)
            mask[::2, 1::2] = 1
            mask[1::2, ::2] = 1

            if complexity_based:
                if mask_ratio > 0.5:
                    additional_masking_ratio = (mask_ratio - 0.5) / 0.5
                    complexity_threshold = np.quantile(complexity_map[~mask], 1 - additional_masking_ratio)
                    #print("complex : {}".format(complexity_threshold))
                    additional_mask = complexity_map < complexity_threshold
                    mask[~mask] = additional_mask[~mask]

                else:
                    unmasking_ratio = (0.5 - mask_ratio) / 0.5
                    complexity_threshold = np.quantile(complexity_map[mask], unmasking_ratio)
                    #print("complex : {}".format(complexity_threshold))
                    unmask = complexity_map > complexity_threshold
                    mask[mask] = ~unmask[mask]

            # Apply mask to the original image based on complexity map
            for i in range(complexity_height):
                for j in range(complexity_width):
                    if mask[i, j]:
                        image[i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = 0

            # Convert image back to PyTorch format
            image = image.astype(np.float32) / 255.0
            masked_images[b] = torch.from_numpy(image).permute(2, 0, 1)

    elif mask_ratio == 0.5:
        B, C, H, W = images.shape
        masked_images = images.clone()

        # Create the chessboard pattern
        pattern = np.tile(np.array([[1, 0] * (W // (2 * patch_size)), [0, 1] * (W // (2 * patch_size))]),
                          (H // (2 * patch_size), 1))

        for b in range(B):
            image = images[b].permute(1, 2, 0).cpu().numpy() * 255
            image = image.astype(np.uint8)

            # Apply masking
            mask = np.repeat(np.repeat(pattern, patch_size, axis=0), patch_size, axis=1)
            image[mask == 0] = 0  # Apply chessboard pattern masking

            # Convert back to PyTorch format
            image = image.astype(np.float32) / 255.0
            masked_images[b] = torch.from_numpy(image).permute(2, 0, 1)

    squeezed_masked_images = masked_images.squeeze(0)

    vector, position = masked2vector(masked_images, params['PS'])

    #print(len(vector))

    return squeezed_masked_images, vector, position   # squeezed_masked_images.shape : torch.Size([3, 32, 32])

def preprocess_and_save_dataset(dataset, root_dir, patch_size, mask_ratio):

    os.makedirs(root_dir, exist_ok=True)
    for i, (images, _) in tqdm(enumerate(dataset), total=len(dataset)):

        masked_images, vector, position = mask_patches_chessboard(images.unsqueeze(0), patch_size, mask_ratio, complexity_based=True)

        if len(vector) == 2304 :
            #print()

            torch.save({
                'masked_images': masked_images.squeeze(0),
                'vector' : vector,
                'position' : position,
                'original_images': images
            }, os.path.join(root_dir, f'data_{i}.pt'))

class PreprocessedCIFAR10Dataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.file_paths = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith('.pt')]

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        data = torch.load(file_path)

        masked_images = data['masked_images']
        vector = data['vector']
        position = data['position']
        original_images = data['original_images']

        '''
        print("Masked Images Shape:", masked_images.shape)
        print("Vector Shape:", len(vector))
        #print(vector)
        print("Position Shape:", len(position))
        #print(position)
        print("Original Images Shape:", original_images.shape)
        print()
        '''


        if not isinstance(vector, torch.Tensor):
            vector = torch.tensor(vector, dtype=torch.float32)  # vector 데이터를 float 타입의 텐서로 변환
        if not isinstance(position, torch.Tensor):
            position = torch.tensor(position, dtype=torch.int)  # position 데이터를 long 타입의 텐서로 변환

        if self.transform:
            masked_images = self.transform(masked_images)
            original_images = self.transform(original_images)

        return masked_images, vector, position, original_images

import torch
import torch.nn as nn

import torch
import torch.nn as nn

import torch
import torch.nn as nn

import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()

        self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1)
        self.relu2 = nn.ReLU()

        self.conv3 = nn.Conv1d(in_channels=32, out_channels=1, kernel_size=3, stride=2, padding=1)
        self.relu3 = nn.ReLU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)

        x = self.conv2(x)
        x = self.relu2(x)

        x = self.conv3(x)
        x = self.relu3(x)
        return x

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.tconv1 = nn.ConvTranspose1d(in_channels=1, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.relu1 = nn.ReLU()

        self.tconv2 = nn.ConvTranspose1d(in_channels=32, out_channels=16, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.relu2 = nn.ReLU()

        self.tconv3 = nn.ConvTranspose1d(in_channels=16, out_channels=1, kernel_size=3, stride=1, padding=1)
        self.sig = nn.Sigmoid()

    def forward(self, x):
        x = self.tconv1(x)
        x = self.relu1(x)

        x = self.tconv2(x)
        x = self.relu2(x)

        x = self.tconv3(x)
        x = self.sig(x)
        return x



class Autoencoder(nn.Module):
    def __init__(self, channels, latent_dim):
        super(Autoencoder, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = Encoder()
        self.decoder = Decoder()
    def AWGN(self, input, SNRdB):

        #print()
        #print("Before")
        #print(input.shape)
        #print(input[0])
        #print()

        normalized_tensor = f.normalize(input, dim=1)

        #print("after normal")
        #print(normalized_tensor[0])
        #print()
        SNR = 10.0 ** (SNRdB / 10.0)
        std = 1 / math.sqrt(self.latent_dim * SNR)
        n = torch.normal(0, std, size=normalized_tensor.size()).to(device)
        #print((normalized_tensor+n)[0])
        #print()

        return normalized_tensor + n

    def forward(self, x, SNRdB):
        #print()
        x = x.unsqueeze(1).to(device)
        #print(x[0])
        encoded = self.encoder(x)
        #print(111111)
        #print(encoded[0])

        #channel_output = self.AWGN(encoded, SNRdB)
        #print(22222222222222)
        #print(channel_output[0])
        decoded = self.decoder(encoded).squeeze(1)
        #print(decoded[0])
        #print()
        return decoded


def train(latent_dim, mask_ratio, trainloader, testloader):
    for snr_i in range(len(params['SNR'])):

        model = Autoencoder(channels=1, latent_dim=latent_dim).to(device)
        print("Model size : {}".format(count_parameters(model)))

        criterion = nn.MSELoss()
        optimizer = optim.SGD(model.parameters(), lr=params['LR'], momentum=0.9)

        min_test_cost = float('inf')
        epochs_no_improve = 0
        n_epochs_stop = 20

        print("+++++ SNR = {} Training Start! +++++\t".format(params['SNR'][snr_i]))

        max_psnr = 0
        previous_best_model_path = None

        for epoch in range(params['EP']):
            # ========================================== Train ==========================================
            train_loss = 0.0

            model.train()
            timetemp = time.time()

            for masked_images, vector, position, original_images in trainloader:

                vector = vector.to(device)
                #print(vector[0])
                #print(1)
                #print(vector.shape)

                optimizer.zero_grad()
                recon_vector = model(vector, SNRdB = params['SNR'][snr_i])
                #print(recon_vector.shape)
                #print()
                #print(vector[0])
                #print(recon_vector[0])

                loss = criterion(vector, recon_vector)

                loss.backward()
                optimizer.step()

                train_loss += loss.item()

            train_cost = train_loss / len(trainloader)
            tr_psnr = round(10 * math.log10(1.0 / train_cost), 3)

            # ========================================== Test ==========================================
            test_loss = 0.0

            model.eval()
            with torch.no_grad():
                for masked_images, vector, position, original_images in testloader:
                    vector = vector.to(device)

                    recon_vector = model(vector, SNRdB = params['SNR'][snr_i])

                    loss = criterion(vector, recon_vector)

                    test_loss += loss.item()

            test_cost = test_loss / len(testloader)
            test_psnr = round(10 * math.log10(1.0 / test_cost), 3)


            if test_cost < min_test_cost:
                min_test_cost = test_cost
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1

            if epochs_no_improve == n_epochs_stop:
                print("Early stopping!")
                break

            training_time = time.time() - timetemp

            print(
                "[{:>3}-Epoch({:>5}sec.)]  PSNR(Train / Val) : {:>6.4f} / {:>6.4f}        Loss(Train / Val) : {:>5.5f} / {:>5.5f}".format(
                    epoch + 1, round(training_time, 2), tr_psnr, test_psnr, train_cost,  test_cost))

            if test_psnr > max_psnr:

                save_folder = 'trained_model'

                if not os.path.exists(save_folder):
                    os.makedirs(save_folder)
                previous_psnr = max_psnr
                max_psnr = test_psnr

                if previous_best_model_path is not None:
                    os.remove(previous_best_model_path)
                    print(f"Performance update!! {previous_psnr} to {max_psnr}")

                save_path = os.path.join(save_folder, f"AE_1D(MR={mask_ratio}_SNR={params['SNR'][snr_i]}_PSNR={max_psnr}).pt")
                torch.save(model, save_path)
                print(f"Saved new best model at {save_path}")

                previous_best_model_path = save_path



if __name__ == '__main__':

    for mr_i in range(len(params['MR'])):

        Processed_train_path = "ProcessedTrain(PS=" + str(params['PS']) + "_MR=" + str(params['MR'][mr_i]) + ")"
        Processed_test_path  = "ProcessedTest(PS=" + str(params['PS']) + "_MR=" + str(params['MR'][mr_i]) + ")"

        if not os.path.exists(Processed_train_path):
            transform = transforms.Compose([transforms.ToTensor()])
            train_cifar10 = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
            preprocess_and_save_dataset(train_cifar10, Processed_train_path, patch_size=params['PS'], mask_ratio=params['MR'][mr_i])

        if not os.path.exists(Processed_test_path):
            transform = transforms.Compose([transforms.ToTensor()])
            test_cifar10 = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
            preprocess_and_save_dataset(test_cifar10, Processed_test_path, patch_size=params['PS'], mask_ratio=params['MR'][mr_i])

        traindataset = PreprocessedCIFAR10Dataset(root_dir=Processed_train_path)
        testdataset  = PreprocessedCIFAR10Dataset(root_dir=Processed_test_path)

        trainloader = DataLoader(traindataset, batch_size=params['BS'], shuffle=True, num_workers=4)
        testloader = DataLoader(testdataset, batch_size=params['BS'], shuffle=True, num_workers=4)

        train(params['DIM'][mr_i], params['MR'][mr_i], trainloader, testloader)
params = {
    'BS': 32,
    'LR': 0.01,
    'EP': 1000,
    'SNR': [100],
    'DIM': [2304, 1536, 768],
    'MR' : [0.25, 0.5, 0.75],
    'PS' : 2,
    'IS' : (3, 32, 32)
}

'MAE' 카테고리의 다른 글

CBS (odd, even) masking  (0) 2024.04.18
line graph  (0) 2024.02.21
bar graph  (0) 2024.02.21
Comments