UOMOP
DeepJSCC(My code)_batch:32, lr:0.001, R:1/6 본문
Research/Semantic Communication
DeepJSCC(My code)_batch:32, lr:0.001, R:1/6
Happy PinGu 2023. 10. 25. 10:14import cv2
import math
import torch
import torchvision
from fractions import Fraction
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as f
import matplotlib.pyplot as plt
import torchvision.transforms as tr
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
args = {
'BATCH_SIZE' : 32,
'LEARNING_RATE' : 0.001,
'NUM_EPOCH' : 500,
'SNRdB_list' : [0, 15, 30],
'latent_dim' : 512,
'input_dim' : 32*32
}
transf = tr.Compose([tr.ToTensor()])
trainset = torchvision.datasets.CIFAR10(root = './data', train = True, download = True, transform = transf)
testset = torchvision.datasets.CIFAR10(root = './data', train = False, download = True, transform = transf)
trainloader = DataLoader(trainset, batch_size = args['BATCH_SIZE'], shuffle = True)
testloader = DataLoader(testset, batch_size = args['BATCH_SIZE'], shuffle = True)
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
c_hid = 32
self.encoder = nn.Sequential(
nn.Conv2d(3, c_hid, kernel_size=3, padding=1, stride=2), # 32x32 => 16x16
nn.ReLU(),
nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2), # 16x16 => 8x8
nn.ReLU(),
nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2), # 8x8 => 4x4
nn.ReLU(),
nn.Flatten(), # Image grid to single feature vector
nn.Linear(2 * 16 * c_hid, args['latent_dim'])
)
def forward(self, x):
encoded = self.encoder(x)
return encoded
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
c_hid = 32
self.linear = nn.Sequential(
nn.Linear(args['latent_dim'], 2 * 16 * c_hid),
nn.ReLU()
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(2 * c_hid, 2 * c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 4x4 => 8x8
nn.ReLU(),
nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(2 * c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 8x8 => 16x16
nn.ReLU(),
nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(c_hid, 3, kernel_size=3, output_padding=1, padding=1, stride=2), # 16x16 => 32x32
)
def forward(self, x):
x = self.linear(x)
x = x.reshape(x.shape[0], -1, 4, 4)
decoded = self.decoder(x)
return decoded
class Autoencoder(nn.Module):
def __init__(
self,
encoder_class: object = Encoder,
decoder_class: object = Decoder
):
super(Autoencoder, self).__init__()
self.encoder = encoder_class()
self.decoder = decoder_class()
def AWGN(self, input, SNRdB):
normalized_tensor = f.normalize(input, dim=1)
SNR = 10.0 ** (SNRdB / 10.0)
K = args['latent_dim']
std = 1 / math.sqrt(K * SNR)
n = torch.normal(0, std, size=normalized_tensor.size()).to(device)
return normalized_tensor + n
def forward(self, x, SNRdB):
encoded = self.encoder(x)
channel_output = self.AWGN(encoded, SNRdB)
decoded = self.decoder(channel_output)
return decoded
def model_train(trainloader, testloader) :
for i in range( len(args['SNRdB_list']) ):
model = Autoencoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr = args['LEARNING_RATE'])
print("+++++ SNR = {} Training Start! +++++\t".format(args['SNRdB_list'][i]))
for epoch in range(args['NUM_EPOCH']) :
#========================================== Train ==========================================
train_loss = 0.0
model.train()
for data in trainloader :
inputs = data[0].to(device)
optimizer.zero_grad()
outputs = model( inputs, SNRdB = args['SNRdB_list'][i])
loss = criterion(inputs, outputs)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_cost = train_loss / len(trainloader)
#========================================== Test ==========================================
test_loss = 0.0
model.eval()
for data in testloader :
inputs = data[0].to(device)
optimizer.zero_grad()
outputs = model( inputs, SNRdB = args['SNRdB_list'][i])
loss = criterion(inputs, outputs)
test_loss += loss.item()
test_cost = test_loss / len(testloader)
# print("[{} Epoch] Train Loss : {} Val. Loss : {}".format(epoch + 1, round(train_cost, 6), round(test_cost, 6)))
print("[{} Epoch] Train PSNR : {} Val. PSNR : {}".format(epoch + 1, round(10 * math.log10(1.0 / train_cost), 4), round(10 * math.log10(1.0 / test_cost), 4)))
#print("[{} Epoch] Train PSNR : {} Val. PSNR : {}".format(cv2.PSNR(inputs, outputs)))
torch.save(model.state_dict(), "DeepJSCC(" + str(args['latent_dim'])+ ")_" + str(args['SNRdB_list'][i]) + "dB.pth")
model_train(trainloader, testloader)
def DeepJSCC_test(testloader) :
psnr_list = []
for i in range(len(args['SNRdB_list'])):
criterion = nn.MSELoss()
SNRdB = args['SNRdB_list'][i]
model_location = "DeepJSCC(" + str(args['latent_dim'])+ ")_" + str(SNRdB) + "dB.pth"
model = Autoencoder().to(device)
model.load_state_dict(torch.load(model_location))
test_loss = 0.0
for data in testloader :
model.eval()
inputs = data[0].to(device)
outputs = model(inputs, SNRdB = SNRdB)
loss = criterion(inputs, outputs)
test_loss += loss.item()
test_cost = test_loss / len(testloader)
psnr_list.append(round(10*math.log10(1/test_cost), 4))
plt.plot(args["SNRdB_list"], psnr_list, linestyle = 'dashed', color = 'blue', label = "AWGN")
plt.grid(True)
plt.legend()
plt.title('batch size:' + str(args['BATCH_SIZE']) + ' || lr:' + str(args['LEARNING_RATE']) + ' || R:' + str( Fraction(args['latent_dim'], int((args['input_dim'] * 3))) ))
plt.xlabel('SNR(dB)')
plt.ylabel('PSNR')
plt.ylim([15, 40])
plt.show()
return psnr_list
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
psnr_list= DeepJSCC_test(testloader)
print("PSNR data : {}".format(psnr_list))
'Research > Semantic Communication' 카테고리의 다른 글
DeepJSCC(My code)_batch:64, lr:0.0005, R:1/6 (0) | 2023.10.27 |
---|---|
DeepJSCC(My code)_batch:64, lr:0.0001, R:1/6 (0) | 2023.10.26 |
DeepJSCC(My code)_batch:64, lr:0.0005, R:1/6 (0) | 2023.10.26 |
DeepJSCC(My code)_batch:64, lr:0.001, R:1/6 (0) | 2023.10.25 |
DeepJSCC(My code)_batch:128, lr:0.001, R:1/6 (1) | 2023.10.25 |
Comments