UOMOP

DeepJSCC performance ( DIM = 768, 1536, 2304 ) 본문

Main

DeepJSCC performance ( DIM = 768, 1536, 2304 )

Happy PinGu 2024. 5. 17. 17:39
import math
import torch
import torchvision
from fractions import Fraction
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as f
import matplotlib.pyplot as plt
import torchvision.transforms as tr
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
import time
import os
from params import *

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

transf = tr.Compose([tr.ToTensor()])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transf)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transf)

trainloader = DataLoader(trainset, batch_size=params['BS'], shuffle=True)
testloader = DataLoader(testset, batch_size=params['BS'], shuffle=True)

class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=5, stride=2, padding=2),  # Output: 16x16
            nn.PReLU(),
            nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=2),  # Output: 8x8
            nn.PReLU(),
            nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=2),  # Output: 4x4
            nn.PReLU(),
            nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2),  # Output: 4x4
            nn.PReLU(),
            nn.Flatten(),
            nn.Linear(4 * 4 * 64, self.latent_dim),

        )

    def forward(self, x):

        output = self.encoder(x)

        return output


class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.latent_dim = latent_dim

        self.decoder = nn.Sequential(
            nn.Linear(self.latent_dim, 4 * 4 * 64),
            nn.PReLU(),
            nn.Unflatten(1, (64, 4, 4)),
            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=1, padding=2),  # Output: 4x4
            nn.PReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=5, stride=2, padding=2, output_padding=1),  # Output: 8x8
            nn.PReLU(),
            nn.Conv2d(32, 32, kernel_size=5, padding=2),
            nn.PReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=5, stride=2, padding=2, output_padding=1),  # Output: 16x16
            nn.PReLU(),
            nn.Conv2d(16, 16, kernel_size=5, padding=2),
            nn.PReLU(),
            nn.ConvTranspose2d(16, 3, kernel_size=5, stride=2, padding=2, output_padding=1),  # Output: 32x32
            nn.Sigmoid()  # Sigmoid activation at the output
        )

    def forward(self, x):
        x = self.decoder(x)
        return x

class Autoencoder(nn.Module):
    def __init__(
            self,
            latent_dim,
            encoder_class: object = Encoder,
            decoder_class: object = Decoder
    ):
        super(Autoencoder, self).__init__()

        self.latent_dim = latent_dim
        self.encoder = encoder_class(latent_dim=latent_dim)
        self.decoder = decoder_class(latent_dim=latent_dim)

    def AWGN(self, input, SNRdB):
        normalized_tensor = f.normalize(input, dim=1)

        SNR = 10.0 ** (SNRdB / 10.0)
        std = 1 / math.sqrt(self.latent_dim * SNR)
        n = torch.normal(0, std, size=normalized_tensor.size()).to(device)

        return normalized_tensor + n


    def forward(self, x, SNRdB):
        encoded = self.encoder(x)
        channel_output = self.AWGN(encoded, SNRdB)
        decoded = self.decoder(channel_output)

        return decoded


def model_train(trainloader, testloader, latent_dims):

    for latent_dim in latent_dims :
        for snr_i in range(len(params['SNR'])):

            model = Autoencoder(latent_dim=latent_dim, encoder_class=Encoder, decoder_class=Decoder).to(device)
            print("Model size : {}".format(count_parameters(model)))
            criterion = nn.MSELoss()
            optimizer = optim.Adam(model.parameters(), lr=params['LR'])

            min_test_cost = float('inf')
            epochs_no_improve = 0
            n_epochs_stop = 20

            print("+++++ SNR = {} Training Start! +++++\t".format(params['SNR'][snr_i]))

            max_psnr = 0
            previous_best_model_path = None

            for epoch in range(params['EP']):

                # ========================================== Train ==========================================
                train_loss = 0.0

                model.train()
                timetemp = time.time()

                for data in trainloader:

                    inputs = data[0].to(device)
                    optimizer.zero_grad()
                    outputs = model(inputs, SNRdB=params['SNR'][snr_i])
                    loss = criterion(inputs, outputs)
                    loss.backward()
                    optimizer.step()
                    train_loss += loss.item()

                train_cost = train_loss / len(trainloader)
                tr_psnr = round(10 * math.log10(1.0 / train_cost), 3)
                # ========================================== Test ==========================================
                test_loss = 0.0

                model.eval()
                with torch.no_grad():
                    for data in testloader:
                        inputs = data[0].to(device)

                        outputs = model(inputs, SNRdB=params['SNR'][snr_i])
                        loss = criterion(inputs, outputs)
                        test_loss += loss.item()

                    test_cost = test_loss / len(testloader)
                    val_psnr = round(10 * math.log10(1.0 / test_cost), 3)

                if test_cost < min_test_cost:
                    min_test_cost = test_cost
                    epochs_no_improve = 0
                else:
                    epochs_no_improve += 1

                if epochs_no_improve == n_epochs_stop:
                    print("Early stopping!")
                    break  # 조기 종료

                training_time = time.time() - timetemp

                print(
                    "[{}-Epoch({}sec.)]  Train PNSR : {:.4f}\tVal PSNR : {:.4f}".format(epoch + 1, round(training_time, 2),tr_psnr, val_psnr))

                if val_psnr > max_psnr:

                    save_folder = 'trained_model'

                    if not os.path.exists(save_folder):
                        os.makedirs(save_folder)
                    previous_psnr = max_psnr
                    max_psnr = val_psnr

                    # 이전 최고 성능 모델이 있다면 삭제
                    if previous_best_model_path is not None:
                        os.remove(previous_best_model_path)
                        print(f"Performance update!! {previous_psnr} to {max_psnr}")

                    save_path = os.path.join(save_folder,
                                             f"DeepJSCC(DIM={latent_dim}_SNR={params['SNR'][snr_i]}_PSNR={max_psnr}).pt")
                    torch.save(model, save_path)
                    print(f"Saved new best model at {save_path}")

                    previous_best_model_path = save_path


model_train(trainloader, testloader, params['DIM'])

'Main' 카테고리의 다른 글

Patch importance  (0) 2024.06.03
Matlab code for PSNR performance comparison  (0) 2024.05.24
Image reconstruction with CBM  (0) 2024.05.03
Patch complexity calculated region extending  (0) 2024.05.03
ChessBoard Masking with Colored Random Noise  (0) 2024.05.02
Comments