UOMOP

2 본문

LSTM

2

Happy PinGu 2024. 7. 23. 20:36
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import math
from params import *
import torchvision
import time


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def psnr(img1, img2):
    mse = nn.functional.mse_loss(img1, img2)
    if mse == 0:
        return float('inf')
    pixel_max = 1.0
    return 20 * math.log10(pixel_max / math.sqrt(mse.item()))

# Define CIFAR-10 dataset with transformations
transform = transforms.Compose([
    transforms.ToTensor()
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size=params['BS'], shuffle=True)
testloader = DataLoader(testset, batch_size=params['BS'], shuffle=True)


class PatchLSTMEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layer, sequence_length):
        super(PatchLSTMEncoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layer = num_layer
        self.sequence_length = sequence_length
        self.lstm1 = nn.LSTM(input_dim, 32, num_layer, batch_first=True)
        self.lstm2 = nn.LSTM(32, 64, num_layer, batch_first=True)
        self.lstm3 = nn.LSTM(64, 32, num_layer, batch_first=True)
        self.lstm4 = nn.LSTM(32, hidden_dim, num_layer, batch_first=True)
        self.relu = nn.LeakyReLU()

    def forward(self, x):
        batch_size, _, _, _ = x.size()
        x = x.unfold(2, params['patch_size'], params['patch_size']).unfold(3, params['patch_size'], params['patch_size'])
        x = x.contiguous().view(batch_size, sequence_length, -1)  # Reshape to (batch_size, sequence_length, patch_dim)
        lstm_out, _ = self.lstm1(x)
        #lstm_out = self.relu(lstm_out)
        lstm_out, _ = self.lstm2(lstm_out)
        #lstm_out = self.relu(lstm_out)
        lstm_out, _ = self.lstm3(lstm_out)
        #lstm_out = self.relu(lstm_out)
        lstm_out, _ = self.lstm4(lstm_out)
        return lstm_out

class PatchLSTMDecoder(nn.Module):
    def __init__(self, hidden_dim, output_dim, num_layer, sequence_length):
        super(PatchLSTMDecoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layer = num_layer
        self.sequence_length = sequence_length
        self.lstm1 = nn.LSTM(hidden_dim, 32, num_layer, batch_first=True)
        self.lstm2 = nn.LSTM(32, 64, num_layer, batch_first=True)
        self.lstm3 = nn.LSTM(64, 32, num_layer, batch_first=True)
        self.lstm4 = nn.LSTM(32, output_dim, num_layer, batch_first=True)
        self.relu = nn.LeakyReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        lstm_out, _ = self.lstm1(x)
        #lstm_out = self.relu(lstm_out)
        lstm_out, _ = self.lstm2(lstm_out)
        #lstm_out = self.relu(lstm_out)
        lstm_out, _ = self.lstm3(lstm_out)
        #lstm_out = self.relu(lstm_out)
        lstm_out, _ = self.lstm4(lstm_out)
        #lstm_out = self.sigmoid(lstm_out)
        batch_size = x.size(0)
        img_dim = int(self.sequence_length ** 0.5)
        output = lstm_out.contiguous().view(batch_size, 3, img_dim, img_dim)
        return output

# Model, Loss, Optimizer

class Model(nn.Module) :
    def __init__(self, input_dim, sequence_length):
        super(Model, self).__init__()
        self.input_dim = input_dim
        self.sequence_length = sequence_length

        self.encoder = PatchLSTMEncoder(input_dim, params['hidden_dim'], params['num_layer'], sequence_length)
        self.decoder = PatchLSTMDecoder(params['hidden_dim'], input_dim, params['num_layer'], sequence_length)
    def forward(self, x):

        encoded = self.encoder(x)
        decoded = self.decoder(encoded)

        return decoded


sequence_length = (32 // params['patch_size']) ** 2
input_dim = 3 * params['patch_size'] * params['patch_size']

model =Model(input_dim, sequence_length).to(device)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=params['LR'])

num_epochs = params['EP']

print("Model size : {}".format(count_parameters(model)))

for epoch in range(num_epochs):
    train_loss = 0

    model.train()
    timetemp = time.time()

    for images, _ in trainloader:
        images = images.to(device)
        optimizer.zero_grad()

        decoded = model(images)

        loss = criterion(decoded, images)
        train_loss += loss.item()

        loss.backward()
        optimizer.step()

    train_avg_loss = train_loss / len(trainloader)
    train_avg_psnr = round(10 * math.log10(1.0 / train_avg_loss), 3)

    #=========================================

    test_loss = 0
    model.eval()
    with torch.no_grad():
        for images, _ in testloader:
            images = images.to(device)

            decoded = model(images)

            loss = criterion(decoded, images)
            test_loss += loss.item()

        test_avg_loss = test_loss / len(testloader)
        test_avg_psnr = round(10 * math.log10(1.0 / test_avg_loss), 3)



    training_time = time.time() - timetemp

    print(
        "[{:>3}-Epoch({:>5}sec.)]  PSNR(Train / Val) : {:>6.4f} / {:>6.4f}".format(
            epoch + 1, round(training_time, 2), train_avg_psnr, test_avg_psnr))


print("Training completed.")
params = {

    'BS' : 64,
    'LR' : 0.008,
    'EP' : 3000,

    'patch_size' : 1,
    'hidden_dim' : 1,
    'num_layer' : 1

}

[ 12-Epoch(21.32sec.)]  PSNR(Train / Val) : 26.8620 / 26.9130
[ 13-Epoch(22.06sec.)]  PSNR(Train / Val) : 26.8960 / 26.9490
[ 14-Epoch(21.43sec.)]  PSNR(Train / Val) : 26.9480 / 27.0240
[ 15-Epoch(21.44sec.)]  PSNR(Train / Val) : 26.9880 / 27.0520
[ 16-Epoch(21.43sec.)]  PSNR(Train / Val) : 27.0260 / 27.0560
[ 17-Epoch(21.64sec.)]  PSNR(Train / Val) : 27.0620 / 27.0140
[ 18-Epoch(21.57sec.)]  PSNR(Train / Val) : 27.0990 / 27.1650
[ 19-Epoch(21.62sec.)]  PSNR(Train / Val) : 27.1120 / 27.1260
[ 20-Epoch(21.27sec.)]  PSNR(Train / Val) : 27.1720 / 27.2180
[ 21-Epoch(21.12sec.)]  PSNR(Train / Val) : 27.1860 / 27.2270
[ 22-Epoch(21.04sec.)]  PSNR(Train / Val) : 27.1550 / 27.2280
[ 23-Epoch(21.16sec.)]  PSNR(Train / Val) : 27.2400 / 27.2420
[ 24-Epoch(21.85sec.)]  PSNR(Train / Val) : 27.2370 / 27.1760
[ 25-Epoch(21.39sec.)]  PSNR(Train / Val) : 27.2660 / 27.3200
[ 26-Epoch(21.12sec.)]  PSNR(Train / Val) : 27.2370 / 27.2950
[ 27-Epoch(21.25sec.)]  PSNR(Train / Val) : 27.2820 / 27.3390
[ 28-Epoch(21.55sec.)]  PSNR(Train / Val) : 27.2790 / 27.3410

'LSTM' 카테고리의 다른 글

Image transmission using LSTM  (0) 2024.07.23
Comments