UOMOP

Traditional AutoEncoder Cifar10(gray) 2023 06 07 본문

Wireless Comm./Python

Traditional AutoEncoder Cifar10(gray) 2023 06 07

Happy PinGu 2023. 6. 7. 19:05
import math
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as f
import matplotlib.pyplot as plt
import torchvision.transforms as tr
from torch.utils.data import DataLoader

args = {
    'BATCH_SIZE' : 50,
    'LEARNING_RATE' : 0.001,
    'NUM_EPOCH' : 80,
    'SNR_dB' : [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17],
    'latent_dim' : 100,
    'input_dim' : 32 * 32
}

transf = tr.Compose([tr.ToTensor(), tr.Grayscale(num_output_channels = 1)])

trainset = torchvision.datasets.CIFAR10(root = './data', train = True, download = True, transform = transf)
testset = torchvision.datasets.CIFAR10(root = './data', train = False, download = True, transform = transf)

trainloader = DataLoader(trainset, batch_size = args['BATCH_SIZE'], shuffle = True)
testloader  = DataLoader(testset, batch_size = args['BATCH_SIZE'], shuffle = False)




class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()

        self.encoder = nn.Sequential(

            nn.Linear(args['input_dim'], 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, args['latent_dim']),
            nn.ReLU()
        )

    def forward(self, x):
        return self.encoder(x)


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()

        self.decoder = nn.Sequential(

            nn.Linear(args['latent_dim'], 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, args['input_dim']),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.decoder(x)


class Autoencoder(nn.Module):
    def __init__(
            self,
            encoder_class: object = Encoder,
            decoder_class: object = Decoder
    ):
        super(Autoencoder, self).__init__()

        self.encoder = encoder_class()
        self.decoder = decoder_class()

    def add_AWGN2tensor(self, input, SNRdB):
        normalized_tensor = f.normalize(input, dim=1)

        SNR = 10.0 ** (SNRdB / 10.0)
        std = 1 / math.sqrt(SNR)
        noise = torch.normal(0, std, size=normalized_tensor.size())

        return input + noise

    def forward(self, x, SNRdB, Rayleigh):

        encoded = self.encoder(x)

        if Rayleigh == 1:
            print(1)

        elif Rayleigh == 0:
            encoded_after_channel = self.add_AWGN2tensor(encoded, SNRdB)

        decoded = self.decoder(encoded_after_channel)

        return decoded


for i in range( len(args['SNR_dB']) ):

    model = Autoencoder()
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr = args['LEARNING_RATE'])

    print("+++++ SNR = {} Training Start! +++++".format(args['SNR_dB'][i]))

    for epoch in range(args['NUM_EPOCH']) :

        running_loss = 0.0

        for data in trainloader :

            inputs = data[0]
            optimizer.zero_grad()
            outputs = model( inputs.view(-1, args['input_dim']), SNRdB = args['SNR_dB'][i], Rayleigh = 0)
            outputs = outputs.view(-1, 1, 32, 32)
            loss = criterion(inputs, outputs)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        cost = running_loss/len(trainloader)
        print("[{} Epoch] Loss : {}".format(epoch + 1, round(cost, 6)))
    print()
    PATH = "./"
    torch.save(model.state_dict(), PATH + "model_AWGN" + "_SNR_" + str(args['SNR_dB'][i]) +".pth")


def PSNR_test(model_name, testloader, SNRdB, Rayleigh) :

    model_1 = Autoencoder()
    model_1.load_state_dict(torch.load(model_name))

    PSNR_list = []

    for data in testloader :
        inputs = data[0]
        outputs = model_1( inputs.view(-1, args['input_dim']), SNRdB = SNRdB, Rayleigh = Rayleigh)
        outputs = outputs.view(-1, 1, 32, 32)
        PSNR = 0

        for j in range(len(inputs)) :
            MSE = torch.sum(torch.pow(inputs[j][0] - outputs[j][0], 2)) / torch.numel(inputs[j][0])
            PSNR += 10 * math.log10(1 / MSE)

        PSNR_list.append(PSNR / len(inputs))
    print("PSNR : {}".format(round(sum(PSNR_list) / len(PSNR_list), 2)))

    return round(sum(PSNR_list) / len(PSNR_list), 2)

PSNR_list = []

for i in range(len(args['SNR_dB'])) :
    SNRdB = args['SNR_dB'][i]
    model_name = "model_AWGN" + "_SNR_" + str(SNRdB) + ".pth"
    PSNR = PSNR_test(model_name, testloader, SNRdB, 0)
    PSNR_list.append(PSNR)

plt.plot(args['SNR_dB'], PSNR_list, linestyle = 'dashed', color = 'blue', label = "AWGN")
plt.grid(True)
plt.legend()
plt.show()

'Wireless Comm. > Python' 카테고리의 다른 글

cifar10 HPF filter size에 따른 결과 확인  (0) 2023.06.08
***Cifar10_AE(color)_20230608  (0) 2023.06.08
send img with contour : 20230603  (0) 2023.06.03
send img with contour  (0) 2023.06.02
trainloader image 확인  (0) 2023.06.01
Comments