UOMOP

AutoEncoder for Fashion MNIST 본문

Wireless Comm./Python

AutoEncoder for Fashion MNIST

Happy PinGu 2023. 5. 11. 23:50
import cv2
import math
import time
import random

import numpy as np
from PIL import Image
from numpy import sqrt
from tqdm import trange
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler

import torch
import torchvision
from torchvision import transforms
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision import transforms, datasets

train_dataset = datasets.FashionMNIST(root = "../data/FashionMNIST",   
                                      train = True,
                                      download = True,
                                      transform = transforms.ToTensor() )

test_dataset = datasets.FashionMNIST(root = "../data/FashionMNIST",     #검증용 데이터 하지만 이미 다운로드는 완료했기 때문에 필요없다.
                                      train = False,
                                      transform = transforms.ToTensor() )
                                      
args = {
    'BATCH_SIZE' : 64,
    'LEARNING_RATE' : 0.001,
    'NUM_EPOCH' : 80,
    'SNR_dB' : 20, 
    'latent_dim' : 50,
    'input_dim' : 28*28
}

train_loader = torch.utils.data.DataLoader(dataset = train_dataset,     #데이터를 loading할때 기본 성질을 정의해주는 구문.
                                           batch_size = args['BATCH_SIZE'],
                                           shuffle = True
                                           )

test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
                                           batch_size = args['BATCH_SIZE'],
                                           shuffle = False
                                           )
                                           
class Autoencoder(nn.Module) :
    def __init__(self) :
        super(Autoencoder, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(784, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 100),
            nn.ReLU()
        )

        self.decoder = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 28 * 28),
            nn.Sigmoid()
        )

    def addAWGN_latent(self, encoded, SNR_dB) :

        batch_size = encoded.shape[0]
        latent_dim = encoded.shape[1]

        encoded = encoded.flatten().tolist()
        complex_list = list()

        for i in range(int(len(encoded)/2)) :
            complex_list.append( complex(encoded[2*i], encoded[2*i+1]) )

        target_snr_db = SNR_dB

        sum_abs = 0
        abs_list = list()
        noise_out = list()

        for i in range(len(complex_list)):
            sum_abs = sum_abs + abs(complex_list[i])

        for i in range(len(complex_list)):
            abs_list.append(abs(complex_list[i]))

        SNR_linear = 10 ** (SNR_dB / 10)

        complex_list_power = np.mean(abs_list)

        N0 = complex_list_power / SNR_linear

        for i in range(len(complex_list)):
            noise = sqrt(N0 / 2) * complex(random.gauss(0, 1), random.gauss(0, 1))
            noise_out.append(complex_list[i] + noise)

        for_tensor_list = list()

        for i in range( int(len(noise_out)) ) :
            for_tensor_list.append( noise_out[i].real )
            for_tensor_list.append( noise_out[i].imag )

        encoded_AWGN = torch.tensor(for_tensor_list).view(batch_size, latent_dim).type(torch.float32) 

        return encoded_AWGN

    def forward(self, x) :
        encoded = self.encoder(x)
        encoded_AWGN = self.addAWGN_latent(encoded, args['SNR_dB'])
        decoded = self.decoder(encoded_AWGN)

        return decoded
        
model = Autoencoder()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr = args['LEARNING_RATE'])

for epoch in range(args['NUM_EPOCH']) :

    running_loss = 0.0

    for data in train_loader :

        inputs = data[0]
        optimizer.zero_grad()
        outputs = model(inputs.view(-1, args['input_dim']))
        outputs = outputs.view(-1, 1, 28, 28)
        loss = criterion(inputs, outputs)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    cost = running_loss/len(train_loader)
    print("[{}] Loss : {}".format(epoch + 1, round(cost, 6)))

'Wireless Comm. > Python' 카테고리의 다른 글

Add AWGN to image  (0) 2023.05.24
plt.imshow 파란색으로 되는 현상  (0) 2023.05.24
Image PSNR Check_20230508  (0) 2023.05.08
Conventional_20230504  (0) 2023.05.04
save2  (0) 2023.05.03
Comments