UOMOP
Image transmission using LSTM 본문
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import math
from params import *
import torchvision
import time
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def psnr(img1, img2):
mse = nn.functional.mse_loss(img1, img2)
if mse == 0:
return float('inf')
pixel_max = 1.0
return 20 * math.log10(pixel_max / math.sqrt(mse.item()))
# Define CIFAR-10 dataset with transformations
transform = transforms.Compose([
transforms.ToTensor()
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=params['BS'], shuffle=True)
testloader = DataLoader(testset, batch_size=params['BS'], shuffle=True)
class PatchLSTMEncoder(nn.Module):
def __init__(self, input_dim, hidden_dim, num_layer, sequence_length):
super(PatchLSTMEncoder, self).__init__()
self.hidden_dim = hidden_dim
self.num_layer = num_layer
self.sequence_length = sequence_length
self.lstm1 = nn.LSTM(input_dim, 32, num_layer, batch_first=True)
self.lstm2 = nn.LSTM(32, 64, num_layer, batch_first=True)
self.lstm3 = nn.LSTM(64, 32, num_layer, batch_first=True)
self.lstm4 = nn.LSTM(32, hidden_dim, num_layer, batch_first=True)
self.relu = nn.ReLU()
def forward(self, x):
batch_size, _, _, _ = x.size()
x = x.unfold(2, params['patch_size'], params['patch_size']).unfold(3, params['patch_size'], params['patch_size'])
x = x.contiguous().view(batch_size, sequence_length, -1) # Reshape to (batch_size, sequence_length, patch_dim)
lstm_out, _ = self.lstm1(x)
lstm_out = self.relu(lstm_out)
lstm_out, _ = self.lstm2(lstm_out)
lstm_out = self.relu(lstm_out)
lstm_out, _ = self.lstm3(lstm_out)
lstm_out = self.relu(lstm_out)
lstm_out, _ = self.lstm4(lstm_out)
return lstm_out
class PatchLSTMDecoder(nn.Module):
def __init__(self, hidden_dim, output_dim, num_layer, sequence_length):
super(PatchLSTMDecoder, self).__init__()
self.hidden_dim = hidden_dim
self.num_layer = num_layer
self.sequence_length = sequence_length
self.lstm1 = nn.LSTM(hidden_dim, 32, num_layer, batch_first=True)
self.lstm2 = nn.LSTM(32, 64, num_layer, batch_first=True)
self.lstm3 = nn.LSTM(64, 32, num_layer, batch_first=True)
self.lstm4 = nn.LSTM(32, output_dim, num_layer, batch_first=True)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
lstm_out, _ = self.lstm1(x)
lstm_out = self.relu(lstm_out)
lstm_out, _ = self.lstm2(lstm_out)
lstm_out = self.relu(lstm_out)
lstm_out, _ = self.lstm3(lstm_out)
lstm_out = self.relu(lstm_out)
lstm_out, _ = self.lstm4(lstm_out)
#lstm_out = self.sigmoid(lstm_out)
batch_size = x.size(0)
img_dim = int(self.sequence_length ** 0.5)
output = lstm_out.contiguous().view(batch_size, 3, img_dim, img_dim)
return output
# Model, Loss, Optimizer
class Model(nn.Module) :
def __init__(self, input_dim, sequence_length):
super(Model, self).__init__()
self.input_dim = input_dim
self.sequence_length = sequence_length
self.encoder = PatchLSTMEncoder(input_dim, params['hidden_dim'], params['num_layer'], sequence_length)
self.decoder = PatchLSTMDecoder(params['hidden_dim'], input_dim, params['num_layer'], sequence_length)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
sequence_length = (32 // params['patch_size']) ** 2
input_dim = 3 * params['patch_size'] * params['patch_size']
model =Model(input_dim, sequence_length).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=params['LR'])
num_epochs = params['EP']
print("Model size : {}".format(count_parameters(model)))
for epoch in range(num_epochs):
train_loss = 0
model.train()
timetemp = time.time()
for images, _ in trainloader:
images = images.to(device)
optimizer.zero_grad()
decoded = model(images)
loss = criterion(decoded, images)
train_loss += loss.item()
loss.backward()
optimizer.step()
train_avg_loss = train_loss / len(trainloader)
train_avg_psnr = round(10 * math.log10(1.0 / train_avg_loss), 3)
#=========================================
test_loss = 0
model.eval()
with torch.no_grad():
for images, _ in testloader:
images = images.to(device)
decoded = model(images)
loss = criterion(decoded, images)
test_loss += loss.item()
test_avg_loss = test_loss / len(testloader)
test_avg_psnr = round(10 * math.log10(1.0 / test_avg_loss), 3)
training_time = time.time() - timetemp
print(
"[{:>3}-Epoch({:>5}sec.)] PSNR(Train / Val) : {:>6.4f} / {:>6.4f}".format(
epoch + 1, round(training_time, 2), train_avg_psnr, test_avg_psnr))
print("Training completed.")
params = {
'BS' : 64,
'LR' : 0.0005,
'EP' : 3000,
'patch_size' : 1,
'hidden_dim' : 1,
'num_layer' : 1
}
C:\Users\dowon\anaconda3\envs\dowon_simul\python.exe C:\Users\dowon\main_research\LSTM_cifar10\main.py
Files already downloaded and verified
Files already downloaded and verified
Model size : 85064
[ 1-Epoch(22.02sec.)] PSNR(Train / Val) : 13.9430 / 16.9690
[ 2-Epoch(21.84sec.)] PSNR(Train / Val) : 17.1040 / 17.2060
[ 3-Epoch( 21.9sec.)] PSNR(Train / Val) : 17.3630 / 17.8890
[ 4-Epoch( 21.6sec.)] PSNR(Train / Val) : 18.3060 / 18.4280
[ 5-Epoch( 21.8sec.)] PSNR(Train / Val) : 18.5470 / 18.6290
[ 6-Epoch(21.76sec.)] PSNR(Train / Val) : 18.6550 / 18.7290
[ 7-Epoch(21.62sec.)] PSNR(Train / Val) : 18.8380 / 19.1920
[ 8-Epoch( 21.6sec.)] PSNR(Train / Val) : 19.4620 / 19.6570
[ 9-Epoch(22.48sec.)] PSNR(Train / Val) : 20.8220 / 23.5100
[ 10-Epoch(22.17sec.)] PSNR(Train / Val) : 23.7390 / 23.9160
[ 11-Epoch(21.57sec.)] PSNR(Train / Val) : 24.0340 / 24.0520
[ 12-Epoch(21.66sec.)] PSNR(Train / Val) : 24.2630 / 24.3960
[ 13-Epoch(21.66sec.)] PSNR(Train / Val) : 24.4090 / 24.2440
[ 14-Epoch(21.66sec.)] PSNR(Train / Val) : 24.4860 / 24.5430
[ 15-Epoch(21.98sec.)] PSNR(Train / Val) : 24.5340 / 24.1220
[ 16-Epoch(21.91sec.)] PSNR(Train / Val) : 24.5780 / 24.5370
[ 17-Epoch(21.82sec.)] PSNR(Train / Val) : 24.6180 / 24.6780
[ 18-Epoch(21.67sec.)] PSNR(Train / Val) : 24.6580 / 24.6570
[ 19-Epoch(25.06sec.)] PSNR(Train / Val) : 24.6850 / 24.7450
[ 20-Epoch( 22.4sec.)] PSNR(Train / Val) : 24.7190 / 24.7740
[ 21-Epoch( 21.7sec.)] PSNR(Train / Val) : 24.7520 / 24.7650
[ 22-Epoch(22.09sec.)] PSNR(Train / Val) : 24.7740 / 24.6580
[ 23-Epoch(21.92sec.)] PSNR(Train / Val) : 24.8120 / 24.8110
[ 24-Epoch(22.16sec.)] PSNR(Train / Val) : 24.8450 / 24.9040
[ 25-Epoch(22.23sec.)] PSNR(Train / Val) : 24.8850 / 24.9240
[ 26-Epoch( 22.6sec.)] PSNR(Train / Val) : 24.9300 / 24.9720
[ 27-Epoch(21.83sec.)] PSNR(Train / Val) : 24.9610 / 25.0130
[ 28-Epoch(22.08sec.)] PSNR(Train / Val) : 25.0060 / 25.0310
[ 29-Epoch(21.88sec.)] PSNR(Train / Val) : 25.0700 / 25.1470
[ 30-Epoch(21.83sec.)] PSNR(Train / Val) : 25.1480 / 25.2230
[ 31-Epoch(22.27sec.)] PSNR(Train / Val) : 25.2460 / 25.2810
[ 32-Epoch(21.98sec.)] PSNR(Train / Val) : 25.3520 / 25.4200
[ 33-Epoch(22.23sec.)] PSNR(Train / Val) : 25.4650 / 25.3800
[ 34-Epoch(22.11sec.)] PSNR(Train / Val) : 25.5500 / 25.6030
[ 35-Epoch(22.05sec.)] PSNR(Train / Val) : 25.6130 / 25.6650
[ 36-Epoch(22.31sec.)] PSNR(Train / Val) : 25.6710 / 25.7140
[ 37-Epoch(22.37sec.)] PSNR(Train / Val) : 25.7260 / 25.7790
[ 38-Epoch(22.02sec.)] PSNR(Train / Val) : 25.7790 / 25.6810
[ 39-Epoch(21.68sec.)] PSNR(Train / Val) : 25.8310 / 25.8970
[ 40-Epoch(21.52sec.)] PSNR(Train / Val) : 25.8820 / 25.9380
[ 41-Epoch(21.83sec.)] PSNR(Train / Val) : 25.9390 / 26.0050
[ 42-Epoch(22.55sec.)] PSNR(Train / Val) : 26.0220 / 26.0980
[ 43-Epoch(21.82sec.)] PSNR(Train / Val) : 26.1020 / 26.1300
[ 44-Epoch(21.59sec.)] PSNR(Train / Val) : 26.1550 / 26.2100
[ 45-Epoch(21.69sec.)] PSNR(Train / Val) : 26.1960 / 26.2500
[ 46-Epoch(21.84sec.)] PSNR(Train / Val) : 26.2290 / 26.2780
[ 47-Epoch(21.71sec.)] PSNR(Train / Val) : 26.2590 / 26.3020
[ 48-Epoch(22.25sec.)] PSNR(Train / Val) : 26.2830 / 26.3400
[ 49-Epoch(21.81sec.)] PSNR(Train / Val) : 26.3100 / 26.3530
[ 50-Epoch(21.79sec.)] PSNR(Train / Val) : 26.3300 / 26.3270
[ 51-Epoch(21.77sec.)] PSNR(Train / Val) : 26.3550 / 26.3870
[ 52-Epoch( 22.0sec.)] PSNR(Train / Val) : 26.3950 / 26.3850
[ 53-Epoch(22.57sec.)] PSNR(Train / Val) : 26.4720 / 26.5260
[ 54-Epoch(21.96sec.)] PSNR(Train / Val) : 26.5150 / 26.5530
[ 55-Epoch(22.04sec.)] PSNR(Train / Val) : 26.5400 / 26.5750
[ 56-Epoch(21.97sec.)] PSNR(Train / Val) : 26.5570 / 26.5810
[ 57-Epoch(21.73sec.)] PSNR(Train / Val) : 26.5720 / 26.6120
[ 58-Epoch(21.88sec.)] PSNR(Train / Val) : 26.5850 / 26.6200
[ 59-Epoch(21.89sec.)] PSNR(Train / Val) : 26.5940 / 26.4810
[ 60-Epoch(21.55sec.)] PSNR(Train / Val) : 26.6040 / 26.6110
[ 61-Epoch( 21.7sec.)] PSNR(Train / Val) : 26.6140 / 26.6590
[ 62-Epoch(21.65sec.)] PSNR(Train / Val) : 26.6240 / 26.6400
[ 63-Epoch(21.75sec.)] PSNR(Train / Val) : 26.6330 / 26.6650
[ 64-Epoch(21.98sec.)] PSNR(Train / Val) : 26.6430 / 26.6710
[ 65-Epoch(21.76sec.)] PSNR(Train / Val) : 26.6460 / 26.6830
[ 66-Epoch(21.72sec.)] PSNR(Train / Val) : 26.6570 / 26.6450
[ 67-Epoch(21.79sec.)] PSNR(Train / Val) : 26.6620 / 26.6790
[ 68-Epoch(21.79sec.)] PSNR(Train / Val) : 26.6690 / 26.6350
[ 69-Epoch(22.54sec.)] PSNR(Train / Val) : 26.6750 / 26.7070
[ 70-Epoch(22.28sec.)] PSNR(Train / Val) : 26.6810 / 26.6970
[ 71-Epoch(21.72sec.)] PSNR(Train / Val) : 26.6900 / 26.7200
[ 72-Epoch(21.96sec.)] PSNR(Train / Val) : 26.6950 / 26.7210
[ 73-Epoch(21.84sec.)] PSNR(Train / Val) : 26.7010 / 26.7210
[ 74-Epoch(21.93sec.)] PSNR(Train / Val) : 26.7060 / 26.7130
[ 75-Epoch(22.21sec.)] PSNR(Train / Val) : 26.7130 / 26.7320
[ 76-Epoch(21.72sec.)] PSNR(Train / Val) : 26.7200 / 26.7270
[ 77-Epoch(21.88sec.)] PSNR(Train / Val) : 26.7260 / 26.7490
[ 78-Epoch(21.69sec.)] PSNR(Train / Val) : 26.7310 / 26.7060
[ 79-Epoch(21.82sec.)] PSNR(Train / Val) : 26.7360 / 26.7580
[ 80-Epoch(22.16sec.)] PSNR(Train / Val) : 26.7440 / 26.7680
[ 81-Epoch(22.18sec.)] PSNR(Train / Val) : 26.7470 / 26.7390
[ 82-Epoch(21.75sec.)] PSNR(Train / Val) : 26.7530 / 26.7750
[ 83-Epoch( 21.7sec.)] PSNR(Train / Val) : 26.7590 / 26.7790
[ 84-Epoch(21.82sec.)] PSNR(Train / Val) : 26.7640 / 26.7840
[ 85-Epoch( 22.3sec.)] PSNR(Train / Val) : 26.7690 / 26.7510
[ 86-Epoch(22.54sec.)] PSNR(Train / Val) : 26.7740 / 26.7860
[ 87-Epoch(22.37sec.)] PSNR(Train / Val) : 26.7800 / 26.8020
[ 88-Epoch(21.72sec.)] PSNR(Train / Val) : 26.7870 / 26.8110
[ 89-Epoch( 22.3sec.)] PSNR(Train / Val) : 26.7880 / 26.8130
[ 90-Epoch(22.13sec.)] PSNR(Train / Val) : 26.7950 / 26.8070
[ 91-Epoch(24.31sec.)] PSNR(Train / Val) : 26.8000 / 26.8280
[ 92-Epoch(22.32sec.)] PSNR(Train / Val) : 26.8070 / 26.8170
[ 93-Epoch(22.07sec.)] PSNR(Train / Val) : 26.8120 / 26.8230
[ 94-Epoch(21.75sec.)] PSNR(Train / Val) : 26.8170 / 26.8180
[ 95-Epoch( 21.7sec.)] PSNR(Train / Val) : 26.8230 / 26.8380
[ 96-Epoch(22.01sec.)] PSNR(Train / Val) : 26.8310 / 26.8450
[ 97-Epoch(22.39sec.)] PSNR(Train / Val) : 26.8350 / 26.8510
[ 98-Epoch(21.99sec.)] PSNR(Train / Val) : 26.8390 / 26.8680
[ 99-Epoch(21.94sec.)] PSNR(Train / Val) : 26.8470 / 26.8690
[100-Epoch(21.81sec.)] PSNR(Train / Val) : 26.8530 / 26.8580
[101-Epoch(22.35sec.)] PSNR(Train / Val) : 26.8580 / 26.8880
[102-Epoch(22.58sec.)] PSNR(Train / Val) : 26.8650 / 26.8900
[103-Epoch(22.01sec.)] PSNR(Train / Val) : 26.8720 / 26.8990
[104-Epoch(22.13sec.)] PSNR(Train / Val) : 26.8790 / 26.8730
[105-Epoch(22.11sec.)] PSNR(Train / Val) : 26.8840 / 26.8540
[106-Epoch( 22.2sec.)] PSNR(Train / Val) : 26.8900 / 26.9140
[107-Epoch( 21.8sec.)] PSNR(Train / Val) : 26.8970 / 26.9150
[108-Epoch(22.19sec.)] PSNR(Train / Val) : 26.9020 / 26.9240
[109-Epoch(21.73sec.)] PSNR(Train / Val) : 26.9080 / 26.9170
[110-Epoch(21.61sec.)] PSNR(Train / Val) : 26.9140 / 26.9380
[111-Epoch( 21.5sec.)] PSNR(Train / Val) : 26.9190 / 26.9310
[112-Epoch(21.85sec.)] PSNR(Train / Val) : 26.9250 / 26.9460
[113-Epoch(22.22sec.)] PSNR(Train / Val) : 26.9300 / 26.9560
[114-Epoch(21.74sec.)] PSNR(Train / Val) : 26.9370 / 26.9570
[115-Epoch( 21.8sec.)] PSNR(Train / Val) : 26.9410 / 26.9650
[116-Epoch(21.74sec.)] PSNR(Train / Val) : 26.9490 / 26.9480
[117-Epoch(21.89sec.)] PSNR(Train / Val) : 26.9520 / 26.9220
[118-Epoch(22.23sec.)] PSNR(Train / Val) : 26.9590 / 26.9800
[119-Epoch(22.29sec.)] PSNR(Train / Val) : 26.9630 / 26.9770
[120-Epoch(21.81sec.)] PSNR(Train / Val) : 26.9690 / 27.0020
[121-Epoch(21.96sec.)] PSNR(Train / Val) : 26.9750 / 26.9990
[122-Epoch(21.89sec.)] PSNR(Train / Val) : 26.9800 / 26.9760
[123-Epoch(21.92sec.)] PSNR(Train / Val) : 26.9840 / 26.9990
[124-Epoch(22.45sec.)] PSNR(Train / Val) : 26.9900 / 26.9900
[125-Epoch(21.73sec.)] PSNR(Train / Val) : 26.9940 / 27.0210
[126-Epoch(21.84sec.)] PSNR(Train / Val) : 26.9980 / 27.0160
[127-Epoch(21.74sec.)] PSNR(Train / Val) : 27.0010 / 27.0190
[128-Epoch(21.87sec.)] PSNR(Train / Val) : 27.0080 / 27.0300
[129-Epoch(22.02sec.)] PSNR(Train / Val) : 27.0100 / 27.0380
[130-Epoch( 22.0sec.)] PSNR(Train / Val) : 27.0150 / 27.0170
[131-Epoch(22.01sec.)] PSNR(Train / Val) : 27.0190 / 27.0380
[132-Epoch(22.08sec.)] PSNR(Train / Val) : 27.0230 / 27.0100
[133-Epoch(21.89sec.)] PSNR(Train / Val) : 27.0290 / 27.0310
[134-Epoch(22.27sec.)] PSNR(Train / Val) : 27.0300 / 27.0570
[135-Epoch(22.72sec.)] PSNR(Train / Val) : 27.0350 / 27.0480
[136-Epoch(21.87sec.)] PSNR(Train / Val) : 27.0390 / 27.0580
[137-Epoch(21.79sec.)] PSNR(Train / Val) : 27.0420 / 27.0660
[138-Epoch(21.99sec.)] PSNR(Train / Val) : 27.0470 / 27.0650