UOMOP

Image transmission using LSTM 본문

LSTM

Image transmission using LSTM

Happy PinGu 2024. 7. 23. 19:23
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import math
from params import *
import torchvision
import time


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def psnr(img1, img2):
    mse = nn.functional.mse_loss(img1, img2)
    if mse == 0:
        return float('inf')
    pixel_max = 1.0
    return 20 * math.log10(pixel_max / math.sqrt(mse.item()))

# Define CIFAR-10 dataset with transformations
transform = transforms.Compose([
    transforms.ToTensor()
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size=params['BS'], shuffle=True)
testloader = DataLoader(testset, batch_size=params['BS'], shuffle=True)


class PatchLSTMEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layer, sequence_length):
        super(PatchLSTMEncoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layer = num_layer
        self.sequence_length = sequence_length
        self.lstm1 = nn.LSTM(input_dim, 32, num_layer, batch_first=True)
        self.lstm2 = nn.LSTM(32, 64, num_layer, batch_first=True)
        self.lstm3 = nn.LSTM(64, 32, num_layer, batch_first=True)
        self.lstm4 = nn.LSTM(32, hidden_dim, num_layer, batch_first=True)
        self.relu = nn.ReLU()

    def forward(self, x):
        batch_size, _, _, _ = x.size()
        x = x.unfold(2, params['patch_size'], params['patch_size']).unfold(3, params['patch_size'], params['patch_size'])
        x = x.contiguous().view(batch_size, sequence_length, -1)  # Reshape to (batch_size, sequence_length, patch_dim)
        lstm_out, _ = self.lstm1(x)
        lstm_out = self.relu(lstm_out)
        lstm_out, _ = self.lstm2(lstm_out)
        lstm_out = self.relu(lstm_out)
        lstm_out, _ = self.lstm3(lstm_out)
        lstm_out = self.relu(lstm_out)
        lstm_out, _ = self.lstm4(lstm_out)
        return lstm_out

class PatchLSTMDecoder(nn.Module):
    def __init__(self, hidden_dim, output_dim, num_layer, sequence_length):
        super(PatchLSTMDecoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layer = num_layer
        self.sequence_length = sequence_length
        self.lstm1 = nn.LSTM(hidden_dim, 32, num_layer, batch_first=True)
        self.lstm2 = nn.LSTM(32, 64, num_layer, batch_first=True)
        self.lstm3 = nn.LSTM(64, 32, num_layer, batch_first=True)
        self.lstm4 = nn.LSTM(32, output_dim, num_layer, batch_first=True)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        lstm_out, _ = self.lstm1(x)
        lstm_out = self.relu(lstm_out)
        lstm_out, _ = self.lstm2(lstm_out)
        lstm_out = self.relu(lstm_out)
        lstm_out, _ = self.lstm3(lstm_out)
        lstm_out = self.relu(lstm_out)
        lstm_out, _ = self.lstm4(lstm_out)
        #lstm_out = self.sigmoid(lstm_out)
        batch_size = x.size(0)
        img_dim = int(self.sequence_length ** 0.5)
        output = lstm_out.contiguous().view(batch_size, 3, img_dim, img_dim)
        return output

# Model, Loss, Optimizer

class Model(nn.Module) :
    def __init__(self, input_dim, sequence_length):
        super(Model, self).__init__()
        self.input_dim = input_dim
        self.sequence_length = sequence_length

        self.encoder = PatchLSTMEncoder(input_dim, params['hidden_dim'], params['num_layer'], sequence_length)
        self.decoder = PatchLSTMDecoder(params['hidden_dim'], input_dim, params['num_layer'], sequence_length)
    def forward(self, x):

        encoded = self.encoder(x)
        decoded = self.decoder(encoded)

        return decoded


sequence_length = (32 // params['patch_size']) ** 2
input_dim = 3 * params['patch_size'] * params['patch_size']

model =Model(input_dim, sequence_length).to(device)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=params['LR'])

num_epochs = params['EP']

print("Model size : {}".format(count_parameters(model)))

for epoch in range(num_epochs):
    train_loss = 0

    model.train()
    timetemp = time.time()

    for images, _ in trainloader:
        images = images.to(device)
        optimizer.zero_grad()

        decoded = model(images)

        loss = criterion(decoded, images)
        train_loss += loss.item()

        loss.backward()
        optimizer.step()

    train_avg_loss = train_loss / len(trainloader)
    train_avg_psnr = round(10 * math.log10(1.0 / train_avg_loss), 3)

    #=========================================

    test_loss = 0
    model.eval()
    with torch.no_grad():
        for images, _ in testloader:
            images = images.to(device)

            decoded = model(images)

            loss = criterion(decoded, images)
            test_loss += loss.item()

        test_avg_loss = test_loss / len(testloader)
        test_avg_psnr = round(10 * math.log10(1.0 / test_avg_loss), 3)



    training_time = time.time() - timetemp

    print(
        "[{:>3}-Epoch({:>5}sec.)]  PSNR(Train / Val) : {:>6.4f} / {:>6.4f}".format(
            epoch + 1, round(training_time, 2), train_avg_psnr, test_avg_psnr))


print("Training completed.")
params = {

    'BS' : 64,
    'LR' : 0.0005,
    'EP' : 3000,

    'patch_size' : 1,
    'hidden_dim' : 1,
    'num_layer' : 1

}

C:\Users\dowon\anaconda3\envs\dowon_simul\python.exe C:\Users\dowon\main_research\LSTM_cifar10\main.py 
Files already downloaded and verified
Files already downloaded and verified
Model size : 85064
[  1-Epoch(22.02sec.)]  PSNR(Train / Val) : 13.9430 / 16.9690
[  2-Epoch(21.84sec.)]  PSNR(Train / Val) : 17.1040 / 17.2060
[  3-Epoch( 21.9sec.)]  PSNR(Train / Val) : 17.3630 / 17.8890
[  4-Epoch( 21.6sec.)]  PSNR(Train / Val) : 18.3060 / 18.4280
[  5-Epoch( 21.8sec.)]  PSNR(Train / Val) : 18.5470 / 18.6290
[  6-Epoch(21.76sec.)]  PSNR(Train / Val) : 18.6550 / 18.7290
[  7-Epoch(21.62sec.)]  PSNR(Train / Val) : 18.8380 / 19.1920
[  8-Epoch( 21.6sec.)]  PSNR(Train / Val) : 19.4620 / 19.6570
[  9-Epoch(22.48sec.)]  PSNR(Train / Val) : 20.8220 / 23.5100
[ 10-Epoch(22.17sec.)]  PSNR(Train / Val) : 23.7390 / 23.9160
[ 11-Epoch(21.57sec.)]  PSNR(Train / Val) : 24.0340 / 24.0520
[ 12-Epoch(21.66sec.)]  PSNR(Train / Val) : 24.2630 / 24.3960
[ 13-Epoch(21.66sec.)]  PSNR(Train / Val) : 24.4090 / 24.2440
[ 14-Epoch(21.66sec.)]  PSNR(Train / Val) : 24.4860 / 24.5430
[ 15-Epoch(21.98sec.)]  PSNR(Train / Val) : 24.5340 / 24.1220
[ 16-Epoch(21.91sec.)]  PSNR(Train / Val) : 24.5780 / 24.5370
[ 17-Epoch(21.82sec.)]  PSNR(Train / Val) : 24.6180 / 24.6780
[ 18-Epoch(21.67sec.)]  PSNR(Train / Val) : 24.6580 / 24.6570
[ 19-Epoch(25.06sec.)]  PSNR(Train / Val) : 24.6850 / 24.7450
[ 20-Epoch( 22.4sec.)]  PSNR(Train / Val) : 24.7190 / 24.7740
[ 21-Epoch( 21.7sec.)]  PSNR(Train / Val) : 24.7520 / 24.7650
[ 22-Epoch(22.09sec.)]  PSNR(Train / Val) : 24.7740 / 24.6580
[ 23-Epoch(21.92sec.)]  PSNR(Train / Val) : 24.8120 / 24.8110
[ 24-Epoch(22.16sec.)]  PSNR(Train / Val) : 24.8450 / 24.9040
[ 25-Epoch(22.23sec.)]  PSNR(Train / Val) : 24.8850 / 24.9240
[ 26-Epoch( 22.6sec.)]  PSNR(Train / Val) : 24.9300 / 24.9720
[ 27-Epoch(21.83sec.)]  PSNR(Train / Val) : 24.9610 / 25.0130
[ 28-Epoch(22.08sec.)]  PSNR(Train / Val) : 25.0060 / 25.0310
[ 29-Epoch(21.88sec.)]  PSNR(Train / Val) : 25.0700 / 25.1470
[ 30-Epoch(21.83sec.)]  PSNR(Train / Val) : 25.1480 / 25.2230
[ 31-Epoch(22.27sec.)]  PSNR(Train / Val) : 25.2460 / 25.2810
[ 32-Epoch(21.98sec.)]  PSNR(Train / Val) : 25.3520 / 25.4200
[ 33-Epoch(22.23sec.)]  PSNR(Train / Val) : 25.4650 / 25.3800
[ 34-Epoch(22.11sec.)]  PSNR(Train / Val) : 25.5500 / 25.6030
[ 35-Epoch(22.05sec.)]  PSNR(Train / Val) : 25.6130 / 25.6650
[ 36-Epoch(22.31sec.)]  PSNR(Train / Val) : 25.6710 / 25.7140
[ 37-Epoch(22.37sec.)]  PSNR(Train / Val) : 25.7260 / 25.7790
[ 38-Epoch(22.02sec.)]  PSNR(Train / Val) : 25.7790 / 25.6810
[ 39-Epoch(21.68sec.)]  PSNR(Train / Val) : 25.8310 / 25.8970
[ 40-Epoch(21.52sec.)]  PSNR(Train / Val) : 25.8820 / 25.9380
[ 41-Epoch(21.83sec.)]  PSNR(Train / Val) : 25.9390 / 26.0050
[ 42-Epoch(22.55sec.)]  PSNR(Train / Val) : 26.0220 / 26.0980
[ 43-Epoch(21.82sec.)]  PSNR(Train / Val) : 26.1020 / 26.1300
[ 44-Epoch(21.59sec.)]  PSNR(Train / Val) : 26.1550 / 26.2100
[ 45-Epoch(21.69sec.)]  PSNR(Train / Val) : 26.1960 / 26.2500
[ 46-Epoch(21.84sec.)]  PSNR(Train / Val) : 26.2290 / 26.2780
[ 47-Epoch(21.71sec.)]  PSNR(Train / Val) : 26.2590 / 26.3020
[ 48-Epoch(22.25sec.)]  PSNR(Train / Val) : 26.2830 / 26.3400
[ 49-Epoch(21.81sec.)]  PSNR(Train / Val) : 26.3100 / 26.3530
[ 50-Epoch(21.79sec.)]  PSNR(Train / Val) : 26.3300 / 26.3270
[ 51-Epoch(21.77sec.)]  PSNR(Train / Val) : 26.3550 / 26.3870
[ 52-Epoch( 22.0sec.)]  PSNR(Train / Val) : 26.3950 / 26.3850
[ 53-Epoch(22.57sec.)]  PSNR(Train / Val) : 26.4720 / 26.5260
[ 54-Epoch(21.96sec.)]  PSNR(Train / Val) : 26.5150 / 26.5530
[ 55-Epoch(22.04sec.)]  PSNR(Train / Val) : 26.5400 / 26.5750
[ 56-Epoch(21.97sec.)]  PSNR(Train / Val) : 26.5570 / 26.5810
[ 57-Epoch(21.73sec.)]  PSNR(Train / Val) : 26.5720 / 26.6120
[ 58-Epoch(21.88sec.)]  PSNR(Train / Val) : 26.5850 / 26.6200
[ 59-Epoch(21.89sec.)]  PSNR(Train / Val) : 26.5940 / 26.4810
[ 60-Epoch(21.55sec.)]  PSNR(Train / Val) : 26.6040 / 26.6110
[ 61-Epoch( 21.7sec.)]  PSNR(Train / Val) : 26.6140 / 26.6590
[ 62-Epoch(21.65sec.)]  PSNR(Train / Val) : 26.6240 / 26.6400
[ 63-Epoch(21.75sec.)]  PSNR(Train / Val) : 26.6330 / 26.6650
[ 64-Epoch(21.98sec.)]  PSNR(Train / Val) : 26.6430 / 26.6710
[ 65-Epoch(21.76sec.)]  PSNR(Train / Val) : 26.6460 / 26.6830
[ 66-Epoch(21.72sec.)]  PSNR(Train / Val) : 26.6570 / 26.6450
[ 67-Epoch(21.79sec.)]  PSNR(Train / Val) : 26.6620 / 26.6790
[ 68-Epoch(21.79sec.)]  PSNR(Train / Val) : 26.6690 / 26.6350
[ 69-Epoch(22.54sec.)]  PSNR(Train / Val) : 26.6750 / 26.7070
[ 70-Epoch(22.28sec.)]  PSNR(Train / Val) : 26.6810 / 26.6970
[ 71-Epoch(21.72sec.)]  PSNR(Train / Val) : 26.6900 / 26.7200
[ 72-Epoch(21.96sec.)]  PSNR(Train / Val) : 26.6950 / 26.7210
[ 73-Epoch(21.84sec.)]  PSNR(Train / Val) : 26.7010 / 26.7210
[ 74-Epoch(21.93sec.)]  PSNR(Train / Val) : 26.7060 / 26.7130
[ 75-Epoch(22.21sec.)]  PSNR(Train / Val) : 26.7130 / 26.7320
[ 76-Epoch(21.72sec.)]  PSNR(Train / Val) : 26.7200 / 26.7270
[ 77-Epoch(21.88sec.)]  PSNR(Train / Val) : 26.7260 / 26.7490
[ 78-Epoch(21.69sec.)]  PSNR(Train / Val) : 26.7310 / 26.7060
[ 79-Epoch(21.82sec.)]  PSNR(Train / Val) : 26.7360 / 26.7580
[ 80-Epoch(22.16sec.)]  PSNR(Train / Val) : 26.7440 / 26.7680
[ 81-Epoch(22.18sec.)]  PSNR(Train / Val) : 26.7470 / 26.7390
[ 82-Epoch(21.75sec.)]  PSNR(Train / Val) : 26.7530 / 26.7750
[ 83-Epoch( 21.7sec.)]  PSNR(Train / Val) : 26.7590 / 26.7790
[ 84-Epoch(21.82sec.)]  PSNR(Train / Val) : 26.7640 / 26.7840
[ 85-Epoch( 22.3sec.)]  PSNR(Train / Val) : 26.7690 / 26.7510
[ 86-Epoch(22.54sec.)]  PSNR(Train / Val) : 26.7740 / 26.7860
[ 87-Epoch(22.37sec.)]  PSNR(Train / Val) : 26.7800 / 26.8020
[ 88-Epoch(21.72sec.)]  PSNR(Train / Val) : 26.7870 / 26.8110
[ 89-Epoch( 22.3sec.)]  PSNR(Train / Val) : 26.7880 / 26.8130
[ 90-Epoch(22.13sec.)]  PSNR(Train / Val) : 26.7950 / 26.8070
[ 91-Epoch(24.31sec.)]  PSNR(Train / Val) : 26.8000 / 26.8280
[ 92-Epoch(22.32sec.)]  PSNR(Train / Val) : 26.8070 / 26.8170
[ 93-Epoch(22.07sec.)]  PSNR(Train / Val) : 26.8120 / 26.8230
[ 94-Epoch(21.75sec.)]  PSNR(Train / Val) : 26.8170 / 26.8180
[ 95-Epoch( 21.7sec.)]  PSNR(Train / Val) : 26.8230 / 26.8380
[ 96-Epoch(22.01sec.)]  PSNR(Train / Val) : 26.8310 / 26.8450
[ 97-Epoch(22.39sec.)]  PSNR(Train / Val) : 26.8350 / 26.8510
[ 98-Epoch(21.99sec.)]  PSNR(Train / Val) : 26.8390 / 26.8680
[ 99-Epoch(21.94sec.)]  PSNR(Train / Val) : 26.8470 / 26.8690
[100-Epoch(21.81sec.)]  PSNR(Train / Val) : 26.8530 / 26.8580
[101-Epoch(22.35sec.)]  PSNR(Train / Val) : 26.8580 / 26.8880
[102-Epoch(22.58sec.)]  PSNR(Train / Val) : 26.8650 / 26.8900
[103-Epoch(22.01sec.)]  PSNR(Train / Val) : 26.8720 / 26.8990
[104-Epoch(22.13sec.)]  PSNR(Train / Val) : 26.8790 / 26.8730
[105-Epoch(22.11sec.)]  PSNR(Train / Val) : 26.8840 / 26.8540
[106-Epoch( 22.2sec.)]  PSNR(Train / Val) : 26.8900 / 26.9140
[107-Epoch( 21.8sec.)]  PSNR(Train / Val) : 26.8970 / 26.9150
[108-Epoch(22.19sec.)]  PSNR(Train / Val) : 26.9020 / 26.9240
[109-Epoch(21.73sec.)]  PSNR(Train / Val) : 26.9080 / 26.9170
[110-Epoch(21.61sec.)]  PSNR(Train / Val) : 26.9140 / 26.9380
[111-Epoch( 21.5sec.)]  PSNR(Train / Val) : 26.9190 / 26.9310
[112-Epoch(21.85sec.)]  PSNR(Train / Val) : 26.9250 / 26.9460
[113-Epoch(22.22sec.)]  PSNR(Train / Val) : 26.9300 / 26.9560
[114-Epoch(21.74sec.)]  PSNR(Train / Val) : 26.9370 / 26.9570
[115-Epoch( 21.8sec.)]  PSNR(Train / Val) : 26.9410 / 26.9650
[116-Epoch(21.74sec.)]  PSNR(Train / Val) : 26.9490 / 26.9480
[117-Epoch(21.89sec.)]  PSNR(Train / Val) : 26.9520 / 26.9220
[118-Epoch(22.23sec.)]  PSNR(Train / Val) : 26.9590 / 26.9800
[119-Epoch(22.29sec.)]  PSNR(Train / Val) : 26.9630 / 26.9770
[120-Epoch(21.81sec.)]  PSNR(Train / Val) : 26.9690 / 27.0020
[121-Epoch(21.96sec.)]  PSNR(Train / Val) : 26.9750 / 26.9990
[122-Epoch(21.89sec.)]  PSNR(Train / Val) : 26.9800 / 26.9760
[123-Epoch(21.92sec.)]  PSNR(Train / Val) : 26.9840 / 26.9990
[124-Epoch(22.45sec.)]  PSNR(Train / Val) : 26.9900 / 26.9900
[125-Epoch(21.73sec.)]  PSNR(Train / Val) : 26.9940 / 27.0210
[126-Epoch(21.84sec.)]  PSNR(Train / Val) : 26.9980 / 27.0160
[127-Epoch(21.74sec.)]  PSNR(Train / Val) : 27.0010 / 27.0190
[128-Epoch(21.87sec.)]  PSNR(Train / Val) : 27.0080 / 27.0300
[129-Epoch(22.02sec.)]  PSNR(Train / Val) : 27.0100 / 27.0380
[130-Epoch( 22.0sec.)]  PSNR(Train / Val) : 27.0150 / 27.0170
[131-Epoch(22.01sec.)]  PSNR(Train / Val) : 27.0190 / 27.0380
[132-Epoch(22.08sec.)]  PSNR(Train / Val) : 27.0230 / 27.0100
[133-Epoch(21.89sec.)]  PSNR(Train / Val) : 27.0290 / 27.0310
[134-Epoch(22.27sec.)]  PSNR(Train / Val) : 27.0300 / 27.0570
[135-Epoch(22.72sec.)]  PSNR(Train / Val) : 27.0350 / 27.0480
[136-Epoch(21.87sec.)]  PSNR(Train / Val) : 27.0390 / 27.0580
[137-Epoch(21.79sec.)]  PSNR(Train / Val) : 27.0420 / 27.0660
[138-Epoch(21.99sec.)]  PSNR(Train / Val) : 27.0470 / 27.0650

'LSTM' 카테고리의 다른 글

2  (0) 2024.07.23
Comments