UOMOP

DeepJSCC(My code)_batch:32, lr:0.001, R:1/6 본문

Research/Semantic Communication

DeepJSCC(My code)_batch:32, lr:0.001, R:1/6

Happy PinGu 2023. 10. 25. 10:14
import cv2
import math
import torch
import torchvision
from fractions import Fraction
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as f
import matplotlib.pyplot as plt
import torchvision.transforms as tr
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

args = {
    'BATCH_SIZE' : 32,
    'LEARNING_RATE' : 0.001,
    'NUM_EPOCH' : 500,
    'SNRdB_list' : [0, 15, 30],
    'latent_dim' : 512,
    'input_dim' : 32*32
}

transf = tr.Compose([tr.ToTensor()])

trainset = torchvision.datasets.CIFAR10(root = './data', train = True, download = True, transform = transf)
testset = torchvision.datasets.CIFAR10(root = './data', train = False, download = True, transform = transf)

trainloader = DataLoader(trainset, batch_size = args['BATCH_SIZE'], shuffle = True)
testloader  = DataLoader(testset, batch_size = args['BATCH_SIZE'], shuffle = True)

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        c_hid = 32

        self.encoder = nn.Sequential(
            nn.Conv2d(3, c_hid, kernel_size=3, padding=1, stride=2),  # 32x32 => 16x16
            nn.ReLU(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2),  # 16x16 => 8x8
            nn.ReLU(),
            nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2),  # 8x8 => 4x4
            nn.ReLU(),
            nn.Flatten(),  # Image grid to single feature vector
            nn.Linear(2 * 16 * c_hid, args['latent_dim'])
        )

    def forward(self, x):
        encoded = self.encoder(x)

        return encoded

class Decoder(nn.Module):

    def __init__(self):
        super(Decoder, self).__init__()
        c_hid = 32

        self.linear = nn.Sequential(
            nn.Linear(args['latent_dim'], 2 * 16 * c_hid),
            nn.ReLU()
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(2 * c_hid, 2 * c_hid, kernel_size=3, output_padding=1, padding=1, stride=2),  # 4x4 => 8x8
            nn.ReLU(),

            nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
            nn.ReLU(),

            nn.ConvTranspose2d(2 * c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2),  # 8x8 => 16x16
            nn.ReLU(),

            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            nn.ReLU(),

            nn.ConvTranspose2d(c_hid, 3, kernel_size=3, output_padding=1, padding=1, stride=2),  # 16x16 => 32x32
        )

    def forward(self, x):
        x = self.linear(x)
        x = x.reshape(x.shape[0], -1, 4, 4)
        decoded = self.decoder(x)

        return decoded

class Autoencoder(nn.Module):
    def __init__(
            self,
            encoder_class: object = Encoder,
            decoder_class: object = Decoder
    ):
        super(Autoencoder, self).__init__()

        self.encoder = encoder_class()
        self.decoder = decoder_class()



    def AWGN(self, input, SNRdB):

        normalized_tensor = f.normalize(input, dim=1)

        SNR = 10.0 ** (SNRdB / 10.0)
        K = args['latent_dim']
        std = 1 / math.sqrt(K * SNR)
        n = torch.normal(0, std, size=normalized_tensor.size()).to(device)

        return normalized_tensor + n



    def forward(self, x, SNRdB):

        encoded = self.encoder(x)
        channel_output = self.AWGN(encoded, SNRdB)
        decoded = self.decoder(channel_output)

        return decoded

def model_train(trainloader, testloader) :

    for i in range( len(args['SNRdB_list']) ):

        model = Autoencoder().to(device)
        criterion = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr = args['LEARNING_RATE'])

        print("+++++ SNR = {} Training Start! +++++\t".format(args['SNRdB_list'][i]))

        for epoch in range(args['NUM_EPOCH']) :

#========================================== Train ==========================================
            train_loss = 0.0

            model.train()

            for data in trainloader :

                inputs = data[0].to(device)
                optimizer.zero_grad()
                outputs = model( inputs, SNRdB = args['SNRdB_list'][i])
                loss = criterion(inputs, outputs)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()

            train_cost = train_loss / len(trainloader)
#========================================== Test ==========================================
            test_loss = 0.0

            model.eval()

            for data in testloader :

                inputs = data[0].to(device)
                optimizer.zero_grad()
                outputs = model( inputs, SNRdB = args['SNRdB_list'][i])
                loss = criterion(inputs, outputs)
                test_loss += loss.item()

            test_cost = test_loss / len(testloader)

            # print("[{} Epoch] Train Loss : {}      Val. Loss : {}".format(epoch + 1, round(train_cost, 6), round(test_cost, 6)))
            print("[{} Epoch] Train PSNR : {}      Val. PSNR : {}".format(epoch + 1, round(10 * math.log10(1.0 / train_cost), 4), round(10 * math.log10(1.0 / test_cost), 4)))
            #print("[{} Epoch] Train PSNR : {}      Val. PSNR : {}".format(cv2.PSNR(inputs, outputs)))
        torch.save(model.state_dict(), "DeepJSCC(" + str(args['latent_dim'])+  ")_" + str(args['SNRdB_list'][i]) + "dB.pth")

model_train(trainloader, testloader)

def DeepJSCC_test(testloader) :

    psnr_list = []

    for i in range(len(args['SNRdB_list'])):

        criterion = nn.MSELoss()

        SNRdB = args['SNRdB_list'][i]
        model_location =  "DeepJSCC(" + str(args['latent_dim'])+  ")_" + str(SNRdB) + "dB.pth"
        model = Autoencoder().to(device)
        model.load_state_dict(torch.load(model_location))

        test_loss = 0.0

        for data in testloader :

            model.eval()

            inputs = data[0].to(device)
            outputs = model(inputs, SNRdB = SNRdB)
            loss = criterion(inputs, outputs)
            test_loss += loss.item()

        test_cost = test_loss / len(testloader)
        psnr_list.append(round(10*math.log10(1/test_cost), 4))

    plt.plot(args["SNRdB_list"], psnr_list, linestyle = 'dashed', color = 'blue', label = "AWGN")
    plt.grid(True)
    plt.legend()
    plt.title('batch size:' + str(args['BATCH_SIZE']) + '   ||   lr:' + str(args['LEARNING_RATE']) + '   ||   R:' + str(    Fraction(args['latent_dim'], int((args['input_dim'] * 3))) ))
    plt.xlabel('SNR(dB)')
    plt.ylabel('PSNR')
    plt.ylim([15, 40])
    plt.show()

    return psnr_list


import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
psnr_list= DeepJSCC_test(testloader)
print("PSNR data : {}".format(psnr_list))

Comments