UOMOP

AWGN_cifar10_20230804 본문

Wireless Comm./CISL

AWGN_cifar10_20230804

Happy PinGu 2023. 8. 4. 14:53
######################## Library ########################

import cv2
import math
import torch
import torchvision
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as f
import matplotlib.pyplot as plt
import torchvision.transforms as tr
from torch.utils.data import DataLoader

######################## GPU Check ########################

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

######################## Control Tower ########################
args = {
    'BATCH_SIZE' : 64,
    'LEARNING_RATE' : 0.001,
    'NUM_EPOCH' : 500,
    'SNRdB_list' : [1, 10, 20, 30],
    'latent_dim' : 512,
    'input_dim' : 32 * 32,
    'filter_type' : 'scharr',
    'filter_dir' : 'x',
    'rl_pat' : 10
}

######################## Function ########################
def sobel_filter(img, x_or_y) :

    if x_or_y == 'x' :
        return cv2.Sobel(img, -1, 1, 0)
    else :
        return cv2.Sobel(img, -1, 0, 1)
def scharr_filter(img, x_or_y) :
    if x_or_y == 'x' :
        return cv2.Scharr(img, -1, 1, 0)
    else :
        return cv2.Scharr(img, -1, 0, 1)

def average_filter(img, kernel_size) :

    return cv2.blur(img, (kernel_size, kernel_size))
def gaussian_filter(img, kernel_size) :

    # cv2.GaussianBlur의 parameter에는 sigma가 존재하는데, 이것은 가우시안 커널의 X, Y 방향의 표준편차이다.
    return cv2.GaussianBlur(img, (kernel_size, kernel_size), 0)
def median_filter(img, kernel_size) :

    return cv2.medianBlur(img, kernel_size)
def bilateral_filter(img, kernel_size) :

    return cv2.bilateralFilter(img, kernel_size, 75, 75)

def model_train(SNRdB) :

    model = Autoencoder().to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr = args['LEARNING_RATE'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=args['rl_pat'], threshold=1e-3, verbose = True)

    print("\n\n+++++ SNR = {} Training Start! +++++\t".format(SNRdB))

    for epoch in range(args['NUM_EPOCH']) :

        train_MSE = 0.0
        val_MSE   = 0.0

        ############ Train ############

        for train_data in trainloader :

            inputs = train_data[0].to(device)

            optimizer.zero_grad()
            outputs = model(inputs, SNRdB = SNRdB)

            loss = criterion(inputs, outputs)
            loss.backward()
            optimizer.step()
            train_MSE += loss.item()

        train_MSE_average = train_MSE / len(trainloader)

        ############ Validation ############

        for val_data in testloader :
            inputs = val_data[0].to(device)
            outputs = model( inputs, SNRdB = SNRdB)
            loss = criterion(inputs, outputs)
            val_MSE += loss.item()

        val_MSE_average = val_MSE / len(testloader)

        scheduler.step(val_MSE_average)

        print("[{} Epoch] train_PSNR : {}\tvalidation_PSNR : {}"
              .format(epoch + 1, round(10*math.log10(1/train_MSE_average), 3), round(10*math.log10(1/val_MSE_average), 3)))

    torch.save(model.state_dict(), "./AWGN_" + "dim=" + str(args['latent_dim']) + "_SNRdB=" + str(SNRdB) +".pth")


######################## DataSet ########################
transf = tr.Compose([tr.ToTensor()])
# 원래 pytorch cifar10은 0~1사이의 값을 가진다.
# -1~1로 정규화를 시켜준다.

trainset = torchvision.datasets.CIFAR10(root = './data', train = True, download = True, transform = transf)
testset = torchvision.datasets.CIFAR10(root = './data', train = False, download = True, transform = transf)

trainloader = DataLoader(trainset, batch_size = args['BATCH_SIZE'], shuffle = True)
testloader  = DataLoader(testset, batch_size = args['BATCH_SIZE'], shuffle = True)


######################## Encoder ########################

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        c_hid = 32

        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size = 5, padding = 'valid', stride = 2),
            nn.PReLU(),

            nn.Conv2d(16, 32, kernel_size = 5, padding = 'valid', stride = 2),
            nn.PReLU(),

            nn.Conv2d(32, 32, kernel_size = 5, padding = 'same', stride = 1),
            nn.PReLU(),

            nn.Conv2d(32, 19, kernel_size = 5, padding = 'same', stride = 1),
            nn.PReLU(),

            nn.Conv2d(19, 19, kernel_size = 5, padding = 'same', stride = 1),
            nn.PReLU()

        )

    def forward(self, x):

        encoded = self.encoder(x)

        return encoded

######################## Decoder ########################

class Decoder(nn.Module):

    def __init__(self):
        super(Decoder, self).__init__()


        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(19, 32, kernel_size = 5, padding = 0, stride = 1),
            nn.PReLU(),

            nn.ConvTranspose2d(32, 32, kernel_size = 5, padding = 0, stride = 1),
            nn.PReLU(),

            nn.ConvTranspose2d(32, 32, kernel_size = 5, padding = 0, stride = 1),
            nn.PReLU(),

            nn.ConvTranspose2d(32, 16, kernel_size = 5, padding = 0, stride = 2),
            nn.PReLU(),

            nn.ConvTranspose2d(16, 3, kernel_size = 5, padding = 0, stride = 2),
            nn.Sigmoid()
        )

    def forward(self, x):
        decoded = self.decoder(x)
        #print("decoded shape : {}".format(decoded.size()))
        decoded_interpolated = f.interpolate(decoded, size=(32, 32), mode='bilinear', align_corners=False)
        #print("decoded_interpolated shape : {}".format(decoded_interpolated.size()))

        return decoded_interpolated

######################## AutoEncoder ########################
class Autoencoder(nn.Module):
    def __init__(
            self,
            encoder_class: object = Encoder,
            decoder_class: object = Decoder
    ):
        super(Autoencoder, self).__init__()

        self.encoder = encoder_class()
        self.decoder = decoder_class()

    def AWGN(self, input, SNRdB):

        normalized_tensor = f.normalize(input, dim=1)

        SNR = 10.0 ** (SNRdB / 10.0)
        K = args['latent_dim']
        std = 1 / math.sqrt(K * SNR)
        n = torch.normal(0, std, size=normalized_tensor.size()).to(device)

        return normalized_tensor + n



    def forward(self, x, SNRdB):

        #print("input size : {}".format(x.size()))

        encoded = self.encoder(x)
        #print("encoded size : {}".format(encoded.size()))
        before_shape = encoded.size()

        encoded_flatten = torch.flatten(encoded, 1)
        #print("encoded_flatten size : {}".format(encoded_flatten.size()))


        encoded_AWGN = self.AWGN(encoded_flatten, SNRdB)
        #print("encoded_AWGN size : {}".format(encoded_AWGN.size()))

        reshaped_encoded = encoded_AWGN.view(before_shape)
        #print("reshaped_encoded size : {}".format(reshaped_encoded.size()))


        decoded = self.decoder(reshaped_encoded)
        #print("decoded size : {}".format(decoded.size()))

        return decoded


######################## Train ########################

for i in range( len(args['SNRdB_list'])) :

    model_train(args['SNRdB_list'][i])


def psnr_test() :

    psnr_list = []

    criterion = nn.MSELoss()

    for i in range(len(args['SNRdB_list'])) :

        SNRdB = args['SNRdB_list'][i]
        model_name = "AWGN_" + "dim=" + str(args['latent_dim']) + "_SNRdB=" + str(args['SNRdB_list'][i]) +".pth"
        model = Autoencoder().to(device)
        model.load_state_dict(torch.load(model_name))

        test_MSE = 0.0

        for test_data in testloader :

            inputs = test_data[0].to(device)
            outputs = model( inputs, SNRdB = SNRdB)
            loss = criterion(inputs, outputs)
            test_MSE += loss.item()

        test_MSE_average = test_MSE / len(testloader)

        psnr_list.append(10*math.log10(1/test_MSE_average))

    return psnr_list


psnr_list = psnr_test()

import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
plt.plot(args['SNRdB_list'], psnr_list, linestyle = 'dashed', color = 'blue', label = "AWGN")
plt.grid(True)
plt.legend()
plt.ylim([10, 40])
plt.show()

print(psnr_list)

'Wireless Comm. > CISL' 카테고리의 다른 글

이거 성능 왜 높게 나옴?  (0) 2023.08.07
Custom Dataset(Gaussian Filtered Cifar 10)  (0) 2023.08.04
DeepJSCC  (0) 2023.08.02
AWGN(pi=1024 si=0)  (0) 2023.07.28
AWGN(pi=64, si=0)  (0) 2023.07.27
Comments