UOMOP
DeepJSCC 본문
import cv2
import math
import time
import torch
import random
import torchvision
import numpy as np
from PIL import Image
import torch.nn as nn
from numpy import sqrt
import torch.optim as optim
import torch.nn.functional as f
import matplotlib.pyplot as plt
import torchvision.transforms as tr
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
args = {
'BATCH_SIZE' : 64,
'LEARNING_RATE' : 0.001,
'EPOCH' : 200,
'SNRdB_list' : [0, 10, 20, 30],
'input_dim' : 32 * 32,
'rl_pat' : 7
}
transf = tr.Compose([tr.ToTensor()])
trainset = torchvision.datasets.CIFAR10(root = './data', train = True, download = True, transform = transf)
testset = torchvision.datasets.CIFAR10(root = './data', train = False, download = True, transform = transf)
trainloader = DataLoader(trainset, batch_size = args['BATCH_SIZE'], shuffle = True)
testloader = DataLoader(testset, batch_size = args['BATCH_SIZE'], shuffle = True)
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 16, kernel_size = 5, padding = 'valid', stride = 2),
nn.PReLU(),
nn.Conv2d(16, 32, kernel_size = 5, padding = 'valid', stride = 2),
nn.PReLU(),
nn.Conv2d(32, 32, kernel_size = 5, padding = 'same', stride = 1),
nn.PReLU(),
nn.Conv2d(32, 19, kernel_size = 5, padding = 'same', stride = 1),
nn.PReLU(),
nn.Conv2d(19, 19, kernel_size = 5, padding = 'same', stride = 1),
nn.PReLU()
)
def forward(self, x):
encoded = self.encoder(x)
return encoded
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.decoder = nn.Sequential(
nn.ConvTranspose2d(19, 32, kernel_size = 5, padding = 0, stride = 1),
nn.PReLU(),
nn.ConvTranspose2d(32, 32, kernel_size = 5, padding = 0, stride = 1),
nn.PReLU(),
nn.ConvTranspose2d(32, 32, kernel_size = 5, padding = 0, stride = 1),
nn.PReLU(),
nn.ConvTranspose2d(32, 16, kernel_size = 5, padding = 0, stride = 2),
nn.PReLU(),
nn.ConvTranspose2d(16, 3, kernel_size = 5, padding = 0, stride = 2),
nn.Sigmoid()
)
# nn.Conv2d에는 torch 1.10.0 이후에 padding으로 'same', 'valid'를 줄 수 있었지만, nn.ConvTranspose2d는 padding으로 'same', 'valid'를 줄 수 없다.
# padding_height = [strides[1] * (in_height - 1) + kernel_size[0] - out_height] / 2
# padding_width = [[strides[2] * (in_width - 1) + kernel_size[1] - out_width] / 2
def forward(self, x):
decoded = self.decoder(x)
#print("decoded shape : {}".format(decoded.size()))
decoded_interpolated = f.interpolate(decoded, size=(32, 32), mode='bilinear', align_corners=False)
#print("decoded_interpolated shape : {}".format(decoded_interpolated.size()))
return decoded_interpolated
class Autoencoder(nn.Module):
def __init__(
self,
encoder_class: object = Encoder,
decoder_class: object = Decoder
):
super(Autoencoder, self).__init__()
self.encoder = encoder_class()
self.decoder = decoder_class()
def AWGN(self, input, SNRdB):
normalized_tensor = f.normalize(input, dim=1)
SNR = 10.0 ** (SNRdB / 10.0)
K = 475
std = 1 / math.sqrt(K * SNR)
n = torch.normal(0, std, size=normalized_tensor.size()).to(device)
return normalized_tensor + n
def forward(self, x, SNRdB):
#print("input : {}".format(x.size()))
encoded = self.encoder(x)
#print("encoded : {}".format(encoded.size()))
shape = encoded.size()
#print("encoded_shape : {}".format(shape))
encoded_flatten = torch.flatten(encoded, start_dim=1)
#print("encoded_flatten : {}".format(encoded_flatten.size()))
encoded_AWGN = self.AWGN(encoded_flatten, SNRdB)
#print("After AWGN : {}".format(encoded_AWGN.size()))
encoded_AWGN_reshaped = encoded_AWGN.view(shape)
#print("encoded_reshaped : {}".format(encoded_AWGN_reshaped.size()))
decoded = self.decoder(encoded_AWGN_reshaped)
#print("decoded size : {}".format(decoded.size()))
return decoded
def model_train(trainloader, testloader) :
for i in range( len(args['SNRdB_list']) ):
model = Autoencoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr = args['LEARNING_RATE'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=args['rl_pat'],
threshold=1e-3, verbose=True)
print("+++++ SNR = {} Training Start! +++++\t".format(args['SNRdB_list'][i]))
for epoch in range(args['EPOCH']) :
#========================================== Train ==========================================
train_loss = 0.0
for data in trainloader :
inputs = data[0].to(device)
#print(111)
#print(inputs.size())
optimizer.zero_grad()
outputs = model( inputs, SNRdB = args['SNRdB_list'][i])
#print(111)
#print(outputs.size())
loss = criterion(inputs, outputs)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_cost = train_loss / len(trainloader)
#========================================== Test ==========================================
test_loss = 0.0
for data in testloader :
inputs = data[0].to(device)
optimizer.zero_grad()
outputs = model( inputs, SNRdB = args['SNRdB_list'][i])
loss = criterion(inputs, outputs)
test_loss += loss.item()
test_cost = test_loss / len(testloader)
#========================================== print(Metric) ==========================================
scheduler.step(test_cost)
#print("[{} Epoch] Train Loss : {} Val. Loss : {}".format(epoch + 1, round(train_cost, 6), round(test_cost, 6)))
print("[{} Epoch] Train PSNR : {} Val. PSNR : {}".format(epoch + 1, round(10*math.log10(1/train_cost), 4), round(10*math.log10(1/test_cost), 4)))
torch.save(model.state_dict(), "DeepJSCC_"+ str(args['SNRdB_list'][i])+"dB.pth")
#model_train(trainloader, testloader)
def DeepJSCC_test(testloader) :
psnr_list = []
for i in range(len(args['SNRdB_list'])):
criterion = nn.MSELoss()
SNRdB = args['SNRdB_list'][i]
model_location = "DeepJSCC_"+ str(SNRdB)+"dB.pth"
model = Autoencoder().to(device)
model.load_state_dict(torch.load(model_location))
test_loss = 0.0
for data in testloader :
inputs = data[0].to(device)
outputs = model(inputs, SNRdB = SNRdB)
loss = criterion(inputs, outputs)
test_loss += loss.item()
test_cost = test_loss / len(testloader)
psnr_list.append(round(10*math.log10(1/test_cost), 4))
plt.plot(args["SNRdB_list"], psnr_list, linestyle = 'dashed', color = 'blue', label = "AWGN")
plt.grid(True)
plt.legend()
plt.show()
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
DeepJSCC_test(testloader)
성능 잘 안나옴.
Dimension 다시 확인해봐야함.
'Wireless Comm. > CISL' 카테고리의 다른 글
Custom Dataset(Gaussian Filtered Cifar 10) (0) | 2023.08.04 |
---|---|
AWGN_cifar10_20230804 (0) | 2023.08.04 |
AWGN(pi=1024 si=0) (0) | 2023.07.28 |
AWGN(pi=64, si=0) (0) | 2023.07.27 |
AWGN(pi=128, si=0) (0) | 2023.07.27 |
Comments