UOMOP
DeepJSCC performance ( DIM = 768, 1536, 2304 ) 본문
import math
import torch
import torchvision
from fractions import Fraction
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as f
import matplotlib.pyplot as plt
import torchvision.transforms as tr
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
import time
import os
from params import *
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
transf = tr.Compose([tr.ToTensor()])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transf)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transf)
trainloader = DataLoader(trainset, batch_size=params['BS'], shuffle=True)
testloader = DataLoader(testset, batch_size=params['BS'], shuffle=True)
class Encoder(nn.Module):
def __init__(self, latent_dim):
super(Encoder, self).__init__()
self.latent_dim = latent_dim
self.encoder = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=5, stride=2, padding=2), # Output: 16x16
nn.PReLU(),
nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2), # Output: 8x8
nn.PReLU(),
nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=2), # Output: 4x4
nn.PReLU(),
nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2), # Output: 4x4
nn.PReLU(),
nn.Flatten(),
nn.Linear(4 * 4 * 64, self.latent_dim),
)
def forward(self, x):
output = self.encoder(x)
return output
class Decoder(nn.Module):
def __init__(self, latent_dim):
super(Decoder, self).__init__()
self.latent_dim = latent_dim
self.decoder = nn.Sequential(
nn.Linear(self.latent_dim, 4 * 4 * 64),
nn.PReLU(),
nn.Unflatten(1, (64, 4, 4)),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=1, padding=2), # Output: 4x4
nn.PReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=5, stride=2, padding=2, output_padding=1), # Output: 8x8
nn.PReLU(),
nn.Conv2d(32, 32, kernel_size=5, padding=2),
nn.PReLU(),
nn.ConvTranspose2d(32, 16, kernel_size=5, stride=2, padding=2, output_padding=1), # Output: 16x16
nn.PReLU(),
nn.Conv2d(16, 16, kernel_size=5, padding=2),
nn.PReLU(),
nn.ConvTranspose2d(16, 3, kernel_size=5, stride=2, padding=2, output_padding=1), # Output: 32x32
nn.Sigmoid() # Sigmoid activation at the output
)
def forward(self, x):
x = self.decoder(x)
return x
class Autoencoder(nn.Module):
def __init__(
self,
latent_dim,
encoder_class: object = Encoder,
decoder_class: object = Decoder
):
super(Autoencoder, self).__init__()
self.latent_dim = latent_dim
self.encoder = encoder_class(latent_dim=latent_dim)
self.decoder = decoder_class(latent_dim=latent_dim)
def AWGN(self, input, SNRdB):
normalized_tensor = f.normalize(input, dim=1)
SNR = 10.0 ** (SNRdB / 10.0)
std = 1 / math.sqrt(self.latent_dim * SNR)
n = torch.normal(0, std, size=normalized_tensor.size()).to(device)
return normalized_tensor + n
def forward(self, x, SNRdB):
encoded = self.encoder(x)
channel_output = self.AWGN(encoded, SNRdB)
decoded = self.decoder(channel_output)
return decoded
def model_train(trainloader, testloader, latent_dims):
for latent_dim in latent_dims :
for snr_i in range(len(params['SNR'])):
model = Autoencoder(latent_dim=latent_dim, encoder_class=Encoder, decoder_class=Decoder).to(device)
print("Model size : {}".format(count_parameters(model)))
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=params['LR'])
min_test_cost = float('inf')
epochs_no_improve = 0
n_epochs_stop = 20
print("+++++ SNR = {} Training Start! +++++\t".format(params['SNR'][snr_i]))
max_psnr = 0
previous_best_model_path = None
for epoch in range(params['EP']):
# ========================================== Train ==========================================
train_loss = 0.0
model.train()
timetemp = time.time()
for data in trainloader:
inputs = data[0].to(device)
optimizer.zero_grad()
outputs = model(inputs, SNRdB=params['SNR'][snr_i])
loss = criterion(inputs, outputs)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_cost = train_loss / len(trainloader)
tr_psnr = round(10 * math.log10(1.0 / train_cost), 3)
# ========================================== Test ==========================================
test_loss = 0.0
model.eval()
with torch.no_grad():
for data in testloader:
inputs = data[0].to(device)
outputs = model(inputs, SNRdB=params['SNR'][snr_i])
loss = criterion(inputs, outputs)
test_loss += loss.item()
test_cost = test_loss / len(testloader)
val_psnr = round(10 * math.log10(1.0 / test_cost), 3)
if test_cost < min_test_cost:
min_test_cost = test_cost
epochs_no_improve = 0
else:
epochs_no_improve += 1
if epochs_no_improve == n_epochs_stop:
print("Early stopping!")
break # 조기 종료
training_time = time.time() - timetemp
print(
"[{}-Epoch({}sec.)] Train PNSR : {:.4f}\tVal PSNR : {:.4f}".format(epoch + 1, round(training_time, 2),tr_psnr, val_psnr))
if val_psnr > max_psnr:
save_folder = 'trained_model'
if not os.path.exists(save_folder):
os.makedirs(save_folder)
previous_psnr = max_psnr
max_psnr = val_psnr
# 이전 최고 성능 모델이 있다면 삭제
if previous_best_model_path is not None:
os.remove(previous_best_model_path)
print(f"Performance update!! {previous_psnr} to {max_psnr}")
save_path = os.path.join(save_folder,
f"DeepJSCC(DIM={latent_dim}_SNR={params['SNR'][snr_i]}_PSNR={max_psnr}).pt")
torch.save(model, save_path)
print(f"Saved new best model at {save_path}")
previous_best_model_path = save_path
model_train(trainloader, testloader, params['DIM'])
'Main' 카테고리의 다른 글
Patch importance (0) | 2024.06.03 |
---|---|
Matlab code for PSNR performance comparison (0) | 2024.05.24 |
Image reconstruction with CBM (0) | 2024.05.03 |
Patch complexity calculated region extending (0) | 2024.05.03 |
ChessBoard Masking with Colored Random Noise (0) | 2024.05.02 |
Comments