UOMOP
AWGN(pi=512, si=0) 본문
import cv2
import math
import time
import torch
import random
import torchvision
import numpy as np
from PIL import Image
import torch.nn as nn
from numpy import sqrt
import torch.optim as optim
import torch.nn.functional as f
import matplotlib.pyplot as plt
import torchvision.transforms as tr
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
def sobel_filter(img, x_or_y) :
if x_or_y == 'x' :
return cv2.Sobel(img, -1, 1, 0)
else :
return cv2.Sobel(img, -1, 0, 1)
def scharr_filter(img, x_or_y) :
if x_or_y == 'x' :
return cv2.Scharr(img, -1, 1, 0)
else :
return cv2.Scharr(img, -1, 0, 1)
def average_filter(img, kernel_size) :
return cv2.blur(img, (kernel_size, kernel_size))
def gaussian_filter(img, kernel_size) :
# cv2.GaussianBlur의 parameter에는 sigma가 존재하는데, 이것은 가우시안 커널의 X, Y 방향의 표준편차이다.
return cv2.GaussianBlur(img, (kernel_size, kernel_size), 0)
def median_filter(img, kernel_size) :
return cv2.medianBlur(img, kernel_size)
def bilateral_filter(img, kernel_size) :
return cv2.bilateralFilter(img, kernel_size, 75, 75)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
args = {
'BATCH_SIZE' : 32,
'LEARNING_RATE' : 0.001,
'NUM_EPOCH' : 200,
'SNRdB_list' : [1, 10, 20, 30],
'pi_dim' : 512,
'si_dim' : 0,
'input_dim' : 32 * 32,
'filter_type' : 'scharr',
'filter_dir' : 'x',
'es_pat' : 10,
'rl_pat' : 10
}
transf = tr.Compose([tr.ToTensor(), tr.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 원래 pytorch cifar10은 0~1사이의 값을 가진다.
# -1~1로 정규화를 시켜준다.
trainset = torchvision.datasets.CIFAR10(root = './data', train = True, download = True, transform = transf)
testset = torchvision.datasets.CIFAR10(root = './data', train = False, download = True, transform = transf)
trainloader = DataLoader(trainset, batch_size = args['BATCH_SIZE'], shuffle = True)
testloader = DataLoader(testset, batch_size = args['BATCH_SIZE'], shuffle = True)
def spectrum(img) :
f = np.fft.fft2(img)
fshift = np.fft.fftshift(f)
magnitude_fshift = np.log(np.abs(fshift) + 1)
return magnitude_fshift
def psnr(SNRdB_list, model_name_without_SNRdB, testloader) :
psnr_list = []
for i in range(len(SNRdB_list)) :
SNRdB = SNRdB_list[i]
model_name = model_name_without_SNRdB + "_pi=" + str(args['pi_dim']) +"_" + "si=" + str(args['si_dim']) + "_SNR=" + str(SNRdB) + ".pth"
model_1 = Autoencoder().to(device)
model_1.load_state_dict(torch.load(model_name))
psnr_culmi = 0.0
for data in testloader :
inputs = data[0].to(device)
outputs = model_1(inputs, SNRdB = SNRdB)
for j in range(len(data)) :
psnr_culmi += cv2.PSNR(inputs[j].detach().cpu().numpy(), outputs[j].detach().cpu().numpy(), 2)
psnr_list.append(round(psnr_culmi / (len(testloader) * len(data)), 3))
return psnr_list
def model_train(SNRdB, learning_rate, epoch_num, trainloader) :
check_num = 1
model = Autoencoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr = learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=args['rl_pat'], threshold=1e-3, verbose = True)
print("\n\n+++++ SNR = {} Training Start! +++++\t".format(SNRdB))
for epoch in range(epoch_num) :
train_psnr = 0.0
val_psnr = 0.0
## Train dataset
for data in trainloader :
inputs = data[0].to(device)
optimizer.zero_grad()
outputs = model(inputs, SNRdB = SNRdB)
loss = criterion(inputs, outputs)
loss.backward()
optimizer.step()
for j in range(len(data)) :
train_psnr += cv2.PSNR(inputs[j].detach().cpu().numpy(), outputs[j].detach().cpu().numpy(), 2)
train_psnr_fin = round(train_psnr / (len(trainloader) * len(data)), 3)
## Test dataset
for val_data in testloader :
inputs = val_data[0].to(device)
outputs = model( inputs, SNRdB = SNRdB)
for j in range(len(val_data)) :
val_psnr += cv2.PSNR(inputs[j].detach().cpu().numpy(), outputs[j].detach().cpu().numpy(), 2)
val_psnr_fin = round(val_psnr / (len(testloader) * len(val_data)), 3)
## Early stopping
'''
if epoch % check_num == 0:
es(val_psnr_fin, model, SNRdB)
if es.early_stop:
print(epoch, loss.item())
break
'''
scheduler.step(val_psnr)
print("[{} Epoch] train_PSNR : {}\tvalidation_PSNR : {}".format(epoch + 1, train_psnr_fin, val_psnr_fin))
torch.save(model.state_dict(), "./AWGN" + "_pi=" + str(args['pi_dim']) +"_" + "si=" + str(args['si_dim']) + "_SNR=" + str(SNRdB) +".pth")
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
c_hid = 32
self.encoder_pi = nn.Sequential(
nn.Conv2d(3, c_hid, kernel_size=3, padding=1, stride=2), # 32x32 => 16x16
nn.ReLU(),
nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2), # 16x16 => 8x8
nn.ReLU(),
nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2), # 8x8 => 4x4
nn.ReLU(),
nn.Flatten(), # Image grid to single feature vector
nn.Linear(2 * 16 * c_hid, args['pi_dim'])
)
self.encoder_si = nn.Sequential(
nn.Linear(args['input_dim'], 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, args['si_dim'])
)
def HPF(self, x, filter_type, filter_dir):
# print("x shape : {}".format(x.size()))
gray_transform = tr.Grayscale()
save_list = []
# x = x.view(args['BATCH_SIZE'], 1, 32, 32)
# print("reshaped x's shape : {}".format(x.size()))
for i in range(x.shape[0]):
# print("123245")
# print(x[i].size())
gray_img = cv2.cvtColor(x[i].permute(1, 2, 0).detach().cpu().numpy(), cv2.COLOR_BGR2GRAY)
# print("x[i] shape : {}".format(gray_img.shape))
if filter_type == 'sobel':
if filter_dir == 'x':
save_list.append(sobel_filter(gray_img, 'x'))
elif filter_dir == 'y':
save_list.append(sobel_filter(gray_img, 'y'))
else:
print("Type Error(filter direction)")
elif filter_type == 'scharr':
if filter_dir == 'x':
save_list.append(scharr_filter(gray_img, 'x'))
elif filter_dir == 'y':
save_list.append(scharr_filter(gray_img, 'y'))
else:
print("Type Error(filter direction)")
else:
print("Type Error(filter type)")
# print("save_list's length : {}".format(len(save_list)))
save_arr = np.array(save_list).reshape(x.shape[0], 1, 32, 32)
# print("save_arr shape : {}".format(save_arr.size))
save_tensor = torch.Tensor(save_arr)
# print("save_tensor's size : {}".format(save_tensor.size()))
return save_tensor
def forward(self, x, filter_type, filter_dir):
HPFed_x = self.HPF(x, args['filter_type'], args['filter_dir']).to(device)
# print("x's size : {}".format(x.size()))
# print("HPFed_x's size : {}".format(HPFed_x.size()))
encoded_pi = self.encoder_pi(x)
# print("encoded_pi size : {}".format(encoded_pi.size()))
encoded_si = self.encoder_si(HPFed_x.view(-1, args['input_dim']).to(torch.float32))
# print("encoded_si size : {}".format(encoded_si.size()))
return encoded_pi, encoded_si
class EarlyStopping:
def __init__(self, patience=7, verbose=False, delta=0):
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.psnr_max = float(0)
self.delta = delta
def __call__(self, psnr, model, SNRdB):
score = psnr
if self.best_score is None:
self.best_score = score
self.save_checkpoint(psnr, model, SNRdB)
elif score < self.best_score + self.delta:
self.counter += 1
print("EarlyStopping counter : {} out of {}".format(self.counter, self.patience))
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(psnr, model, SNRdB)
self.counter = 0
def save_checkpoint(self, psnr, model, SNRdB):
'''Saves model when validation loss decrease.'''
#if self.verbose:
#print("Validation PSNR increase({} --> {})".format(round(self.psnr_max, 3), round(psnr, 3)))
torch.save(model.state_dict(), "./AWGN" + "_pi=" + str(args['pi_dim']) +"_" + "si=" + str(args['si_dim']) + "_SNR=" + str(SNRdB) +".pth")
self.psnr_max = psnr
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
c_hid = 32
self.linear = nn.Sequential(
nn.Linear(args['pi_dim'] + args['si_dim'], 2 * 16 * c_hid),
nn.ReLU()
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(2 * c_hid, 2 * c_hid, kernel_size=3, output_padding=1, padding=1, stride=2),
# 4x4 => 8x8
nn.ReLU(),
nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
nn.ReLU(),
# nn.ConvTranspose2d(2 * c_hid, 2 * c_hid, kernel_size=3, output_padding=0, padding=1, stride=1), # 8x8 => 8x8
# nn.ReLU(),
# nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
# nn.ReLU(),
nn.ConvTranspose2d(2 * c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 8x8 => 16x16
nn.ReLU(),
nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
nn.ReLU(),
# nn.ConvTranspose2d(c_hid, c_hid, kernel_size=3, output_padding=0, padding=1, stride=1), # 16x16 => 16x16
# nn.ReLU(),
# nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
# nn.ReLU(),
nn.ConvTranspose2d(c_hid, 3, kernel_size=3, output_padding=1, padding=1, stride=2), # 16x16 => 32x32
)
def forward(self, x):
x = self.linear(x)
# print("before : {}".format(x.size()))
x = x.reshape(x.shape[0], -1, 4, 4)
# print("after : {}".format(x.size()))
decoded = self.decoder(x)
# print(decoded.size())
# print("111")
return decoded
class Autoencoder(nn.Module):
def __init__(
self,
encoder_class: object = Encoder,
decoder_class: object = Decoder
):
super(Autoencoder, self).__init__()
self.encoder = encoder_class()
self.decoder = decoder_class()
def AWGN(self, input, SNRdB):
normalized_tensor = f.normalize(input, dim=1)
SNR = 10.0 ** (SNRdB / 10.0)
K = args['pi_dim'] + args['si_dim']
std = 1 / math.sqrt(K * SNR)
n = torch.normal(0, std, size=normalized_tensor.size()).to(device)
return normalized_tensor + n
def L2_Normalization(self, input):
norm_2 = torch.norm(input, p=2, dim=1)
out_list = []
for i in range(len(input)):
out_list.append((norm_2[i] * input[i]).tolist())
return torch.Tensor(out_list)
def forward(self, x, SNRdB):
encoded_pi, encoded_si = self.encoder(x, args['filter_type'], args['filter_dir'])
# print("encoded_pi : {}".format(encoded_pi[0]))
# print("encoded_pi size : {}".format(encoded_pi.size()))
# print("encoded_si : {}".format(encoded_si[0]))
# print("encoded_si size : {}".format(encoded_si.size()))
Tx_output = torch.cat([encoded_pi, encoded_si], dim=1)
# print("Tx_output : {}".format(Tx_output[0]))
# print("Tx_output size : {}".format(Tx_output.size()))
Rx_input = self.AWGN(Tx_output, SNRdB)
# print("encoded_AWGN size : {}".format(Rx_input.size()))
decoded = self.decoder(Rx_input)
# print("decoded : {}".format(decoded[0]))
# print("decoded size : {}".format(decoded.size()))
return decoded
SNRdB_list = args['SNRdB_list']
learning_rate = args['LEARNING_RATE']
epoch_num = args['NUM_EPOCH']
trainloader = trainloader
model_name_without_SNRdB = 'AWGN'
for i in range( len(args['SNRdB_list'])) :
#es = EarlyStopping(patience=args['es_pat'], verbose=True, delta=0)
model_train(args['SNRdB_list'][i], learning_rate, epoch_num, trainloader)
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
psnr_list = psnr(SNRdB_list, model_name_without_SNRdB, testloader)
plt.plot(SNRdB_list, psnr_list, linestyle = 'dashed', color = 'blue', label = "AWGN")
plt.grid(True)
plt.legend()
plt.ylim([10, 40])
plt.show()
print(psnr_list)
'Wireless Comm. > CISL' 카테고리의 다른 글
AWGN(pi=128, si=0) (0) | 2023.07.27 |
---|---|
AWGN(pi=256, si=0) (0) | 2023.07.26 |
AWGN, latent_dim = 500 (0) | 2023.07.20 |
Cifar10 AWGN [1dB, 10dB, 20dB] (0) | 2023.07.18 |
Cifar10 Rayleigh [1dB 10dB 20dB] (0) | 2023.07.11 |
Comments