UOMOP

cifar10 HPF filter size에 따른 결과 확인 본문

Wireless Comm./Python

cifar10 HPF filter size에 따른 결과 확인

Happy PinGu 2023. 6. 8. 19:28
import cv2
import math
import time
import torch
import random
import torchvision
import numpy as np
from PIL import Image
import torch.nn as nn
from numpy import sqrt
from tqdm import trange
import torch.optim as optim
import torch.nn.functional as f
import matplotlib.pyplot as plt
import torchvision.transforms as tr
from torchvision import datasets
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Dataset

args = {
    'BATCH_SIZE' : 50,
    'LEARNING_RATE' : 0.001,
    'NUM_EPOCH' : 50,
    'SNR_dB' : [1, 10, 20], 
    'latent_dim' : 10,
    'input_dim' : 32*32,
    'filter_size' : [1]
}

transf = tr.Compose([tr.ToTensor(), tr.Grayscale(num_output_channels = 1) ])

trainset = torchvision.datasets.CIFAR10(root = './data', train = True, download = True, transform = transf) 
testset = torchvision.datasets.CIFAR10(root = './data', train = False, download = True, transform = transf)

trainloader = DataLoader(trainset, batch_size = args['BATCH_SIZE'], shuffle = True)
testloader  = DataLoader(testset, batch_size = args['BATCH_SIZE'], shuffle = False)


def cal_D(c_row, c_col, r, c) :
    s = (c_row-r)**2+ (c_col-c)**2
    return s**(1/2)

def filter_radius(fshift, rad, low = True) :
    rows, cols = fshift.shape
    c_row, c_col = int(rows/2), int(cols/2)

    filter_fshift = fshift.copy()

    for r in range(rows) :
        for c in range(cols) :
            if low :
                if cal_D(c_row, c_col, r, c) > rad :
                    filter_fshift[r, c] = 0

            else :
                if cal_D(c_row, c_col, r, c) < rad :
                    filter_fshift[r, c] = 0

    return filter_fshift
    
    
    
def HPF_tensor(x, filter_size) :
    save_list = []
    #print("x.shape : {}".format(x.shape))

    x = x.view(args['BATCH_SIZE'], 1, 32, 32)


    fig = plt.figure()

    for i in range(x.shape[0]) :

        plt.subplot(1, 2, 1)
        plt.imshow(x[i][0].detach().numpy(), cmap = 'gray')
        #print(x[i][0].size())


        f = np.fft.fft2(x[i][0])
        fshift = np.fft.fftshift(f)
        high_fshift = filter_radius(fshift, rad = filter_size, low = False)
        high_ishift = np.fft.ifftshift(high_fshift)
        high_img = np.fft.ifft2(high_ishift)
        high_img = np.abs(high_img)
        save_list.append(high_img)

        plt.subplot(1, 2, 2)
        plt.imshow(high_img, cmap = 'gray')


    save_arr = np.array(save_list).reshape(x.shape[0], 1, 32, 32)
    save_tensor = torch.tensor(save_arr, dtype = torch.float)

    return save_tensor
    
    
class Encoder(nn.Module) :
    def __init__(self) :
        super(Encoder, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(args['input_dim'], 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, args['latent_dim']),
            nn.ReLU()
        )


    def forward(self, x) :
        
        return self.encoder(x)


class Decoder(nn.Module) :
    def __init__(self) :
        super(Decoder, self).__init__()
        
        self.decoder = nn.Sequential(
            nn.Linear(args['latent_dim'], 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, args['input_dim']),
            nn.Sigmoid()
        )

    def forward(self, x) :
        
        return self.decoder(x)


class Autoencoder(nn.Module) :
    def __init__(
        self,
        encoder_class : object = Encoder,
        decoder_class : object = Decoder
        ) :

        super(Autoencoder, self).__init__()

        self.encoder = encoder_class()
        self.decoder = decoder_class()



    def add_AWGN2tensor(self, input, SNRdB):

        normalized_tensor = f.normalize(input, dim=1)

        SNR = 10.0 ** (SNRdB / 10.0)
        K = args['latent_dim']
        std = 1 / math.sqrt(K * SNR)
        noise = torch.normal(0, std, size=normalized_tensor.size())

        return input + noise



    def forward(self, x, SNRdB) :
        encoded = self.encoder(x)
        encoded = self.add_AWGN2tensor(encoded, SNRdB)
        decoded = self.decoder(encoded)

        return decoded
        
        
        
for j in range(len(args['filter_size'])):

    print("=============== Filter_size = {} ===============".format(args['filter_size'][j]))

    for i in range( len(args['SNR_dB']) ):

        model = Autoencoder()
        criterion = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr = args['LEARNING_RATE'])

        print("+++++ SNR = {} Training Start! +++++".format(args['SNR_dB'][i]))

        for epoch in range(args['NUM_EPOCH']) :

            running_loss = 0.0

            for data in trainloader :


                inputs = data[0]
                inputs_HPF = HPF_tensor(inputs, args['filter_size'][j])

                optimizer.zero_grad()

                outputs = model( inputs_HPF.view(-1, args['input_dim']) , i)
                outputs = outputs.view(-1, 1, 32, 32)

                loss = criterion(inputs_HPF, outputs)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()

            cost = running_loss/len(trainloader)
            print("[{} Epoch] Loss : {}".format(epoch + 1, round(cost, 6)))

        print()
        PATH = "./"
        torch.save(model.state_dict(), PATH + "model_AWGN" + "_SNR=" + str(args['SNR_dB'][i]) + "_filsize=" + str(args['filter_size'][j]) +".pth")



    print("==================================================\n\n")

cut-off freq가 낮을 수록 원래 이미지에 가까움
Comments