UOMOP
Proposed net 본문
import math
import torch
import torchvision
from fractions import Fraction
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as f
import matplotlib.pyplot as plt
import torchvision.transforms as tr
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
import time
import os
from skimage.metrics import structural_similarity as ssim
from params import *
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
transf = tr.Compose([tr.ToTensor()])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transf)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transf)
trainloader = DataLoader(trainset, batch_size=params['BS'], shuffle=True)
testloader = DataLoader(testset, batch_size=params['BS'], shuffle=True)
class Encoder(nn.Module):
def __init__(self, latent_dim):
super(Encoder, self).__init__()
self.latent_dim = latent_dim
self.encoder = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=1),
nn.PReLU(),
# Output: [batch, 16, 30, 30]
nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=1),
nn.PReLU(),
# Output: [batch, 16, 28, 28]
nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=1),
nn.PReLU(),
# Output: [batch, 16, 26, 26]
nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=1),
nn.PReLU(),
# Output: [batch, 16, 24, 24]
nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=1),
nn.PReLU(),
# Output: [batch, 16, 22, 22]
nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=1),
nn.PReLU(),
# Output: [batch, 16, 20, 20]
nn.Flatten(),
nn.Linear(6400, self.latent_dim),
)
def forward(self, x):
encoded = self.encoder(x)
return encoded
class Decoder(nn.Module):
def __init__(self, latent_dim):
super(Decoder, self).__init__()
self.latent_dim = latent_dim
self.decoder = nn.Sequential(
nn.Linear(self.latent_dim, 6400),
nn.PReLU(),
nn.Unflatten(1, (16, 20, 20)),
nn.ConvTranspose2d(16, 16, kernel_size=5, stride=1, padding=2),
nn.PReLU(),
# Output: [batch, 16, 20, 20]
nn.ConvTranspose2d(16, 16, kernel_size=5, stride=1, padding=1),
nn.PReLU(),
# Output: [batch, 16, 22, 22]
nn.ConvTranspose2d(16, 16, kernel_size=5, stride=1, padding=1),
nn.PReLU(),
# Output: [batch, 16, 24, 24]
nn.ConvTranspose2d(16, 16, kernel_size=5, stride=1, padding=1),
nn.PReLU(),
# Output: [batch, 16, 26, 26]
nn.ConvTranspose2d(16, 16, kernel_size=5, stride=1, padding=1),
nn.PReLU(),
# Output: [batch, 16, 28, 28]
nn.ConvTranspose2d(16, 16, kernel_size=5, stride=1, padding=1),
nn.PReLU(),
# Output: [batch, 16, 30, 30]
nn.ConvTranspose2d(16, 3, kernel_size=5, stride=1, padding=1),
nn.Sigmoid(),
# Output: [batch, 3, 32, 32]
)
def forward(self, x):
decoded = self.decoder(x)
return decoded
class Autoencoder(nn.Module):
def __init__(self, latent_dim):
super(Autoencoder, self).__init__()
self.latent_dim = latent_dim
self.encoder = Encoder(latent_dim)
self.decoder = Decoder(latent_dim)
def Power_norm(self, z, P=1 / np.sqrt(2)):
batch_size, z_dim = z.shape
z_power = torch.sqrt(torch.sum(z ** 2, 1))
z_M = z_power.repeat(z_dim, 1)
return np.sqrt(P * z_dim) * z / z_M.t()
def Power_norm_complex(self, z, P=1 / np.sqrt(2)):
batch_size, z_dim = z.shape
z_com = torch.complex(z[:, 0:z_dim:2], z[:, 1:z_dim:2])
z_com_conj = torch.complex(z[:, 0:z_dim:2], -z[:, 1:z_dim:2])
z_power = torch.sum(z_com * z_com_conj, 1).real
z_M = z_power.repeat(z_dim // 2, 1)
z_nlz = np.sqrt(P * z_dim) * z_com / torch.sqrt(z_M.t())
z_out = torch.zeros(batch_size, z_dim).to(device)
z_out[:, 0:z_dim:2] = z_nlz.real
z_out[:, 1:z_dim:2] = z_nlz.imag
return z_out
def AWGN_channel(self, x, snr, P=1):
batch_size, length = x.shape
gamma = 10 ** (snr / 10.0)
noise = np.sqrt(P / gamma) * torch.randn(batch_size, length).cuda()
y = x + noise
return y
def Fading_channel(self, x, snr, P=1):
gamma = 10 ** (snr / 10.0)
[batch_size, feature_length] = x.shape
K = feature_length // 2
h_I = torch.randn(batch_size, K).to(device)
h_R = torch.randn(batch_size, K).to(device)
h_com = torch.complex(h_I, h_R)
x_com = torch.complex(x[:, 0:feature_length:2], x[:, 1:feature_length:2])
y_com = h_com * x_com
n_I = np.sqrt(P / gamma) * torch.randn(batch_size, K).to(device)
n_R = np.sqrt(P / gamma) * torch.randn(batch_size, K).to(device)
noise = torch.complex(n_I, n_R)
y_add = y_com + noise
y = y_add / h_com
y_out = torch.zeros(batch_size, feature_length).to(device)
y_out[:, 0:feature_length:2] = y.real
y_out[:, 1:feature_length:2] = y.imag
return y_out
def forward(self, x, SNRdB, channel):
encoded = self.encoder(x)
if channel == 'AWGN':
normalized_x = self.Power_norm(encoded)
channel_output = self.AWGN_channel(normalized_x, SNRdB)
elif channel == 'Rayleigh':
normalized_complex_x = self.Power_norm_complex(encoded)
channel_output = self.Fading_channel(normalized_complex_x, SNRdB)
decoded = self.decoder(channel_output)
return decoded
def model_train(trainloader, testloader, latent_dims):
for latent_dim in latent_dims:
for snr_i in range(len(params['SNR'])):
model = Autoencoder(latent_dim=latent_dim).to(device)
print("Model size : {}".format(count_parameters(model)))
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=params['LR'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)
min_test_cost = float('inf')
epochs_no_improve = 0
n_epochs_stop = 43
print("+++++ SNR = {} Training Start! +++++\t".format(params['SNR'][snr_i]))
max_psnr = 0
previous_best_model_path = None
for epoch in range(params['EP']):
# ========================================== Train ==========================================
train_loss = 0.0
model.train()
timetemp = time.time()
for data in trainloader:
inputs = data[0].to(device)
optimizer.zero_grad()
outputs = model(inputs, SNRdB=params['SNR'][snr_i], channel=params['channel'])
loss = criterion(inputs, outputs)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_cost = train_loss / len(trainloader)
tr_psnr = round(10 * math.log10(1.0 / train_cost), 3)
# ========================================== Test ==========================================
test_loss = 0.0
model.eval()
with torch.no_grad():
for data in testloader:
inputs = data[0].to(device)
outputs = model(inputs, SNRdB=params['SNR'][snr_i], channel=params['channel'])
loss = criterion(inputs, outputs)
test_loss += loss.item()
test_cost = test_loss / len(testloader)
test_psnr = round(10 * math.log10(1.0 / test_cost), 3)
scheduler.step(test_cost)
if test_cost < min_test_cost:
min_test_cost = test_cost
epochs_no_improve = 0
else:
epochs_no_improve += 1
if epochs_no_improve == n_epochs_stop:
print("Early stopping!")
break # 조기 종료
training_time = time.time() - timetemp
print(
"[{:>3}-Epoch({:>5}sec.)] PSNR(Train / Val) : {:>6.4f} / {:>6.4f} ".format(epoch + 1, round(training_time, 2), tr_psnr, test_psnr))
if test_psnr > max_psnr:
save_folder = 'trained_model'
if not os.path.exists(save_folder):
os.makedirs(save_folder)
previous_psnr = max_psnr
max_psnr = test_psnr
# 이전 최고 성능 모델이 있다면 삭제
if previous_best_model_path is not None:
os.remove(previous_best_model_path)
print(f"Performance update!! {previous_psnr} to {max_psnr}")
save_path = os.path.join(save_folder,
f"DeepJSCC(DIM={latent_dim}_SNR={params['SNR'][snr_i]}_PSNR={max_psnr}).pt")
torch.save(model, save_path)
print(f"Saved new best model at {save_path}")
previous_best_model_path = save_path
model_train(trainloader, testloader, params['DIM'])
params = {
'BS': 64,
'LR': 0.0005,
'EP': 1000,
'SNR': [40, 20],
'DIM': [1024, 512, 256],
'channel' : 'Rayleigh'
}
'DE > Code' 카테고리의 다른 글
Adaptive decoder (0) | 2024.08.06 |
---|---|
Adaptive encoder (0) | 2024.08.06 |
Filter counting of zero padding (0) | 2024.08.06 |
Patch selection code (CBS) (0) | 2024.08.05 |
Image (variance , entropy , edge) (0) | 2024.08.05 |
Comments