UOMOP
Random Selection 본문
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torch.nn.functional as F
import math
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as f
from torch.utils.data import DataLoader, Dataset
import time
from params import *
import os
from tqdm import tqdm
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def random_patch_masking(images, patch_size, mask_ratio):
B, C, H, W = images.shape
n_patches_horizontal = H // patch_size
n_patches_vertical = W // patch_size
n_patches_per_image = n_patches_horizontal * n_patches_vertical
n_patches_to_mask = int(n_patches_per_image * mask_ratio)
masked_images = images.clone()
for i in range(B):
mask_indices = torch.randperm(n_patches_per_image)[:n_patches_to_mask]
for idx in mask_indices:
row = torch.div(idx, n_patches_vertical, rounding_mode='floor') * patch_size
col = (idx % n_patches_vertical) * patch_size
masked_images[i, :, row:row + patch_size, col:col + patch_size] = 0
return masked_images
class Encoder(nn.Module):
def __init__(self, latent_dim):
super(Encoder, self).__init__()
self.latent_dim = latent_dim
self.encoder = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1), # Output: [batch, 32, 16, 16]
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # Output: [batch, 64, 8, 8]
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # Output: [batch, 128, 4, 4]
nn.ReLU(),
nn.Flatten(),
nn.Linear(4*4*128, self.latent_dim),
)
def forward(self, x):
return self.encoder(x)
class Decoder(nn.Module):
def __init__(self, latent_dim):
super(Decoder, self).__init__()
self.latent_dim = latent_dim
self.decoder = nn.Sequential(
nn.Linear(self.latent_dim, 4*4*128),
nn.ReLU(),
nn.Unflatten(1, (128, 4, 4)),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # Output: [batch, 64, 8, 8]
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # Output: [batch, 32, 16, 16]
nn.ReLU(),
nn.Conv2d(32, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1), # Output: [batch, 3, 32, 32]
nn.Sigmoid()
)
def forward(self, x):
return self.decoder(x)
class Autoencoder(nn.Module):
def __init__(self, latent_dim):
super(Autoencoder, self).__init__()
self.latent_dim = latent_dim
self.encoder = Encoder(latent_dim)
self.decoder = Decoder(latent_dim)
def AWGN(self, input, SNRdB):
normalized_tensor = f.normalize(input, dim=1)
SNR = 10.0 ** (SNRdB / 10.0)
std = 1 / math.sqrt(self.latent_dim * SNR)
n = torch.normal(0, std, size=normalized_tensor.size()).to(device)
return normalized_tensor + n
def forward(self, x, SNRdB):
encoded = self.encoder(x)
channel_output = self.AWGN(encoded, SNRdB)
decoded = self.decoder(channel_output)
return decoded
def train(latent_dim, patch_size, mask_ratio, trainloader, testloader):
for snr_i in range(len(params['SNR'])) :
model = Autoencoder(latent_dim=latent_dim).to(device)
print("Model size : {}".format(count_parameters(model)))
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=params['LR'])
min_test_cost = float('inf')
epochs_no_improve = 0 # 감소하지 않은 에폭 수
n_epochs_stop = 20 # 조기 중지 기준이 되는 에폭 수
print("+++++ SNR = {} Training Start! +++++\t".format(params['SNR'][snr_i]))
max_psnr = 0
for epoch in range(params['EP']):
# ========================================== Train ==========================================
train_loss = 0.0
model.train()
timetemp = time.time()
for data in trainloader:
inputs = data[0].to(device)
masked_inputs = random_patch_masking(inputs, patch_size, mask_ratio)
optimizer.zero_grad()
outputs = model(masked_inputs, SNRdB = params['SNR'][snr_i])
loss = criterion(inputs, outputs)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_cost = train_loss / len(trainloader)
tr_psnr = round(10 * math.log10(1.0 / train_cost), 3)
# ========================================================================
test_loss = 0.0
model.eval()
with torch.no_grad():
for data in testloader:
inputs = data[0].to(device)
masked_inputs = random_patch_masking(inputs, patch_size, mask_ratio)
outputs = model(masked_inputs, SNRdB=params['SNR'][snr_i])
loss = criterion(inputs, outputs)
test_loss += loss.item()
test_cost = test_loss / len(testloader)
test_psnr = round(10 * math.log10(1.0 / test_cost), 3)
# 조기 중지 조건 확인
if test_cost < min_test_cost:
min_test_cost = test_cost
epochs_no_improve = 0
else:
epochs_no_improve += 1
if epochs_no_improve == n_epochs_stop:
print("Early stopping!")
break # 조기 종료
training_time = time.time() - timetemp
print(
"[{:>3}-Epoch({:>5}sec.)] PSNR(Train / Val) : {:>6.4f} / {:>6.4f} Loss(Train / Val) : {:>5.5f} / {:>5.5f}".format(
epoch + 1, round(training_time, 2), tr_psnr, test_psnr, train_cost, test_cost))
if test_psnr > max_psnr:
save_folder = 'trained_model'
if not os.path.exists(save_folder):
os.makedirs(save_folder)
max_psnr = test_psnr
save_path = os.path.join(save_folder, "RS(PS=" + str(patch_size) + "_DIM=" + str(latent_dim) + "_MR=" + str(mask_ratio) + "_SNR=" + str(params['SNR'][snr_i]) + "_PSNR=" + str(max_psnr) + ").pt")
torch.save(model, save_path)
'''
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(images[0].permute(1, 2, 0)) # Unnormalize
plt.title('Original Image')
plt.subplot(1, 2, 2)
plt.imshow(masked_images[0].permute(1, 2, 0)) # Unnormalize
plt.title('Masked Image')
plt.show()
'''
if __name__ == '__main__':
transform = transforms.Compose([transforms.ToTensor()])
train_cifar10 = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_cifar10 = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
trainloader = DataLoader(train_cifar10, batch_size=params['BS'], shuffle=True)
testloader = DataLoader(test_cifar10, batch_size=params['BS'], shuffle=True)
for ps_i in range(len(params['PS'])):
for dim_i in range(len(params['DIM'])):
for mr_i in range(len(params['MR'])):
train(params['DIM'][dim_i], params['PS'][ps_i], params['MR'][mr_i], trainloader, testloader)
params = {
'BS': 64,
'LR': 0.0005,
'EP': 500,
'SNR': [0, 15, 30],
'DIM': [32, 128, 512],
'MR' : [0, 0.25, 0.5, 0.75, 1],
'PS' : [2, 4, 8, 16]
}
'Main' 카테고리의 다른 글
High Attention Selection Performance (CR : 1/6, 1/24, 1/96) (PS : 2) (0) | 2024.03.18 |
---|---|
Low Attention Selection Performance (CR : 1/6, 1/24, 1/96) (PS : 2) (1) | 2024.03.15 |
High Attention Selection (0) | 2024.03.15 |
Low Attention Selection (0) | 2024.03.15 |
Patch selection with zero padding (0) | 2024.03.13 |
Comments