UOMOP

Proposed net 본문

DE/Code

Proposed net

Happy PinGu 2024. 8. 6. 10:52
import math
import torch
import torchvision
from fractions import Fraction
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as f
import matplotlib.pyplot as plt
import torchvision.transforms as tr
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
import time
import os
from skimage.metrics import structural_similarity as ssim
from params import *

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


transf = tr.Compose([tr.ToTensor()])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transf)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transf)

trainloader = DataLoader(trainset, batch_size=params['BS'], shuffle=True)
testloader = DataLoader(testset, batch_size=params['BS'], shuffle=True)


class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=1),
            nn.PReLU(),
            # Output: [batch, 16, 30, 30]

            nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=1),
            nn.PReLU(),
            # Output: [batch, 16, 28, 28]

            nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=1),
            nn.PReLU(),
            # Output: [batch, 16, 26, 26]

            nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=1),
            nn.PReLU(),
            # Output: [batch, 16, 24, 24]

            nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=1),
            nn.PReLU(),
            # Output: [batch, 16, 22, 22]
            nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=1),
            nn.PReLU(),
            # Output: [batch, 16, 20, 20]

            nn.Flatten(),
            nn.Linear(6400, self.latent_dim),


        )

    def forward(self, x):
        encoded = self.encoder(x)
        return encoded


class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.latent_dim = latent_dim
        self.decoder = nn.Sequential(
            nn.Linear(self.latent_dim, 6400),
            nn.PReLU(),
            nn.Unflatten(1, (16, 20, 20)),
            nn.ConvTranspose2d(16, 16, kernel_size=5, stride=1, padding=2),
            nn.PReLU(),
            # Output: [batch, 16, 20, 20]

            nn.ConvTranspose2d(16, 16, kernel_size=5, stride=1, padding=1),
            nn.PReLU(),
            # Output: [batch, 16, 22, 22]

            nn.ConvTranspose2d(16, 16, kernel_size=5, stride=1, padding=1),
            nn.PReLU(),
            # Output: [batch, 16, 24, 24]

            nn.ConvTranspose2d(16, 16, kernel_size=5, stride=1, padding=1),
            nn.PReLU(),
            # Output: [batch, 16, 26, 26]

            nn.ConvTranspose2d(16, 16, kernel_size=5, stride=1, padding=1),
            nn.PReLU(),
            # Output: [batch, 16, 28, 28]

            nn.ConvTranspose2d(16, 16, kernel_size=5, stride=1, padding=1),
            nn.PReLU(),
            # Output: [batch, 16, 30, 30]

            nn.ConvTranspose2d(16, 3, kernel_size=5, stride=1, padding=1),
            nn.Sigmoid(),
            # Output: [batch, 3, 32, 32]
        )

    def forward(self, x):
        decoded = self.decoder(x)
        return decoded


class Autoencoder(nn.Module):
    def __init__(self, latent_dim):
        super(Autoencoder, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def Power_norm(self, z, P=1 / np.sqrt(2)):

        batch_size, z_dim = z.shape
        z_power = torch.sqrt(torch.sum(z ** 2, 1))
        z_M = z_power.repeat(z_dim, 1)

        return np.sqrt(P * z_dim) * z / z_M.t()

    def Power_norm_complex(self, z, P=1 / np.sqrt(2)):

        batch_size, z_dim = z.shape
        z_com = torch.complex(z[:, 0:z_dim:2], z[:, 1:z_dim:2])
        z_com_conj = torch.complex(z[:, 0:z_dim:2], -z[:, 1:z_dim:2])
        z_power = torch.sum(z_com * z_com_conj, 1).real
        z_M = z_power.repeat(z_dim // 2, 1)
        z_nlz = np.sqrt(P * z_dim) * z_com / torch.sqrt(z_M.t())
        z_out = torch.zeros(batch_size, z_dim).to(device)
        z_out[:, 0:z_dim:2] = z_nlz.real
        z_out[:, 1:z_dim:2] = z_nlz.imag

        return z_out

    def AWGN_channel(self, x, snr, P=1):
        batch_size, length = x.shape
        gamma = 10 ** (snr / 10.0)
        noise = np.sqrt(P / gamma) * torch.randn(batch_size, length).cuda()
        y = x + noise
        return y

    def Fading_channel(self, x, snr, P=1):

        gamma = 10 ** (snr / 10.0)
        [batch_size, feature_length] = x.shape
        K = feature_length // 2

        h_I = torch.randn(batch_size, K).to(device)
        h_R = torch.randn(batch_size, K).to(device)
        h_com = torch.complex(h_I, h_R)
        x_com = torch.complex(x[:, 0:feature_length:2], x[:, 1:feature_length:2])
        y_com = h_com * x_com

        n_I = np.sqrt(P / gamma) * torch.randn(batch_size, K).to(device)
        n_R = np.sqrt(P / gamma) * torch.randn(batch_size, K).to(device)
        noise = torch.complex(n_I, n_R)

        y_add = y_com + noise
        y = y_add / h_com

        y_out = torch.zeros(batch_size, feature_length).to(device)
        y_out[:, 0:feature_length:2] = y.real
        y_out[:, 1:feature_length:2] = y.imag

        return y_out

    def forward(self, x, SNRdB, channel):

        encoded = self.encoder(x)

        if channel == 'AWGN':
            normalized_x = self.Power_norm(encoded)
            channel_output = self.AWGN_channel(normalized_x, SNRdB)
        elif channel == 'Rayleigh':
            normalized_complex_x = self.Power_norm_complex(encoded)
            channel_output = self.Fading_channel(normalized_complex_x, SNRdB)

        decoded = self.decoder(channel_output)

        return decoded


def model_train(trainloader, testloader, latent_dims):
    for latent_dim in latent_dims:
        for snr_i in range(len(params['SNR'])):

            model = Autoencoder(latent_dim=latent_dim).to(device)
            print("Model size : {}".format(count_parameters(model)))
            criterion = nn.MSELoss()
            optimizer = optim.Adam(model.parameters(), lr=params['LR'])
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)

            min_test_cost = float('inf')
            epochs_no_improve = 0
            n_epochs_stop = 43

            print("+++++ SNR = {} Training Start! +++++\t".format(params['SNR'][snr_i]))

            max_psnr = 0
            previous_best_model_path = None

            for epoch in range(params['EP']):

                # ========================================== Train ==========================================
                train_loss = 0.0

                model.train()
                timetemp = time.time()

                for data in trainloader:
                    inputs = data[0].to(device)
                    optimizer.zero_grad()
                    outputs = model(inputs, SNRdB=params['SNR'][snr_i], channel=params['channel'])
                    loss = criterion(inputs, outputs)
                    loss.backward()
                    optimizer.step()
                    train_loss += loss.item()

                train_cost = train_loss / len(trainloader)
                tr_psnr = round(10 * math.log10(1.0 / train_cost), 3)
                # ========================================== Test ==========================================
                test_loss = 0.0

                model.eval()
                with torch.no_grad():
                    for data in testloader:
                        inputs = data[0].to(device)

                        outputs = model(inputs, SNRdB=params['SNR'][snr_i], channel=params['channel'])
                        loss = criterion(inputs, outputs)
                        test_loss += loss.item()

                    test_cost = test_loss / len(testloader)
                    test_psnr = round(10 * math.log10(1.0 / test_cost), 3)

                scheduler.step(test_cost)

                if test_cost < min_test_cost:
                    min_test_cost = test_cost
                    epochs_no_improve = 0
                else:
                    epochs_no_improve += 1
                if epochs_no_improve == n_epochs_stop:
                    print("Early stopping!")
                    break  # 조기 종료

                training_time = time.time() - timetemp

                print(
                    "[{:>3}-Epoch({:>5}sec.)]  PSNR(Train / Val) : {:>6.4f} / {:>6.4f}  ".format(epoch + 1, round(training_time, 2), tr_psnr, test_psnr))

                if test_psnr > max_psnr:

                    save_folder = 'trained_model'

                    if not os.path.exists(save_folder):
                        os.makedirs(save_folder)
                    previous_psnr = max_psnr
                    max_psnr = test_psnr

                    # 이전 최고 성능 모델이 있다면 삭제
                    if previous_best_model_path is not None:
                        os.remove(previous_best_model_path)
                        print(f"Performance update!! {previous_psnr} to {max_psnr}")

                    save_path = os.path.join(save_folder,
                                             f"DeepJSCC(DIM={latent_dim}_SNR={params['SNR'][snr_i]}_PSNR={max_psnr}).pt")
                    torch.save(model, save_path)
                    print(f"Saved new best model at {save_path}")

                    previous_best_model_path = save_path


model_train(trainloader, testloader, params['DIM'])
params = {
    'BS': 64,
    'LR': 0.0005,
    'EP': 1000,
    'SNR': [40, 20],
    'DIM': [1024, 512, 256],
    'channel' : 'Rayleigh'
}

'DE > Code' 카테고리의 다른 글

Adaptive decoder  (0) 2024.08.06
Adaptive encoder  (0) 2024.08.06
Filter counting of zero padding  (0) 2024.08.06
Patch selection code (CBS)  (0) 2024.08.05
Image (variance , entropy , edge)  (0) 2024.08.05
Comments