UOMOP
cifar10 HPF filter size에 따른 결과 확인 본문
import cv2
import math
import time
import torch
import random
import torchvision
import numpy as np
from PIL import Image
import torch.nn as nn
from numpy import sqrt
from tqdm import trange
import torch.optim as optim
import torch.nn.functional as f
import matplotlib.pyplot as plt
import torchvision.transforms as tr
from torchvision import datasets
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Dataset
args = {
'BATCH_SIZE' : 50,
'LEARNING_RATE' : 0.001,
'NUM_EPOCH' : 50,
'SNR_dB' : [1, 10, 20],
'latent_dim' : 10,
'input_dim' : 32*32,
'filter_size' : [1]
}
transf = tr.Compose([tr.ToTensor(), tr.Grayscale(num_output_channels = 1) ])
trainset = torchvision.datasets.CIFAR10(root = './data', train = True, download = True, transform = transf)
testset = torchvision.datasets.CIFAR10(root = './data', train = False, download = True, transform = transf)
trainloader = DataLoader(trainset, batch_size = args['BATCH_SIZE'], shuffle = True)
testloader = DataLoader(testset, batch_size = args['BATCH_SIZE'], shuffle = False)
def cal_D(c_row, c_col, r, c) :
s = (c_row-r)**2+ (c_col-c)**2
return s**(1/2)
def filter_radius(fshift, rad, low = True) :
rows, cols = fshift.shape
c_row, c_col = int(rows/2), int(cols/2)
filter_fshift = fshift.copy()
for r in range(rows) :
for c in range(cols) :
if low :
if cal_D(c_row, c_col, r, c) > rad :
filter_fshift[r, c] = 0
else :
if cal_D(c_row, c_col, r, c) < rad :
filter_fshift[r, c] = 0
return filter_fshift
def HPF_tensor(x, filter_size) :
save_list = []
#print("x.shape : {}".format(x.shape))
x = x.view(args['BATCH_SIZE'], 1, 32, 32)
fig = plt.figure()
for i in range(x.shape[0]) :
plt.subplot(1, 2, 1)
plt.imshow(x[i][0].detach().numpy(), cmap = 'gray')
#print(x[i][0].size())
f = np.fft.fft2(x[i][0])
fshift = np.fft.fftshift(f)
high_fshift = filter_radius(fshift, rad = filter_size, low = False)
high_ishift = np.fft.ifftshift(high_fshift)
high_img = np.fft.ifft2(high_ishift)
high_img = np.abs(high_img)
save_list.append(high_img)
plt.subplot(1, 2, 2)
plt.imshow(high_img, cmap = 'gray')
save_arr = np.array(save_list).reshape(x.shape[0], 1, 32, 32)
save_tensor = torch.tensor(save_arr, dtype = torch.float)
return save_tensor
class Encoder(nn.Module) :
def __init__(self) :
super(Encoder, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(args['input_dim'], 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, args['latent_dim']),
nn.ReLU()
)
def forward(self, x) :
return self.encoder(x)
class Decoder(nn.Module) :
def __init__(self) :
super(Decoder, self).__init__()
self.decoder = nn.Sequential(
nn.Linear(args['latent_dim'], 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, args['input_dim']),
nn.Sigmoid()
)
def forward(self, x) :
return self.decoder(x)
class Autoencoder(nn.Module) :
def __init__(
self,
encoder_class : object = Encoder,
decoder_class : object = Decoder
) :
super(Autoencoder, self).__init__()
self.encoder = encoder_class()
self.decoder = decoder_class()
def add_AWGN2tensor(self, input, SNRdB):
normalized_tensor = f.normalize(input, dim=1)
SNR = 10.0 ** (SNRdB / 10.0)
K = args['latent_dim']
std = 1 / math.sqrt(K * SNR)
noise = torch.normal(0, std, size=normalized_tensor.size())
return input + noise
def forward(self, x, SNRdB) :
encoded = self.encoder(x)
encoded = self.add_AWGN2tensor(encoded, SNRdB)
decoded = self.decoder(encoded)
return decoded
for j in range(len(args['filter_size'])):
print("=============== Filter_size = {} ===============".format(args['filter_size'][j]))
for i in range( len(args['SNR_dB']) ):
model = Autoencoder()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr = args['LEARNING_RATE'])
print("+++++ SNR = {} Training Start! +++++".format(args['SNR_dB'][i]))
for epoch in range(args['NUM_EPOCH']) :
running_loss = 0.0
for data in trainloader :
inputs = data[0]
inputs_HPF = HPF_tensor(inputs, args['filter_size'][j])
optimizer.zero_grad()
outputs = model( inputs_HPF.view(-1, args['input_dim']) , i)
outputs = outputs.view(-1, 1, 32, 32)
loss = criterion(inputs_HPF, outputs)
loss.backward()
optimizer.step()
running_loss += loss.item()
cost = running_loss/len(trainloader)
print("[{} Epoch] Loss : {}".format(epoch + 1, round(cost, 6)))
print()
PATH = "./"
torch.save(model.state_dict(), PATH + "model_AWGN" + "_SNR=" + str(args['SNR_dB'][i]) + "_filsize=" + str(args['filter_size'][j]) +".pth")
print("==================================================\n\n")
cut-off freq가 낮을 수록 원래 이미지에 가까움
'Wireless Comm. > Python' 카테고리의 다른 글
WSR maximization using WMMSE for MIMO-BC beamforming design using Python (0) | 2023.09.15 |
---|---|
***AutoEncoder cifar10(color) 1dB, 10dB, 20dB (0) | 2023.06.08 |
***Cifar10_AE(color)_20230608 (0) | 2023.06.08 |
Traditional AutoEncoder Cifar10(gray) 2023 06 07 (0) | 2023.06.07 |
send img with contour : 20230603 (0) | 2023.06.03 |
Comments