UOMOP
AWGN_cifar10_20230804 본문
######################## Library ########################
import cv2
import math
import torch
import torchvision
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as f
import matplotlib.pyplot as plt
import torchvision.transforms as tr
from torch.utils.data import DataLoader
######################## GPU Check ########################
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
######################## Control Tower ########################
args = {
'BATCH_SIZE' : 64,
'LEARNING_RATE' : 0.001,
'NUM_EPOCH' : 500,
'SNRdB_list' : [1, 10, 20, 30],
'latent_dim' : 512,
'input_dim' : 32 * 32,
'filter_type' : 'scharr',
'filter_dir' : 'x',
'rl_pat' : 10
}
######################## Function ########################
def sobel_filter(img, x_or_y) :
if x_or_y == 'x' :
return cv2.Sobel(img, -1, 1, 0)
else :
return cv2.Sobel(img, -1, 0, 1)
def scharr_filter(img, x_or_y) :
if x_or_y == 'x' :
return cv2.Scharr(img, -1, 1, 0)
else :
return cv2.Scharr(img, -1, 0, 1)
def average_filter(img, kernel_size) :
return cv2.blur(img, (kernel_size, kernel_size))
def gaussian_filter(img, kernel_size) :
# cv2.GaussianBlur의 parameter에는 sigma가 존재하는데, 이것은 가우시안 커널의 X, Y 방향의 표준편차이다.
return cv2.GaussianBlur(img, (kernel_size, kernel_size), 0)
def median_filter(img, kernel_size) :
return cv2.medianBlur(img, kernel_size)
def bilateral_filter(img, kernel_size) :
return cv2.bilateralFilter(img, kernel_size, 75, 75)
def model_train(SNRdB) :
model = Autoencoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr = args['LEARNING_RATE'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=args['rl_pat'], threshold=1e-3, verbose = True)
print("\n\n+++++ SNR = {} Training Start! +++++\t".format(SNRdB))
for epoch in range(args['NUM_EPOCH']) :
train_MSE = 0.0
val_MSE = 0.0
############ Train ############
for train_data in trainloader :
inputs = train_data[0].to(device)
optimizer.zero_grad()
outputs = model(inputs, SNRdB = SNRdB)
loss = criterion(inputs, outputs)
loss.backward()
optimizer.step()
train_MSE += loss.item()
train_MSE_average = train_MSE / len(trainloader)
############ Validation ############
for val_data in testloader :
inputs = val_data[0].to(device)
outputs = model( inputs, SNRdB = SNRdB)
loss = criterion(inputs, outputs)
val_MSE += loss.item()
val_MSE_average = val_MSE / len(testloader)
scheduler.step(val_MSE_average)
print("[{} Epoch] train_PSNR : {}\tvalidation_PSNR : {}"
.format(epoch + 1, round(10*math.log10(1/train_MSE_average), 3), round(10*math.log10(1/val_MSE_average), 3)))
torch.save(model.state_dict(), "./AWGN_" + "dim=" + str(args['latent_dim']) + "_SNRdB=" + str(SNRdB) +".pth")
######################## DataSet ########################
transf = tr.Compose([tr.ToTensor()])
# 원래 pytorch cifar10은 0~1사이의 값을 가진다.
# -1~1로 정규화를 시켜준다.
trainset = torchvision.datasets.CIFAR10(root = './data', train = True, download = True, transform = transf)
testset = torchvision.datasets.CIFAR10(root = './data', train = False, download = True, transform = transf)
trainloader = DataLoader(trainset, batch_size = args['BATCH_SIZE'], shuffle = True)
testloader = DataLoader(testset, batch_size = args['BATCH_SIZE'], shuffle = True)
######################## Encoder ########################
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
c_hid = 32
self.encoder = nn.Sequential(
nn.Conv2d(3, 16, kernel_size = 5, padding = 'valid', stride = 2),
nn.PReLU(),
nn.Conv2d(16, 32, kernel_size = 5, padding = 'valid', stride = 2),
nn.PReLU(),
nn.Conv2d(32, 32, kernel_size = 5, padding = 'same', stride = 1),
nn.PReLU(),
nn.Conv2d(32, 19, kernel_size = 5, padding = 'same', stride = 1),
nn.PReLU(),
nn.Conv2d(19, 19, kernel_size = 5, padding = 'same', stride = 1),
nn.PReLU()
)
def forward(self, x):
encoded = self.encoder(x)
return encoded
######################## Decoder ########################
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.decoder = nn.Sequential(
nn.ConvTranspose2d(19, 32, kernel_size = 5, padding = 0, stride = 1),
nn.PReLU(),
nn.ConvTranspose2d(32, 32, kernel_size = 5, padding = 0, stride = 1),
nn.PReLU(),
nn.ConvTranspose2d(32, 32, kernel_size = 5, padding = 0, stride = 1),
nn.PReLU(),
nn.ConvTranspose2d(32, 16, kernel_size = 5, padding = 0, stride = 2),
nn.PReLU(),
nn.ConvTranspose2d(16, 3, kernel_size = 5, padding = 0, stride = 2),
nn.Sigmoid()
)
def forward(self, x):
decoded = self.decoder(x)
#print("decoded shape : {}".format(decoded.size()))
decoded_interpolated = f.interpolate(decoded, size=(32, 32), mode='bilinear', align_corners=False)
#print("decoded_interpolated shape : {}".format(decoded_interpolated.size()))
return decoded_interpolated
######################## AutoEncoder ########################
class Autoencoder(nn.Module):
def __init__(
self,
encoder_class: object = Encoder,
decoder_class: object = Decoder
):
super(Autoencoder, self).__init__()
self.encoder = encoder_class()
self.decoder = decoder_class()
def AWGN(self, input, SNRdB):
normalized_tensor = f.normalize(input, dim=1)
SNR = 10.0 ** (SNRdB / 10.0)
K = args['latent_dim']
std = 1 / math.sqrt(K * SNR)
n = torch.normal(0, std, size=normalized_tensor.size()).to(device)
return normalized_tensor + n
def forward(self, x, SNRdB):
#print("input size : {}".format(x.size()))
encoded = self.encoder(x)
#print("encoded size : {}".format(encoded.size()))
before_shape = encoded.size()
encoded_flatten = torch.flatten(encoded, 1)
#print("encoded_flatten size : {}".format(encoded_flatten.size()))
encoded_AWGN = self.AWGN(encoded_flatten, SNRdB)
#print("encoded_AWGN size : {}".format(encoded_AWGN.size()))
reshaped_encoded = encoded_AWGN.view(before_shape)
#print("reshaped_encoded size : {}".format(reshaped_encoded.size()))
decoded = self.decoder(reshaped_encoded)
#print("decoded size : {}".format(decoded.size()))
return decoded
######################## Train ########################
for i in range( len(args['SNRdB_list'])) :
model_train(args['SNRdB_list'][i])
def psnr_test() :
psnr_list = []
criterion = nn.MSELoss()
for i in range(len(args['SNRdB_list'])) :
SNRdB = args['SNRdB_list'][i]
model_name = "AWGN_" + "dim=" + str(args['latent_dim']) + "_SNRdB=" + str(args['SNRdB_list'][i]) +".pth"
model = Autoencoder().to(device)
model.load_state_dict(torch.load(model_name))
test_MSE = 0.0
for test_data in testloader :
inputs = test_data[0].to(device)
outputs = model( inputs, SNRdB = SNRdB)
loss = criterion(inputs, outputs)
test_MSE += loss.item()
test_MSE_average = test_MSE / len(testloader)
psnr_list.append(10*math.log10(1/test_MSE_average))
return psnr_list
psnr_list = psnr_test()
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
plt.plot(args['SNRdB_list'], psnr_list, linestyle = 'dashed', color = 'blue', label = "AWGN")
plt.grid(True)
plt.legend()
plt.ylim([10, 40])
plt.show()
print(psnr_list)
'Wireless Comm. > CISL' 카테고리의 다른 글
이거 성능 왜 높게 나옴? (0) | 2023.08.07 |
---|---|
Custom Dataset(Gaussian Filtered Cifar 10) (0) | 2023.08.04 |
DeepJSCC (0) | 2023.08.02 |
AWGN(pi=1024 si=0) (0) | 2023.07.28 |
AWGN(pi=64, si=0) (0) | 2023.07.27 |
Comments