UOMOP
Good 본문
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torch.nn.functional as F
import math
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import time
from tqdm import tqdm
import numpy as np
from skimage.metrics import structural_similarity as ssim
from params import *
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torch.nn.functional as F
import math
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as f
from torch.utils.data import DataLoader, Dataset
import time
from params import *
import os
from tqdm import tqdm
import numpy as np
from skimage.metrics import structural_similarity as ssim
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#print(device)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def patch_importance(image, patch_size=2, type='variance', how_many=2, noise_scale=0):
if isinstance(image, torch.Tensor):
image = image.numpy()
H, W = image.shape[-2:]
extended_patch_size = patch_size + 2 * how_many
value_map = np.zeros((H // patch_size, W // patch_size))
for i in range(0, H, patch_size):
for j in range(0, W, patch_size):
start_i = max(i - how_many, 0)
end_i = min(i + patch_size + how_many, H)
start_j = max(j - how_many, 0)
end_j = min(j + patch_size + how_many, W)
extended_patch = image[start_i:end_i, start_j:end_j]
if type == 'variance':
value = np.std(extended_patch)
elif type == 'mean_brightness':
value = np.mean(extended_patch)
elif type == 'contrast':
value = extended_patch.max() - extended_patch.min()
elif type == 'edge_density':
dy, dx = np.gradient(extended_patch)
value = np.sum(np.sqrt(dx ** 2 + dy ** 2))
elif type == 'color_diversity':
value = np.std(extended_patch)
noise = np.random.randn() * noise_scale
value_map[i // patch_size, j // patch_size] = value + noise
return value_map
def chessboard_mask(images, patch_size=2, mask_ratio=0.5, importance_type='variance', how_many=1, noise_scale=0):
B, C, H, W = images.shape
masked_images = images.clone()
unmasked_counts = []
unmasked_patches = []
patch_index = []
target_unmasked_ratio = 1 - mask_ratio
num_patches = (H // patch_size) * (W // patch_size)
target_unmasked_patches = int(num_patches * target_unmasked_ratio)
for b in range(B):
patch_importance_map = patch_importance(images[b, 0], patch_size, importance_type, how_many, noise_scale)
mask = np.zeros((H // patch_size, W // patch_size), dtype=bool)
for i in range(H // patch_size):
for j in range(W // patch_size):
if (i + j) % 2 == 0:
mask[i, j] = True
unmasked_count = np.sum(~mask)
if mask_ratio < 0.5:
masked_indices = np.argwhere(mask)
importances = patch_importance_map[mask]
sorted_indices = masked_indices[np.argsort(importances)[::-1]]
for idx in sorted_indices:
if unmasked_count >= target_unmasked_patches:
break
mask[tuple(idx)] = False
unmasked_count += 1
elif mask_ratio > 0.5:
unmasked_indices = np.argwhere(~mask)
importances = patch_importance_map[~mask]
sorted_indices = unmasked_indices[np.argsort(importances)]
for idx in sorted_indices:
if unmasked_count <= target_unmasked_patches:
break
mask[tuple(idx)] = True
unmasked_count -= 1
patches = []
for i in range(H // patch_size):
for j in range(W // patch_size):
if mask[i, j]:
masked_images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = 0
else:
patch = images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size]
patches.append(patch)
patch_index.append((H // patch_size)*i + j)
unmasked_patches.append(torch.cat(patches, dim=-1))
unmasked_counts.append(unmasked_count)
unmasked_patches_image = torch.cat(unmasked_patches, dim=-1)
#print(unmasked_patches_image.shape)
if mask_ratio == 0.33984 :
split_len = 26
split_tensor = torch.split(unmasked_patches_image, split_len, dim=2)
reshaped = torch.cat(split_tensor, dim=1)
return masked_images, reshaped, torch.tensor(patch_index)
class Encoder1D(nn.Module):
def __init__(self, latent_dim):
super(Encoder1D, self).__init__()
self.latent_dim = latent_dim
self.encoder = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1),
nn.PReLU(),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1),
nn.PReLU(),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),
nn.PReLU(),
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1),
nn.PReLU(),
nn.Flatten(), #torch.Size([64, 1280])
nn.Linear(512, self.latent_dim),
)
def forward(self, x):
encoded = self.encoder(x)
return encoded
class Decoder1D(nn.Module):
def __init__(self, latent_dim):
super(Decoder1D, self).__init__()
self.latent_dim = latent_dim
self.decoder = nn.Sequential(
nn.Linear(self.latent_dim, 128 * 2*2),
nn.PReLU(),
nn.Unflatten(1, (128, 2, 2)),
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), # Output: [batch, 32, 63]
nn.PReLU(),
nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=0), # Output: [batch, 32, 127]
nn.PReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=0), # Output: [batch, 16, 255]
nn.PReLU(),
nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2, padding=1, output_padding=1), # Output: [batch, 3, 511]
nn.Sigmoid()
)
def reconstruct_masked_image(self, unmasked_patches, patch_indices, image_shape, patch_size):
B, C, H, W = image_shape
reconstructed_images = torch.zeros((B, C, H, W)).to(unmasked_patches.device)
for b in range(B):
patches = []
for i in range(len(patch_indices[b])):
patches.append(unmasked_patches[b, :, patch_size * i: patch_size * (i + 1)])
for idx, linear_idx in enumerate(patch_indices[b]):
i = linear_idx // (W // patch_size)
j = linear_idx % (W // patch_size)
reconstructed_images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = patches[idx].unsqueeze(2)
return reconstructed_images
def forward(self, x, patch_indices, image_shape, patch_size):
decoded = self.decoder(x)
#masked_image_recon = self.reconstruct_masked_image(decoded, patch_indices, image_shape, patch_size)
return decoded
class Autoencoder1D(nn.Module):
def __init__(self, latent_dim, patch_size):
super(Autoencoder1D, self).__init__()
self.latent_dim = latent_dim
self.patch_size = patch_size
self.encoder = Encoder1D(latent_dim)
self.decoder = Decoder1D(latent_dim)
def Power_norm(self, z, P=1 / np.sqrt(2)):
batch_size, z_dim = z.shape
z_power = torch.sqrt(torch.sum(z ** 2, 1))
z_M = z_power.repeat(z_dim, 1)
return np.sqrt(P * z_dim) * z / z_M.t()
def Power_norm_complex(self, z, P=1 / np.sqrt(2)):
batch_size, z_dim = z.shape
z_com = torch.complex(z[:, 0:z_dim:2], z[:, 1:z_dim:2])
z_com_conj = torch.complex(z[:, 0:z_dim:2], -z[:, 1:z_dim:2])
z_power = torch.sum(z_com * z_com_conj, 1).real
z_M = z_power.repeat(z_dim // 2, 1)
z_nlz = np.sqrt(P * z_dim) * z_com / torch.sqrt(z_M.t())
z_out = torch.zeros(batch_size, z_dim).to(device)
z_out[:, 0:z_dim:2] = z_nlz.real
z_out[:, 1:z_dim:2] = z_nlz.imag
return z_out
def AWGN_channel(self, x, snr, P=1):
batch_size, length = x.shape
gamma = 10 ** (snr / 10.0)
noise = np.sqrt(P / gamma) * torch.randn(batch_size, length).cuda()
y = x + noise
return y
def Fading_channel(self, x, snr, P=1):
gamma = 10 ** (snr / 10.0)
[batch_size, feature_length] = x.shape
K = feature_length // 2
h_I = torch.randn(batch_size, K).to(device)
h_R = torch.randn(batch_size, K).to(device)
h_com = torch.complex(h_I, h_R)
x_com = torch.complex(x[:, 0:feature_length:2], x[:, 1:feature_length:2])
y_com = h_com * x_com
n_I = np.sqrt(P / gamma) * torch.randn(batch_size, K).to(device)
n_R = np.sqrt(P / gamma) * torch.randn(batch_size, K).to(device)
noise = torch.complex(n_I, n_R)
y_add = y_com + noise
y = y_add / h_com
y_out = torch.zeros(batch_size, feature_length).to(device)
y_out[:, 0:feature_length:2] = y.real
y_out[:, 1:feature_length:2] = y.imag
return y_out
def forward(self, x, SNRdB, channel, patch_index, image_shape, patch_size):
encoded = self.encoder(x)
if channel == 'AWGN':
normalized_x = self.Power_norm(encoded)
channel_output = self.AWGN_channel(normalized_x, SNRdB)
elif channel == 'Rayleigh':
normalized_complex_x = self.Power_norm_complex(encoded)
channel_output = self.Fading_channel(normalized_complex_x, SNRdB)
decoded = self.decoder(channel_output, patch_index, image_shape, patch_size)
return decoded
def preprocess_and_save_dataset(dataset, root_dir, patch_size, mask_ratio, importance_type, how_many, noise_scale):
os.makedirs(root_dir, exist_ok=True)
for i, (images, _) in tqdm(enumerate(dataset), total=len(dataset)):
masked_images, unmasked_patches_image, patch_index = chessboard_mask(images.unsqueeze(0), patch_size, mask_ratio, importance_type, how_many, noise_scale)
torch.save({
'original_images': images,
'masked_images' : masked_images.squeeze(0),
'unmasked_patches': unmasked_patches_image,
'patch_index' : patch_index
}, os.path.join(root_dir, f'data_{i}.pt'))
class PreprocessedCIFAR10Dataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.file_paths = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith('.pt')]
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
file_path = self.file_paths[idx]
data = torch.load(file_path, weights_only=False)
original_images = data['original_images']
masked_images = data['masked_images']
unmasked_patches = data['unmasked_patches']
patch_index = data['patch_index']
if self.transform:
unmasked_patches = self.transform(unmasked_patches)
masked_images = self.transform(masked_images)
original_images = self.transform(original_images)
patch_index = self.transform(patch_index)
return original_images, masked_images, unmasked_patches, patch_index
def train(latent_dim, patch_size, mask_ratio, trainloader, testloader, ES, IT, HM, NS):
for snr_i in range(len(params['SNR'])) :
model = Autoencoder1D(latent_dim=latent_dim, patch_size=patch_size).to(device)
print("Model size : {}".format(count_parameters(model)))
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=params['LR'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=8, factor=0.5, verbose=True)
min_test_cost = float('inf')
epochs_no_improve = 0
n_epochs_stop = ES
print("+++++ SNR = {} Training Start! +++++\t".format(params['SNR'][snr_i]))
max_psnr = 0
previous_best_model_path = None
for epoch in range(params['EP']):
# ========================================== Train ==========================================
train_loss = 0.0
model.train()
timetemp = time.time()
for original_images, masked_images, unmasked_patches, patch_index in trainloader:
unmasked_patches = unmasked_patches.squeeze(2).to(device)
original_images = original_images.to(device)
image_shape = original_images.shape
optimizer.zero_grad()
outputs = model(unmasked_patches, SNRdB = params['SNR'][snr_i], channel = params['channel'], patch_index = patch_index, image_shape = image_shape, patch_size = patch_size)
loss = criterion(unmasked_patches, outputs)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_cost = train_loss / len(trainloader)
tr_psnr = round(10 * math.log10(1.0 / train_cost), 3)
# ========================================================================
test_loss = 0.0
model.eval()
with torch.no_grad():
for original_images, masked_images, unmasked_patches, patch_index in testloader:
unmasked_patches = unmasked_patches.squeeze(2).to(device)
original_images = original_images.to(device)
image_shape = original_images.shape
outputs = model(unmasked_patches, SNRdB = params['SNR'][snr_i], channel = params['channel'], patch_index = patch_index, image_shape = image_shape, patch_size = patch_size)
loss = criterion(unmasked_patches, outputs)
test_loss += loss.item()
test_cost = test_loss / len(testloader)
test_psnr = round(10 * math.log10(1.0 / test_cost), 3)
scheduler.step(test_cost)
# 조기 중지 조건 확인
if test_cost < min_test_cost:
min_test_cost = test_cost
epochs_no_improve = 0
else:
epochs_no_improve += 1
if epochs_no_improve == n_epochs_stop:
print("Early stopping!")
break # 조기 종료
training_time = time.time() - timetemp
print(
"[{:>3}-Epoch({:>5}sec.)] PSNR(Train / Val) : {:>6.4f} / {:>6.4f}".format(
epoch + 1, round(training_time, 2), tr_psnr, test_psnr))
if test_psnr > max_psnr:
save_folder = 'trained_model'
if not os.path.exists(save_folder):
os.makedirs(save_folder)
previous_psnr = max_psnr
max_psnr = test_psnr
if previous_best_model_path is not None:
os.remove(previous_best_model_path)
print(f"Performance update!! {previous_psnr} to {max_psnr}")
save_path = os.path.join(save_folder, f"CBM(PS={patch_size}_DIM={latent_dim}_MR={mask_ratio}_IT={IT}_HM={HM}_NS={NS}_SNR={params['SNR'][snr_i]}_PSNR={max_psnr}).pt")
torch.save(model, save_path)
print(f"Saved new best model at {save_path}")
previous_best_model_path = save_path
if __name__ == '__main__':
for ps_i in range(len(params['PS'])):
for dim_i in range(len(params['DIM'])):
for mr_i in range(len(params['MR'])):
for hm_i in range(len(params['HM'])) :
Processed_train_path = "ProcessedTrain(PS=" + str(params['PS'][ps_i]) + "_MR=" + str(params['MR'][mr_i]) + "_IT=" + str(params['IT']) + "_HM=" + str(params['HM'][hm_i]) + ")"
Processed_test_path = "ProcessedTest(PS=" + str(params['PS'][ps_i]) + "_MR=" + str(params['MR'][mr_i]) + "_IT=" + str(params['IT']) + "_HM=" + str(params['HM'][hm_i]) + ")"
if not os.path.exists(Processed_train_path):
transform = transforms.Compose([transforms.ToTensor()])
train_cifar10 = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
preprocess_and_save_dataset(train_cifar10, Processed_train_path, patch_size=params['PS'][ps_i], mask_ratio=params['MR'][mr_i], importance_type=params['IT'], how_many=params['HM'][hm_i], noise_scale=params['NS'])
if not os.path.exists(Processed_test_path):
transform = transforms.Compose([transforms.ToTensor()])
test_cifar10 = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
preprocess_and_save_dataset(test_cifar10, Processed_test_path, patch_size=params['PS'][ps_i], mask_ratio=params['MR'][mr_i], importance_type=params['IT'], how_many=params['HM'][hm_i], noise_scale=params['NS'])
traindataset = PreprocessedCIFAR10Dataset(root_dir=Processed_train_path)
testdataset = PreprocessedCIFAR10Dataset(root_dir=Processed_test_path)
trainloader = DataLoader(traindataset, batch_size=params['BS'], shuffle=True, num_workers=4, drop_last = True)
testloader = DataLoader(testdataset, batch_size=params['BS'], shuffle=True, num_workers=4, drop_last = True)
train(params['DIM'][dim_i], params['PS'][ps_i], params['MR'][mr_i], trainloader, testloader, params['ES'], params['IT'], params['HM'][hm_i], params['NS'])
params = {
'BS': 64,
'LR': 0.001,
'EP': 5000,
'SNR': [40, 0],
'DIM': [512, 256],
'MR' : [0.33984],
'PS' : [2],
'ES' : 40,
'IT' : 'variance',
'HM' : [1, 2, 3, 4],
'NS' : 0,
'channel' : 'Rayleigh'
}
'Main' 카테고리의 다른 글
CBS (0) | 2024.07.25 |
---|---|
Position Estimator (0) | 2024.07.10 |
No Encoder Symbol Check (1) | 2024.07.05 |
No Masking Symbol Check (0) | 2024.07.05 |
Object/background focusing (0) | 2024.06.25 |
Comments