UOMOP
Traditional AutoEncoder Cifar10(gray) 2023 06 07 본문
Wireless Comm./Python
Traditional AutoEncoder Cifar10(gray) 2023 06 07
Happy PinGu 2023. 6. 7. 19:05import math
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as f
import matplotlib.pyplot as plt
import torchvision.transforms as tr
from torch.utils.data import DataLoader
args = {
'BATCH_SIZE' : 50,
'LEARNING_RATE' : 0.001,
'NUM_EPOCH' : 80,
'SNR_dB' : [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17],
'latent_dim' : 100,
'input_dim' : 32 * 32
}
transf = tr.Compose([tr.ToTensor(), tr.Grayscale(num_output_channels = 1)])
trainset = torchvision.datasets.CIFAR10(root = './data', train = True, download = True, transform = transf)
testset = torchvision.datasets.CIFAR10(root = './data', train = False, download = True, transform = transf)
trainloader = DataLoader(trainset, batch_size = args['BATCH_SIZE'], shuffle = True)
testloader = DataLoader(testset, batch_size = args['BATCH_SIZE'], shuffle = False)
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(args['input_dim'], 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, args['latent_dim']),
nn.ReLU()
)
def forward(self, x):
return self.encoder(x)
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.decoder = nn.Sequential(
nn.Linear(args['latent_dim'], 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, args['input_dim']),
nn.Sigmoid()
)
def forward(self, x):
return self.decoder(x)
class Autoencoder(nn.Module):
def __init__(
self,
encoder_class: object = Encoder,
decoder_class: object = Decoder
):
super(Autoencoder, self).__init__()
self.encoder = encoder_class()
self.decoder = decoder_class()
def add_AWGN2tensor(self, input, SNRdB):
normalized_tensor = f.normalize(input, dim=1)
SNR = 10.0 ** (SNRdB / 10.0)
std = 1 / math.sqrt(SNR)
noise = torch.normal(0, std, size=normalized_tensor.size())
return input + noise
def forward(self, x, SNRdB, Rayleigh):
encoded = self.encoder(x)
if Rayleigh == 1:
print(1)
elif Rayleigh == 0:
encoded_after_channel = self.add_AWGN2tensor(encoded, SNRdB)
decoded = self.decoder(encoded_after_channel)
return decoded
for i in range( len(args['SNR_dB']) ):
model = Autoencoder()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr = args['LEARNING_RATE'])
print("+++++ SNR = {} Training Start! +++++".format(args['SNR_dB'][i]))
for epoch in range(args['NUM_EPOCH']) :
running_loss = 0.0
for data in trainloader :
inputs = data[0]
optimizer.zero_grad()
outputs = model( inputs.view(-1, args['input_dim']), SNRdB = args['SNR_dB'][i], Rayleigh = 0)
outputs = outputs.view(-1, 1, 32, 32)
loss = criterion(inputs, outputs)
loss.backward()
optimizer.step()
running_loss += loss.item()
cost = running_loss/len(trainloader)
print("[{} Epoch] Loss : {}".format(epoch + 1, round(cost, 6)))
print()
PATH = "./"
torch.save(model.state_dict(), PATH + "model_AWGN" + "_SNR_" + str(args['SNR_dB'][i]) +".pth")
def PSNR_test(model_name, testloader, SNRdB, Rayleigh) :
model_1 = Autoencoder()
model_1.load_state_dict(torch.load(model_name))
PSNR_list = []
for data in testloader :
inputs = data[0]
outputs = model_1( inputs.view(-1, args['input_dim']), SNRdB = SNRdB, Rayleigh = Rayleigh)
outputs = outputs.view(-1, 1, 32, 32)
PSNR = 0
for j in range(len(inputs)) :
MSE = torch.sum(torch.pow(inputs[j][0] - outputs[j][0], 2)) / torch.numel(inputs[j][0])
PSNR += 10 * math.log10(1 / MSE)
PSNR_list.append(PSNR / len(inputs))
print("PSNR : {}".format(round(sum(PSNR_list) / len(PSNR_list), 2)))
return round(sum(PSNR_list) / len(PSNR_list), 2)
PSNR_list = []
for i in range(len(args['SNR_dB'])) :
SNRdB = args['SNR_dB'][i]
model_name = "model_AWGN" + "_SNR_" + str(SNRdB) + ".pth"
PSNR = PSNR_test(model_name, testloader, SNRdB, 0)
PSNR_list.append(PSNR)
plt.plot(args['SNR_dB'], PSNR_list, linestyle = 'dashed', color = 'blue', label = "AWGN")
plt.grid(True)
plt.legend()
plt.show()
'Wireless Comm. > Python' 카테고리의 다른 글
cifar10 HPF filter size에 따른 결과 확인 (0) | 2023.06.08 |
---|---|
***Cifar10_AE(color)_20230608 (0) | 2023.06.08 |
send img with contour : 20230603 (0) | 2023.06.03 |
send img with contour (0) | 2023.06.02 |
trainloader image 확인 (0) | 2023.06.01 |
Comments