UOMOP
2 본문
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import math
from params import *
import torchvision
import time
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def psnr(img1, img2):
mse = nn.functional.mse_loss(img1, img2)
if mse == 0:
return float('inf')
pixel_max = 1.0
return 20 * math.log10(pixel_max / math.sqrt(mse.item()))
# Define CIFAR-10 dataset with transformations
transform = transforms.Compose([
transforms.ToTensor()
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=params['BS'], shuffle=True)
testloader = DataLoader(testset, batch_size=params['BS'], shuffle=True)
class PatchLSTMEncoder(nn.Module):
def __init__(self, input_dim, hidden_dim, num_layer, sequence_length):
super(PatchLSTMEncoder, self).__init__()
self.hidden_dim = hidden_dim
self.num_layer = num_layer
self.sequence_length = sequence_length
self.lstm1 = nn.LSTM(input_dim, 32, num_layer, batch_first=True)
self.lstm2 = nn.LSTM(32, 64, num_layer, batch_first=True)
self.lstm3 = nn.LSTM(64, 32, num_layer, batch_first=True)
self.lstm4 = nn.LSTM(32, hidden_dim, num_layer, batch_first=True)
self.relu = nn.LeakyReLU()
def forward(self, x):
batch_size, _, _, _ = x.size()
x = x.unfold(2, params['patch_size'], params['patch_size']).unfold(3, params['patch_size'], params['patch_size'])
x = x.contiguous().view(batch_size, sequence_length, -1) # Reshape to (batch_size, sequence_length, patch_dim)
lstm_out, _ = self.lstm1(x)
#lstm_out = self.relu(lstm_out)
lstm_out, _ = self.lstm2(lstm_out)
#lstm_out = self.relu(lstm_out)
lstm_out, _ = self.lstm3(lstm_out)
#lstm_out = self.relu(lstm_out)
lstm_out, _ = self.lstm4(lstm_out)
return lstm_out
class PatchLSTMDecoder(nn.Module):
def __init__(self, hidden_dim, output_dim, num_layer, sequence_length):
super(PatchLSTMDecoder, self).__init__()
self.hidden_dim = hidden_dim
self.num_layer = num_layer
self.sequence_length = sequence_length
self.lstm1 = nn.LSTM(hidden_dim, 32, num_layer, batch_first=True)
self.lstm2 = nn.LSTM(32, 64, num_layer, batch_first=True)
self.lstm3 = nn.LSTM(64, 32, num_layer, batch_first=True)
self.lstm4 = nn.LSTM(32, output_dim, num_layer, batch_first=True)
self.relu = nn.LeakyReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
lstm_out, _ = self.lstm1(x)
#lstm_out = self.relu(lstm_out)
lstm_out, _ = self.lstm2(lstm_out)
#lstm_out = self.relu(lstm_out)
lstm_out, _ = self.lstm3(lstm_out)
#lstm_out = self.relu(lstm_out)
lstm_out, _ = self.lstm4(lstm_out)
#lstm_out = self.sigmoid(lstm_out)
batch_size = x.size(0)
img_dim = int(self.sequence_length ** 0.5)
output = lstm_out.contiguous().view(batch_size, 3, img_dim, img_dim)
return output
# Model, Loss, Optimizer
class Model(nn.Module) :
def __init__(self, input_dim, sequence_length):
super(Model, self).__init__()
self.input_dim = input_dim
self.sequence_length = sequence_length
self.encoder = PatchLSTMEncoder(input_dim, params['hidden_dim'], params['num_layer'], sequence_length)
self.decoder = PatchLSTMDecoder(params['hidden_dim'], input_dim, params['num_layer'], sequence_length)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
sequence_length = (32 // params['patch_size']) ** 2
input_dim = 3 * params['patch_size'] * params['patch_size']
model =Model(input_dim, sequence_length).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=params['LR'])
num_epochs = params['EP']
print("Model size : {}".format(count_parameters(model)))
for epoch in range(num_epochs):
train_loss = 0
model.train()
timetemp = time.time()
for images, _ in trainloader:
images = images.to(device)
optimizer.zero_grad()
decoded = model(images)
loss = criterion(decoded, images)
train_loss += loss.item()
loss.backward()
optimizer.step()
train_avg_loss = train_loss / len(trainloader)
train_avg_psnr = round(10 * math.log10(1.0 / train_avg_loss), 3)
#=========================================
test_loss = 0
model.eval()
with torch.no_grad():
for images, _ in testloader:
images = images.to(device)
decoded = model(images)
loss = criterion(decoded, images)
test_loss += loss.item()
test_avg_loss = test_loss / len(testloader)
test_avg_psnr = round(10 * math.log10(1.0 / test_avg_loss), 3)
training_time = time.time() - timetemp
print(
"[{:>3}-Epoch({:>5}sec.)] PSNR(Train / Val) : {:>6.4f} / {:>6.4f}".format(
epoch + 1, round(training_time, 2), train_avg_psnr, test_avg_psnr))
print("Training completed.")
params = {
'BS' : 64,
'LR' : 0.008,
'EP' : 3000,
'patch_size' : 1,
'hidden_dim' : 1,
'num_layer' : 1
}
[ 12-Epoch(21.32sec.)] PSNR(Train / Val) : 26.8620 / 26.9130
[ 13-Epoch(22.06sec.)] PSNR(Train / Val) : 26.8960 / 26.9490
[ 14-Epoch(21.43sec.)] PSNR(Train / Val) : 26.9480 / 27.0240
[ 15-Epoch(21.44sec.)] PSNR(Train / Val) : 26.9880 / 27.0520
[ 16-Epoch(21.43sec.)] PSNR(Train / Val) : 27.0260 / 27.0560
[ 17-Epoch(21.64sec.)] PSNR(Train / Val) : 27.0620 / 27.0140
[ 18-Epoch(21.57sec.)] PSNR(Train / Val) : 27.0990 / 27.1650
[ 19-Epoch(21.62sec.)] PSNR(Train / Val) : 27.1120 / 27.1260
[ 20-Epoch(21.27sec.)] PSNR(Train / Val) : 27.1720 / 27.2180
[ 21-Epoch(21.12sec.)] PSNR(Train / Val) : 27.1860 / 27.2270
[ 22-Epoch(21.04sec.)] PSNR(Train / Val) : 27.1550 / 27.2280
[ 23-Epoch(21.16sec.)] PSNR(Train / Val) : 27.2400 / 27.2420
[ 24-Epoch(21.85sec.)] PSNR(Train / Val) : 27.2370 / 27.1760
[ 25-Epoch(21.39sec.)] PSNR(Train / Val) : 27.2660 / 27.3200
[ 26-Epoch(21.12sec.)] PSNR(Train / Val) : 27.2370 / 27.2950
[ 27-Epoch(21.25sec.)] PSNR(Train / Val) : 27.2820 / 27.3390
[ 28-Epoch(21.55sec.)] PSNR(Train / Val) : 27.2790 / 27.3410
'LSTM' 카테고리의 다른 글
Image transmission using LSTM (0) | 2024.07.23 |
---|