UOMOP

Basic AutoEncoder using Cifar10 본문

Wireless Comm./CISL

Basic AutoEncoder using Cifar10

Happy PinGu 2023. 6. 10. 22:50
import math
import torch
import random
import torchvision
import torch.nn as nn
from tqdm import tqdm
import torch.optim as optim
import torch.nn.functional as f
import matplotlib.pyplot as plt
import torchvision.transforms as tr
from torch.utils.data import DataLoader
import numpy as np

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(device)

args = {
    'BATCH_SIZE' : 50,
    'LEARNING_RATE' : 0.001,
    'NUM_EPOCH' : 80,
    'SNR_dB' : [1, 10, 20],
    'latent_dim' : 200,
    'input_dim' : 32 * 32
}

transf = tr.Compose([tr.ToTensor(), tr.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root = './data', train = True, download = True, transform = transf)
testset = torchvision.datasets.CIFAR10(root = './data', train = False, download = True, transform = transf)

trainloader = DataLoader(trainset, batch_size = args['BATCH_SIZE'], shuffle = True)
testloader  = DataLoader(testset, batch_size = args['BATCH_SIZE'], shuffle = False)

def transmit_img_AE(model_name, SNRdB, testloader) :

    model = Autoencoder()
    model.load_state_dict(torch.load(model_name))

    fig = plt.figure()

    batch_size = args['BATCH_SIZE']

    for data in testloader :

        inputs = data[0]

        outputs = model(inputs, SNRdB = SNRdB, Rayleigh = 0)

        break

    rand_num = random.randrange(batch_size)

    plt.subplot(1, 2, 1)
    plt.imshow( inputs[rand_num].permute(1, 2, 0).detach().numpy())
    plt.title('Original img')

    plt.subplot(1, 2, 2)
    plt.imshow(outputs[rand_num].permute(1, 2, 0).detach().numpy())
    plt.title('Decoded img')




def PSNR_test(model_name, testloader, Rayleigh) :

    if model_name[-6] == '=' :
        SNRdB = int(model_name[-5])
    else :
        SNRdB = int(model_name[-6] + model_name[-5])

    model_1 = Autoencoder().to(device)
    model_1.load_state_dict(torch.load(model_name))

    PSNR_list = []

    for data in testloader :
        inputs = data[0].to(device)

        outputs = model_1( inputs, SNRdB = SNRdB, Rayleigh = Rayleigh)


        for j in range(len(inputs)) :
            PSNR = 0
            for k in range(3) :

                MSE = torch.sum(torch.pow(inputs[j][k] - outputs[j][k], 2)) / args['input_dim']
                PSNR += 10 * math.log10(1 / MSE)

            PSNR_list.append(PSNR / 3)
    print("PSNR : {}".format(round(sum(PSNR_list) / len(PSNR_list), 2)))

    return round(sum(PSNR_list) / len(PSNR_list), 2)





def compare_PSNR(SNRdB_list, model_name_without_SNRdB, testloader, Rayleigh) :

    PSNR_list = []

    for i in range(len(SNRdB_list)) :
        SNRdB = SNRdB_list[i]
        model_name = model_name_without_SNRdB + "_SNR=" + str(SNRdB) + ".pth"
        PSNR = PSNR_test(model_name, testloader, Rayleigh)
        PSNR_list.append(PSNR)

    plt.plot(args['SNR_dB'], PSNR_list, linestyle = 'dashed', color = 'blue', label = "AWGN")
    plt.grid(True)
    plt.legend()
    plt.show()





def model_train(SNRdB_list, learning_rate, epoch_num, trainloader, Rayleigh) :

    for i in range( len(args['SNR_dB']) ):

        model = Autoencoder().to(device)
        criterion = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr = learning_rate)

        print("+++++ SNR = {} Training Start! +++++".format(args['SNR_dB'][i]))

        for epoch in range(epoch_num) :

            running_loss = 0.0

            for data in trainloader :
                
                inputs = data[0].to(device)

                optimizer.zero_grad()
                outputs = model( inputs, SNRdB = args['SNR_dB'][i], Rayleigh = Rayleigh)
                loss = criterion(inputs, outputs)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()

            cost = running_loss/len(trainloader)

            print("[{} Epoch] Loss : {}".format(epoch + 1, round(cost, 6)))
        print()
        PATH = "./"
        torch.save(model.state_dict(), PATH + "model_AWGN(color)" + "_SNR=" + str(args['SNR_dB'][i]) +".pth")
        
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        c_hid = 32

        self.encoder = nn.Sequential(
            nn.Conv2d(3, c_hid, kernel_size=3, padding=1, stride=2),  # 32x32 => 16x16
            nn.ReLU(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2),  # 16x16 => 8x8
            nn.ReLU(),
            nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2),  # 8x8 => 4x4
            nn.ReLU(),
            nn.Flatten(),  # Image grid to single feature vector
            nn.Linear(2 * 16 * c_hid, args['latent_dim'])
        )

    def forward(self, x):
        return self.encoder(x)




class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        c_hid = 32

        self.linear = nn.Sequential(
            nn.Linear(args['latent_dim'], 2 * 16 * c_hid),
            nn.ReLU()
            )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(2 * c_hid, 2 * c_hid, kernel_size=3, output_padding=1, padding=1, stride=2),  # 4x4 => 8x8
            nn.ReLU(),
            nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(2 * c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2),  # 8x8 => 16x16
            nn.ReLU(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(c_hid, 3, kernel_size=3, output_padding=1, padding=1, stride=2),  # 16x16 => 32x32
        )

    def forward(self, x):

        x = self.linear(x)
        x = x.reshape(x.shape[0], -1, 4, 4)
        decoded = self.decoder(x)

        return decoded



class Autoencoder(nn.Module):
    def __init__(
            self,
            encoder_class: object = Encoder,
            decoder_class: object = Decoder
    ):
        super(Autoencoder, self).__init__()

        self.encoder = encoder_class()
        self.decoder = decoder_class()


    def add_AWGN2tensor(self, input, SNRdB):

        normalized_tensor = f.normalize(input, dim=1)

        SNR = 10.0 ** (SNRdB / 10.0)
        K = args['latent_dim']
        std = 1 / math.sqrt(K * SNR)
        noise = torch.normal(0, std, size=normalized_tensor.size()).to(device)


        return normalized_tensor + noise


    def forward(self, x, SNRdB, Rayleigh):

        encoded = self.encoder(x)
        #print("encoded : {}".format(encoded[0]))
        #print("encoded size : {}".format(encoded.size()))

        encoded_AWGN = self.add_AWGN2tensor(encoded, SNRdB)
        #print("encoded_AWGN size : {}".format(encoded_AWGN.size()))        

        decoded = self.decoder(encoded_AWGN)
        #print("decoded : {}".format(decoded[0]))
        #print("decoded size : {}".format(decoded.size()))

        return decoded
        
SNRdB_list = args['SNR_dB']
learning_rate = args['LEARNING_RATE']
epoch_num = args['NUM_EPOCH']
trainloader = trainloader
Rayleigh = 0
model_name_without_SNRdB = 'model_AWGN(color)'

model_train(SNRdB_list, learning_rate, epoch_num, trainloader, Rayleigh)
+++++ SNR = 1 Training Start! +++++ [1 Epoch] Loss : 0.095944 [2 Epoch] Loss : 0.068737 [3 Epoch] Loss : 0.062231 [4 Epoch] Loss : 0.060306 [5 Epoch] Loss : 0.058868 [6 Epoch] Loss : 0.057487 [7 Epoch] Loss : 0.056279 [8 Epoch] Loss : 0.055408 [9 Epoch] Loss : 0.054647 [10 Epoch] Loss : 0.053863 [11 Epoch] Loss : 0.053497 [12 Epoch] Loss : 0.053028 [13 Epoch] Loss : 0.052476 [14 Epoch] Loss : 0.052226 [15 Epoch] Loss : 0.052059 [16 Epoch] Loss : 0.051822 [17 Epoch] Loss : 0.051519 [18 Epoch] Loss : 0.051271 [19 Epoch] Loss : 0.051186 [20 Epoch] Loss : 0.050843 [21 Epoch] Loss : 0.05064 [22 Epoch] Loss : 0.050666 [23 Epoch] Loss : 0.050386 [24 Epoch] Loss : 0.050362 [25 Epoch] Loss : 0.050294 [26 Epoch] Loss : 0.050006 [27 Epoch] Loss : 0.049823 [28 Epoch] Loss : 0.049832 [29 Epoch] Loss : 0.049689 [30 Epoch] Loss : 0.049538 [31 Epoch] Loss : 0.049473 [32 Epoch] Loss : 0.049431 [33 Epoch] Loss : 0.049274 [34 Epoch] Loss : 0.049137 [35 Epoch] Loss : 0.049087 [36 Epoch] Loss : 0.049055 [37 Epoch] Loss : 0.048922 [38 Epoch] Loss : 0.048787 [39 Epoch] Loss : 0.048695 [40 Epoch] Loss : 0.048707 [41 Epoch] Loss : 0.048625 [42 Epoch] Loss : 0.048548 [43 Epoch] Loss : 0.04855 [44 Epoch] Loss : 0.048426 [45 Epoch] Loss : 0.048367 [46 Epoch] Loss : 0.048252 [47 Epoch] Loss : 0.048214 [48 Epoch] Loss : 0.048192 [49 Epoch] Loss : 0.047979 [50 Epoch] Loss : 0.047991 [51 Epoch] Loss : 0.048021 [52 Epoch] Loss : 0.047966 [53 Epoch] Loss : 0.047814 [54 Epoch] Loss : 0.047812 [55 Epoch] Loss : 0.04777 [56 Epoch] Loss : 0.047688 [57 Epoch] Loss : 0.047662 [58 Epoch] Loss : 0.047618 [59 Epoch] Loss : 0.047493 [60 Epoch] Loss : 0.047482 [61 Epoch] Loss : 0.047461 [62 Epoch] Loss : 0.047344 [63 Epoch] Loss : 0.047345 [64 Epoch] Loss : 0.047277 [65 Epoch] Loss : 0.047227 [66 Epoch] Loss : 0.047286 [67 Epoch] Loss : 0.047128 [68 Epoch] Loss : 0.047194 [69 Epoch] Loss : 0.04706 [70 Epoch] Loss : 0.046987 [71 Epoch] Loss : 0.047071 [72 Epoch] Loss : 0.046961 [73 Epoch] Loss : 0.04695 [74 Epoch] Loss : 0.046869 [75 Epoch] Loss : 0.04681 [76 Epoch] Loss : 0.046758 [77 Epoch] Loss : 0.04672 [78 Epoch] Loss : 0.046758 [79 Epoch] Loss : 0.046703 [80 Epoch] Loss : 0.046686

+++++ SNR = 10 Training Start! +++++ [1 Epoch] Loss : 0.083862 [2 Epoch] Loss : 0.055587 [3 Epoch] Loss : 0.043076 [4 Epoch] Loss : 0.039104 [5 Epoch] Loss : 0.036578 [6 Epoch] Loss : 0.034011 [7 Epoch] Loss : 0.031914 [8 Epoch] Loss : 0.030341 [9 Epoch] Loss : 0.029155 [10 Epoch] Loss : 0.028355 [11 Epoch] Loss : 0.027595 [12 Epoch] Loss : 0.027009 [13 Epoch] Loss : 0.026386 [14 Epoch] Loss : 0.025764 [15 Epoch] Loss : 0.025346 [16 Epoch] Loss : 0.025057 [17 Epoch] Loss : 0.02469 [18 Epoch] Loss : 0.024303 [19 Epoch] Loss : 0.024007 [20 Epoch] Loss : 0.02379 [21 Epoch] Loss : 0.023549 [22 Epoch] Loss : 0.023356 [23 Epoch] Loss : 0.023204 [24 Epoch] Loss : 0.022961 [25 Epoch] Loss : 0.022775 [26 Epoch] Loss : 0.022646 [27 Epoch] Loss : 0.022487 [28 Epoch] Loss : 0.022313 [29 Epoch] Loss : 0.022145 [30 Epoch] Loss : 0.021988 [31 Epoch] Loss : 0.021873 [32 Epoch] Loss : 0.021752 [33 Epoch] Loss : 0.021595 [34 Epoch] Loss : 0.021444 [35 Epoch] Loss : 0.02136 [36 Epoch] Loss : 0.021198 [37 Epoch] Loss : 0.021077 [38 Epoch] Loss : 0.020983 [39 Epoch] Loss : 0.020899 [40 Epoch] Loss : 0.020796 [41 Epoch] Loss : 0.020675 [42 Epoch] Loss : 0.02059 [43 Epoch] Loss : 0.020554 [44 Epoch] Loss : 0.020426 [45 Epoch] Loss : 0.020384 [46 Epoch] Loss : 0.020346 [47 Epoch] Loss : 0.020293 [48 Epoch] Loss : 0.020191 [49 Epoch] Loss : 0.020194 [50 Epoch] Loss : 0.020135 [51 Epoch] Loss : 0.020149 [52 Epoch] Loss : 0.020101 [53 Epoch] Loss : 0.020023 [54 Epoch] Loss : 0.019987 [55 Epoch] Loss : 0.019918 [56 Epoch] Loss : 0.019898 [57 Epoch] Loss : 0.019893 [58 Epoch] Loss : 0.019877 [59 Epoch] Loss : 0.019872 [60 Epoch] Loss : 0.019787 [61 Epoch] Loss : 0.019789 [62 Epoch] Loss : 0.019767 [63 Epoch] Loss : 0.019734 [64 Epoch] Loss : 0.019707 [65 Epoch] Loss : 0.019697 [66 Epoch] Loss : 0.019703 [67 Epoch] Loss : 0.019693 [68 Epoch] Loss : 0.019684 [69 Epoch] Loss : 0.019616 [70 Epoch] Loss : 0.01958 [71 Epoch] Loss : 0.019582 [72 Epoch] Loss : 0.019619 [73 Epoch] Loss : 0.019568 [74 Epoch] Loss : 0.019573 [75 Epoch] Loss : 0.019538 [76 Epoch] Loss : 0.019532 [77 Epoch] Loss : 0.019493 [78 Epoch] Loss : 0.019441 [79 Epoch] Loss : 0.019493 [80 Epoch] Loss : 0.019417

+++++ SNR = 20 Training Start! +++++ [1 Epoch] Loss : 0.080483 [2 Epoch] Loss : 0.049674 [3 Epoch] Loss : 0.038266 [4 Epoch] Loss : 0.034439 [5 Epoch] Loss : 0.031607 [6 Epoch] Loss : 0.028457 [7 Epoch] Loss : 0.026045 [8 Epoch] Loss : 0.024269 [9 Epoch] Loss : 0.023041 [10 Epoch] Loss : 0.021848 [11 Epoch] Loss : 0.020812 [12 Epoch] Loss : 0.020029 [13 Epoch] Loss : 0.019297 [14 Epoch] Loss : 0.018549 [15 Epoch] Loss : 0.018055 [16 Epoch] Loss : 0.017533 [17 Epoch] Loss : 0.017054 [18 Epoch] Loss : 0.016651 [19 Epoch] Loss : 0.01636 [20 Epoch] Loss : 0.01598 [21 Epoch] Loss : 0.015681 [22 Epoch] Loss : 0.015377 [23 Epoch] Loss : 0.015132 [24 Epoch] Loss : 0.014912 [25 Epoch] Loss : 0.014746 [26 Epoch] Loss : 0.014547 [27 Epoch] Loss : 0.014378 [28 Epoch] Loss : 0.014249 [29 Epoch] Loss : 0.014154 [30 Epoch] Loss : 0.014103 [31 Epoch] Loss : 0.014004 [32 Epoch] Loss : 0.013913 [33 Epoch] Loss : 0.013852 [34 Epoch] Loss : 0.013787 [35 Epoch] Loss : 0.013744 [36 Epoch] Loss : 0.013719 [37 Epoch] Loss : 0.013673 [38 Epoch] Loss : 0.013606 [39 Epoch] Loss : 0.013577 [40 Epoch] Loss : 0.01354 [41 Epoch] Loss : 0.013467 [42 Epoch] Loss : 0.013475 [43 Epoch] Loss : 0.013416 [44 Epoch] Loss : 0.013391 [45 Epoch] Loss : 0.013361 [46 Epoch] Loss : 0.013326 [47 Epoch] Loss : 0.013294 [48 Epoch] Loss : 0.013261 [49 Epoch] Loss : 0.013227 [50 Epoch] Loss : 0.013187 [51 Epoch] Loss : 0.013166 [52 Epoch] Loss : 0.013162 [53 Epoch] Loss : 0.013127 [54 Epoch] Loss : 0.013101 [55 Epoch] Loss : 0.013097 [56 Epoch] Loss : 0.013055 [57 Epoch] Loss : 0.013068 [58 Epoch] Loss : 0.013001 [59 Epoch] Loss : 0.013021 [60 Epoch] Loss : 0.012993 [61 Epoch] Loss : 0.012984 [62 Epoch] Loss : 0.012984 [63 Epoch] Loss : 0.012931 [64 Epoch] Loss : 0.012938 [65 Epoch] Loss : 0.012889 [66 Epoch] Loss : 0.012898 [67 Epoch] Loss : 0.012884 [68 Epoch] Loss : 0.012868 [69 Epoch] Loss : 0.012837 [70 Epoch] Loss : 0.012846 [71 Epoch] Loss : 0.012835 [72 Epoch] Loss : 0.012792 [73 Epoch] Loss : 0.012779 [74 Epoch] Loss : 0.01278 [75 Epoch] Loss : 0.012776 [76 Epoch] Loss : 0.012746 [77 Epoch] Loss : 0.012756 [78 Epoch] Loss : 0.01271 [79 Epoch] Loss : 0.012709 [80 Epoch] Loss : 0.01272
compare_PSNR(SNRdB_list, model_name_without_SNRdB, testloader, Rayleigh)

'Wireless Comm. > CISL' 카테고리의 다른 글

Cifar10 AWGN [1dB, 10dB, 20dB]  (0) 2023.07.18
Cifar10 Rayleigh [1dB 10dB 20dB]  (0) 2023.07.11
Low Pass Filtering using cv2  (0) 2023.07.11
Cifar10 Rayleigh with SSIM  (0) 2023.07.10
Cifar10 Autoencoder 20230710 Rayleigh  (0) 2023.07.10
Comments