UOMOP

DeepJSCC 본문

Wireless Comm./CISL

DeepJSCC

Happy PinGu 2023. 8. 2. 22:06
import cv2
import math
import time
import torch
import random
import torchvision
import numpy as np
from PIL import Image
import torch.nn as nn
from numpy import sqrt
import torch.optim as optim
import torch.nn.functional as f
import matplotlib.pyplot as plt
import torchvision.transforms as tr
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

args = {
    'BATCH_SIZE' : 64,
    'LEARNING_RATE' : 0.001,
    'EPOCH' : 200,
    'SNRdB_list' : [0, 10, 20, 30],
    'input_dim' : 32 * 32,
    'rl_pat' : 7
}

transf = tr.Compose([tr.ToTensor()])

trainset = torchvision.datasets.CIFAR10(root = './data', train = True, download = True, transform = transf)
testset = torchvision.datasets.CIFAR10(root = './data', train = False, download = True, transform = transf)

trainloader = DataLoader(trainset, batch_size = args['BATCH_SIZE'], shuffle = True)
testloader  = DataLoader(testset, batch_size = args['BATCH_SIZE'], shuffle = True)

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size = 5, padding = 'valid', stride = 2),
            nn.PReLU(),

            nn.Conv2d(16, 32, kernel_size = 5, padding = 'valid', stride = 2),
            nn.PReLU(),

            nn.Conv2d(32, 32, kernel_size = 5, padding = 'same', stride = 1),
            nn.PReLU(),

            nn.Conv2d(32, 19, kernel_size = 5, padding = 'same', stride = 1),
            nn.PReLU(),

            nn.Conv2d(19, 19, kernel_size = 5, padding = 'same', stride = 1),
            nn.PReLU()

        )

    def forward(self, x):

        encoded = self.encoder(x)

        return encoded

class Decoder(nn.Module):

    def __init__(self):
        super(Decoder, self).__init__()

        self.decoder = nn.Sequential(

            nn.ConvTranspose2d(19, 32, kernel_size = 5, padding = 0, stride = 1),
            nn.PReLU(),

            nn.ConvTranspose2d(32, 32, kernel_size = 5, padding = 0, stride = 1),
            nn.PReLU(),

            nn.ConvTranspose2d(32, 32, kernel_size = 5, padding = 0, stride = 1),
            nn.PReLU(),

            nn.ConvTranspose2d(32, 16, kernel_size = 5, padding = 0, stride = 2),
            nn.PReLU(),

            nn.ConvTranspose2d(16, 3, kernel_size = 5, padding = 0, stride = 2),
            nn.Sigmoid()

        )


#  nn.Conv2d에는 torch 1.10.0 이후에 padding으로 'same', 'valid'를 줄 수 있었지만, nn.ConvTranspose2d는 padding으로 'same', 'valid'를 줄 수 없다.
#  padding_height = [strides[1] * (in_height - 1) + kernel_size[0] - out_height] / 2
#  padding_width  = [[strides[2] * (in_width - 1) + kernel_size[1] - out_width] / 2

    def forward(self, x):

        decoded = self.decoder(x)
        #print("decoded shape : {}".format(decoded.size()))
        decoded_interpolated = f.interpolate(decoded, size=(32, 32), mode='bilinear', align_corners=False)
        #print("decoded_interpolated shape : {}".format(decoded_interpolated.size()))

        return decoded_interpolated

class Autoencoder(nn.Module):
    def __init__(
            self,
            encoder_class: object = Encoder,
            decoder_class: object = Decoder
    ):
        super(Autoencoder, self).__init__()

        self.encoder = encoder_class()
        self.decoder = decoder_class()


    def AWGN(self, input, SNRdB):

        normalized_tensor = f.normalize(input, dim=1)

        SNR = 10.0 ** (SNRdB / 10.0)
        K = 475
        std = 1 / math.sqrt(K * SNR)
        n = torch.normal(0, std, size=normalized_tensor.size()).to(device)

        return normalized_tensor + n


    def forward(self, x, SNRdB):
        #print("input : {}".format(x.size()))

        encoded = self.encoder(x)
        #print("encoded : {}".format(encoded.size()))

        shape = encoded.size()
        #print("encoded_shape : {}".format(shape))

        encoded_flatten = torch.flatten(encoded, start_dim=1)
        #print("encoded_flatten : {}".format(encoded_flatten.size()))

        encoded_AWGN = self.AWGN(encoded_flatten, SNRdB)
        #print("After AWGN : {}".format(encoded_AWGN.size()))

        encoded_AWGN_reshaped = encoded_AWGN.view(shape)
        #print("encoded_reshaped : {}".format(encoded_AWGN_reshaped.size()))

        decoded = self.decoder(encoded_AWGN_reshaped)
        #print("decoded size : {}".format(decoded.size()))

        return decoded

def model_train(trainloader, testloader) :

    for i in range( len(args['SNRdB_list']) ):

        model = Autoencoder().to(device)
        criterion = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr = args['LEARNING_RATE'])
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=args['rl_pat'],
                                                         threshold=1e-3, verbose=True)

        print("+++++ SNR = {} Training Start! +++++\t".format(args['SNRdB_list'][i]))

        for epoch in range(args['EPOCH']) :

#========================================== Train ==========================================
            train_loss = 0.0

            for data in trainloader :

                inputs = data[0].to(device)
                #print(111)
                #print(inputs.size())
                optimizer.zero_grad()
                outputs = model( inputs, SNRdB = args['SNRdB_list'][i])
                #print(111)
                #print(outputs.size())
                loss = criterion(inputs, outputs)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()

            train_cost = train_loss / len(trainloader)

#========================================== Test ==========================================
            test_loss = 0.0

            for data in testloader :

                inputs = data[0].to(device)
                optimizer.zero_grad()
                outputs = model( inputs, SNRdB = args['SNRdB_list'][i])
                loss = criterion(inputs, outputs)
                test_loss += loss.item()

            test_cost = test_loss / len(testloader)

#========================================== print(Metric) ==========================================

            scheduler.step(test_cost)

            #print("[{} Epoch] Train Loss : {}      Val. Loss : {}".format(epoch + 1, round(train_cost, 6), round(test_cost, 6)))
            print("[{} Epoch] Train PSNR : {}      Val. PSNR : {}".format(epoch + 1, round(10*math.log10(1/train_cost), 4), round(10*math.log10(1/test_cost), 4)))

        torch.save(model.state_dict(), "DeepJSCC_"+ str(args['SNRdB_list'][i])+"dB.pth")

#model_train(trainloader, testloader)

def DeepJSCC_test(testloader) :

    psnr_list = []

    for i in range(len(args['SNRdB_list'])):

        criterion = nn.MSELoss()

        SNRdB = args['SNRdB_list'][i]
        model_location =  "DeepJSCC_"+ str(SNRdB)+"dB.pth"
        model = Autoencoder().to(device)
        model.load_state_dict(torch.load(model_location))

        test_loss = 0.0

        for data in testloader :

            inputs = data[0].to(device)
            outputs = model(inputs, SNRdB = SNRdB)
            loss = criterion(inputs, outputs)
            test_loss += loss.item()

        test_cost = test_loss / len(testloader)
        psnr_list.append(round(10*math.log10(1/test_cost), 4))

    plt.plot(args["SNRdB_list"], psnr_list, linestyle = 'dashed', color = 'blue', label = "AWGN")
    plt.grid(True)
    plt.legend()
    plt.show()
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
DeepJSCC_test(testloader)

성능 잘 안나옴.
Dimension 다시 확인해봐야함.

'Wireless Comm. > CISL' 카테고리의 다른 글

Custom Dataset(Gaussian Filtered Cifar 10)  (0) 2023.08.04
AWGN_cifar10_20230804  (0) 2023.08.04
AWGN(pi=1024 si=0)  (0) 2023.07.28
AWGN(pi=64, si=0)  (0) 2023.07.27
AWGN(pi=128, si=0)  (0) 2023.07.27
Comments