UOMOP

***AutoEncoder cifar10(color) 1dB, 10dB, 20dB 본문

Wireless Comm./Python

***AutoEncoder cifar10(color) 1dB, 10dB, 20dB

Happy PinGu 2023. 6. 8. 20:25
import math
import torch
import random
import torchvision
import torch.nn as nn
from tqdm import tqdm
import torch.optim as optim
import torch.nn.functional as f
import matplotlib.pyplot as plt
import torchvision.transforms as tr
from torch.utils.data import DataLoader
import numpy as np

args = {
    'BATCH_SIZE' : 64,
    'LEARNING_RATE' : 0.001,
    'NUM_EPOCH' : 20,
    'SNR_dB' : [1, 10, 20],
    'latent_dim' : 100,
    'input_dim' : 32 * 32
}

transf = tr.Compose([tr.ToTensor(), tr.Normalize((0.5,), (0.5,))])

trainset = torchvision.datasets.CIFAR10(root = './data', train = True, download = True, transform = transf)
testset = torchvision.datasets.CIFAR10(root = './data', train = False, download = True, transform = transf)

trainloader = DataLoader(trainset, batch_size = args['BATCH_SIZE'], shuffle = True)
testloader  = DataLoader(testset, batch_size = args['BATCH_SIZE'], shuffle = False)


def transmit_img_AE(model_name, SNRdB, testloader) :

    model = Autoencoder()
    model.load_state_dict(torch.load(model_name))

    fig = plt.figure()

    batch_size = args['BATCH_SIZE']

    for data in testloader :

        inputs = data[0]

        outputs = model(inputs, SNRdB = SNRdB, Rayleigh = 0)

        break

    rand_num = random.randrange(batch_size)

    plt.subplot(1, 2, 1)
    plt.imshow( inputs[rand_num].permute(1, 2, 0).detach().numpy())
    plt.title('Original img')

    plt.subplot(1, 2, 2)
    plt.imshow(outputs[rand_num].permute(1, 2, 0).detach().numpy())
    plt.title('Decoded img')



def model_train(SNRdB_list, learning_rate, epoch_num, trainloader, Rayleigh) :

    for i in range( len(args['SNR_dB']) ):

        model = Autoencoder()
        criterion = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr = learning_rate)

        print("+++++ SNR = {} Training Start! +++++".format(args['SNR_dB'][i]))

        for epoch in range(epoch_num) :

            running_loss = 0.0

            for data in trainloader :

                inputs = data[0]
                optimizer.zero_grad()
                outputs = model( inputs, SNRdB = args['SNR_dB'][i], Rayleigh = Rayleigh)
                loss = criterion(inputs, outputs)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()

            cost = running_loss/len(trainloader)
            print("[{} Epoch] Loss : {}".format(epoch + 1, round(cost, 6)))
        print()
        PATH = "./"
        torch.save(model.state_dict(), PATH + "model_AWGN(color)" + "_SNR=" + str(args['SNR_dB'][i]) +".pth")




def PSNR_test(model_name, testloader, Rayleigh) :

    if model_name[-6] == '=' :
        SNRdB = int(model_name[-5])
    else :
        SNRdB = int(model_name[-6] + model_name[-5])

    model_1 = Autoencoder()
    model_1.load_state_dict(torch.load(model_name))

    PSNR_list = []

    for data in testloader :
        inputs = data[0]

        outputs = model_1( inputs, SNRdB = SNRdB, Rayleigh = Rayleigh)

        for j in range(len(inputs)) :
            PSNR = 0
            for k in range(3) :

                MSE = torch.sum(torch.pow(inputs[j][k] - outputs[j][k], 2)) / args['input_dim']
                PSNR += 10 * math.log10(1 / MSE)

            PSNR_list.append(PSNR / 3)
    print("PSNR : {}".format(round(sum(PSNR_list) / len(PSNR_list), 2)))

    return round(sum(PSNR_list) / len(PSNR_list), 2)




def compare_PSNR(SNRdB_list, model_name_without_SNRdB, testloader, Rayleigh) :

    PSNR_list = []

    for i in range(len(SNRdB_list)) :
        SNRdB = SNRdB_list[i]
        model_name = model_name_without_SNRdB + "_SNR=" + str(SNRdB) + ".pth"
        PSNR = PSNR_test(model_name, testloader, Rayleigh)
        PSNR_list.append(PSNR)

    plt.plot(args['SNR_dB'], PSNR_list, linestyle = 'dashed', color = 'blue', label = "AWGN")
    plt.grid(True)
    plt.legend()
    plt.show()



class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        c_hid = 32

        self.encoder = nn.Sequential(
            nn.Conv2d(3, c_hid, kernel_size=3, padding=1, stride=2),  # 32x32 => 16x16
            nn.ReLU(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2),  # 16x16 => 8x8
            nn.ReLU(),
            nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2),  # 8x8 => 4x4
            nn.ReLU(),
            nn.Flatten(),  # Image grid to single feature vector
            nn.Linear(2 * 16 * c_hid, args['latent_dim'])
        )

    def forward(self, x):
        return self.encoder(x)



class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        c_hid = 32

        self.linear = nn.Sequential(
            nn.Linear(args['latent_dim'], 2 * 16 * c_hid),
            nn.ReLU()
            )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(2 * c_hid, 2 * c_hid, kernel_size=3, output_padding=1, padding=1, stride=2),  # 4x4 => 8x8
            nn.ReLU(),
            nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(2 * c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2),  # 8x8 => 16x16
            nn.ReLU(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(c_hid, 3, kernel_size=3, output_padding=1, padding=1, stride=2),  # 16x16 => 32x32
        )

    def forward(self, x):

        x = self.linear(x)
        x = x.reshape(x.shape[0], -1, 4, 4)
        decoded = self.decoder(x)

        return decoded




class Autoencoder(nn.Module):
    def __init__(
            self,
            encoder_class: object = Encoder,
            decoder_class: object = Decoder
    ):
        super(Autoencoder, self).__init__()

        self.encoder = encoder_class()
        self.decoder = decoder_class()


    def add_AWGN2tensor(self, input, SNRdB):

        normalized_tensor = f.normalize(input, dim=1)

        SNR = 10.0 ** (SNRdB / 10.0)
        K = args['latent_dim']
        std = 1 / math.sqrt(K * SNR)
        noise = torch.normal(0, std, size=normalized_tensor.size())

        return input + noise



    def forward(self, x, SNRdB, Rayleigh):

        encoded = self.encoder(x)
        encoded_AWGN = self.add_AWGN2tensor(encoded, SNRdB)
        decoded = self.decoder(encoded_AWGN)

        return decoded


SNRdB_list = args['SNR_dB']
learning_rate = args['LEARNING_RATE']
epoch_num = args['NUM_EPOCH']
trainloader = trainloader
Rayleigh = 0
model_name_without_SNRdB = 'model_AWGN(color)'



model_train(SNRdB_list, learning_rate, epoch_num, trainloader, Rayleigh)

compare_PSNR(SNRdB_list, model_name_without_SNRdB, testloader, Rayleigh)

+++++ SNR = 1 Training Start! +++++
[1 Epoch] Loss : 0.084596
[2 Epoch] Loss : 0.042723
[3 Epoch] Loss : 0.037308
[4 Epoch] Loss : 0.03358
[5 Epoch] Loss : 0.03099
[6 Epoch] Loss : 0.027896
[7 Epoch] Loss : 0.026058
[8 Epoch] Loss : 0.025078
[9 Epoch] Loss : 0.024269
[10 Epoch] Loss : 0.023837
[11 Epoch] Loss : 0.023479
[12 Epoch] Loss : 0.023303
[13 Epoch] Loss : 0.023112
[14 Epoch] Loss : 0.022917
[15 Epoch] Loss : 0.02271
[16 Epoch] Loss : 0.022642
[17 Epoch] Loss : 0.022511
[18 Epoch] Loss : 0.0224
[19 Epoch] Loss : 0.022355
[20 Epoch] Loss : 0.022207

+++++ SNR = 10 Training Start! +++++
[1 Epoch] Loss : 0.084925
[2 Epoch] Loss : 0.044145
[3 Epoch] Loss : 0.036439
[4 Epoch] Loss : 0.032975
[5 Epoch] Loss : 0.030463
[6 Epoch] Loss : 0.027844
[7 Epoch] Loss : 0.02591
[8 Epoch] Loss : 0.024768
[9 Epoch] Loss : 0.023982
[10 Epoch] Loss : 0.023531
[11 Epoch] Loss : 0.023166
[12 Epoch] Loss : 0.023027
[13 Epoch] Loss : 0.022712
[14 Epoch] Loss : 0.022614
[15 Epoch] Loss : 0.022475
[16 Epoch] Loss : 0.022364
[17 Epoch] Loss : 0.022289
[18 Epoch] Loss : 0.022228
[19 Epoch] Loss : 0.022079
[20 Epoch] Loss : 0.022065

+++++ SNR = 20 Training Start! +++++
[1 Epoch] Loss : 0.082607
[2 Epoch] Loss : 0.041877
[3 Epoch] Loss : 0.036044
[4 Epoch] Loss : 0.03161
[5 Epoch] Loss : 0.029085
[6 Epoch] Loss : 0.027129
[7 Epoch] Loss : 0.025665
[8 Epoch] Loss : 0.024617
[9 Epoch] Loss : 0.024005
[10 Epoch] Loss : 0.023523
[11 Epoch] Loss : 0.023244
[12 Epoch] Loss : 0.022985
[13 Epoch] Loss : 0.022784
[14 Epoch] Loss : 0.022664
[15 Epoch] Loss : 0.02248
[16 Epoch] Loss : 0.022396
[17 Epoch] Loss : 0.022241
[18 Epoch] Loss : 0.022156
[19 Epoch] Loss : 0.022091
[20 Epoch] Loss : 0.022019

PSNR : 17.12
PSNR : 17.22
PSNR : 17.23

Comments