UOMOP
Colorization cifar10(gray to color) 본문
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision import models,transforms
from torchvision.utils import make_grid
from torchvision.datasets import CIFAR10
from torch.utils.data.sampler import SubsetRandomSampler
from mpl_toolkits.axes_grid1 import ImageGrid
#from torchsummary import summary
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import math
import math
import torch
import torchvision
from fractions import Fraction
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as f
import matplotlib.pyplot as plt
import torchvision.transforms as tr
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
def cifar10(batch_sz, random_seed, valid_size=0.2, shuffle=True):
if random_seed is not None:
torch.manual_seed(random_seed)
transform_train = transforms.Compose([
transforms.ToTensor(),
])
transform_valid = transforms.ToTensor()
transform_test = transforms.ToTensor()
train_dataset = CIFAR10(root='./datasets', train=True, download=True, transform=transform_train)
valid_dataset = CIFAR10(root='./datasets', train=True, download=True, transform=transform_valid)
num_train = len(train_dataset)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))
if shuffle:
np.random.seed(random_seed)
np.random.shuffle(indices)
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_sz, sampler=train_sampler, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_sz, sampler=valid_sampler, pin_memory=True)
test_dataset = CIFAR10(root='./datasets', train=False, download=True, transform=transform_test)
test_loader = DataLoader(test_dataset, batch_size=batch_sz)
return train_loader, valid_loader, test_loader
def convert_to_grayscale(batch):
grayscale_dataset = []
for image in batch:
# Convert colored image to grayscale by taking the average of all channels
grayscale_image = torch.mean(image, dim=0, keepdim=True)
grayscale_dataset.append((grayscale_image))
grayscale_dataset = torch.stack(grayscale_dataset, dim=0)
return grayscale_dataset
batch_sz=64 # this is batch size i.e. the number of rows in a batch of data
random_seed=2000
train_loader, valid_loader, test_loader = cifar10(batch_sz, random_seed)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
class Generator(nn.Module):
"""Generator model"""
def __init__(self):
super(Generator, self).__init__()
# U-Net architecture
self.encoding_unit1 = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(),
)
self.encoding_unit2 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(),
)
self.encoding_unit3 = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(),
)
self.encoding_unit4 = nn.Sequential(
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(),
)
self.encoding_unit5 = nn.Sequential(
nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(1024),
nn.LeakyReLU(),
nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(1024),
nn.LeakyReLU(),
)
self.decoding_unit1 = nn.Sequential(
nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(),
)
self.decoding_unit2 = nn.Sequential(
nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(),
)
self.decoding_unit3 = nn.Sequential(
nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(),
)
self.decoding_unit4 = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(),
)
self.decoding_unit5 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1, stride=1, padding=0),
)
self.upsampling1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, stride=2, padding=1, output_padding=1)
self.upsampling2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1)
self.upsampling3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1)
self.upsampling4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1)
def forward(self, x):
x1 = self.encoding_unit1(x)
x = F.max_pool2d(x1, (2,2))
x2 = self.encoding_unit2(x)
x = F.max_pool2d(x2, (2,2))
x3 = self.encoding_unit3(x)
x = F.max_pool2d(x3, (2,2))
x4 = self.encoding_unit4(x)
x = F.max_pool2d(x4, (2,2))
x = self.encoding_unit5(x)
x = self.upsampling1(x)
x = torch.cat([x, x4], dim=1)
x = self.decoding_unit1(x)
x = self.upsampling2(x)
x = torch.cat([x, x3], dim=1)
x = self.decoding_unit2(x)
x = self.upsampling3(x)
x = torch.cat([x, x2], dim=1)
x = self.decoding_unit3(x)
x = self.upsampling4(x)
x = torch.cat([x, x1], dim=1)
x = self.decoding_unit4(x)
x = self.decoding_unit5(x)
x = torch.tanh(x)
return x
class Discriminator(nn.Module):
"""Discriminator model"""
def __init__(self):
super(Discriminator, self).__init__()
self.disc_model = nn.Sequential(
nn.Conv2d(in_channels=4, out_channels=64, kernel_size=5, stride=2, padding=2),
nn.BatchNorm2d(64),
nn.LeakyReLU(),
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(),
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(),
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(),
nn.Flatten(1,-1),
)
self.layer_1 = nn.Linear(512 * 2 * 2, 256)
self.layer_1_b = nn.BatchNorm1d(256)
self.layer_1_a = nn.LeakyReLU()
self.layer_2 = nn.Linear(256, 128)
self.layer_2_b = nn.BatchNorm1d(128)
self.layer_2_a = nn.LeakyReLU()
self.layer_3 = nn.Linear(128, 1)
self.layer_3_a = nn.Sigmoid()
def forward(self, x, condition):
x = torch.cat([x,condition], 1)
x = self.disc_model(x)
x = self.layer_1_a(self.layer_1_b(self.layer_1(x)))
x = self.layer_2_a(self.layer_2_b(self.layer_2(x)))
x = self.layer_3_a(self.layer_3(x))
return x
def display_image_grid(images, labels, num_rows, num_cols, title_text):
fig = plt.figure(figsize=(num_cols*3., num_rows*3.), )
grid = ImageGrid(fig, 111, nrows_ncols=(num_rows, num_cols), axes_pad=0.15)
for ax, im, l in zip(grid, images, labels):
if im.size(0) == 1:
if im.dtype == torch.float32 or im.dtype == torch.float64:
ax.imshow(np.clip(im.cpu().permute(1,2,0).numpy(), 0, 1), cmap = 'gray')
else:
ax.imshow(np.clip(im.cpu().permute(1,2,0).numpy(), 0, 255), cmap = 'gray')
else:
if im.dtype == torch.float32 or im.dtype == torch.float64:
ax.imshow(np.clip(im.cpu().permute(1,2,0).numpy(), 0, 1))
else:
ax.imshow(np.clip(im.cpu().permute(1,2,0).numpy(), 0, 255))
ax.axis("off")
ax.set_title(classes[l])
plt.suptitle(title_text, fontsize=20)
plt.tight_layout()
plt.show()
args = {
'BATCH_SIZE' : 64,
'LEARNING_RATE' : 0.0005,
'NUM_EPOCH' : 500,
'SNRdB_list' : [30],
'latent_dim' : 512,
'input_size' : 32*32
}
transf = tr.Compose([tr.ToTensor()])
trainset = torchvision.datasets.CIFAR10(root = './data', train = True, download = True, transform = transf)
testset = torchvision.datasets.CIFAR10(root = './data', train = False, download = True, transform = transf)
trainloader = DataLoader(trainset, batch_size = args['BATCH_SIZE'], shuffle = True)
testloader = DataLoader(testset, batch_size = args['BATCH_SIZE'], shuffle = True)
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
c_hid = 32
self.encoder = nn.Sequential(
nn.Conv2d(3, c_hid, kernel_size=3, padding=1, stride=2), # 32x32 => 16x16
nn.ReLU(),
nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2), # 16x16 => 8x8
nn.ReLU(),
nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2), # 8x8 => 4x4
nn.ReLU(),
nn.Flatten(), # Image grid to single feature vector
nn.Linear(2 * 16 * c_hid, args['latent_dim'])
)
def forward(self, x):
encoded = self.encoder(x)
return encoded
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
c_hid = 32
self.linear = nn.Sequential(
nn.Linear(args['latent_dim'], 2 * 16 * c_hid),
nn.ReLU()
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(2 * c_hid, 2 * c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 4x4 => 8x8
nn.ReLU(),
nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(2 * c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 8x8 => 16x16
nn.ReLU(),
nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(c_hid, 3, kernel_size=3, output_padding=1, padding=1, stride=2), # 16x16 => 32x32
)
def forward(self, x):
x = self.linear(x)
x = x.reshape(x.shape[0], -1, 4, 4)
decoded = self.decoder(x)
return decoded
class Autoencoder(nn.Module):
def __init__(
self,
encoder_class: object = Encoder,
decoder_class: object = Decoder
):
super(Autoencoder, self).__init__()
self.encoder = encoder_class()
self.decoder = decoder_class()
def AWGN(self, input, SNRdB):
normalized_tensor = f.normalize(input, dim=1)
SNR = 10.0 ** (SNRdB / 10.0)
K = args['latent_dim']
std = 1 / math.sqrt(K * SNR)
n = torch.normal(0, std, size=normalized_tensor.size()).to(device)
return normalized_tensor + n
def forward(self, x, SNRdB):
encoded = self.encoder(x)
channel_output = self.AWGN(encoded, SNRdB)
decoded = self.decoder(channel_output)
return decoded
device = "cuda"
def test_GAN(generator, discriminator, l1_reg):
real_label = 1
fake_label = 0
model_location = "DeepJSCC(512)_30dB.pth"
model = Autoencoder().to(device)
model.load_state_dict(torch.load(model_location))
generator.eval()
discriminator.eval()
loss_fn = nn.BCELoss()
n_batch = 0
total_gen_loss = 0.0
total_disc_loss = 0.0
num_samples = 0
with torch.no_grad():
for batch in test_loader:
original_images, classes = batch
inputs = original_images.to(device)
classes = classes.to(device)
grayscale_images = convert_to_grayscale(original_images)
condition = grayscale_images.to(device)
inputs_for_deepjscc = batch[0].to(device)
outputs_for_deepjscc = model(inputs_for_deepjscc, SNRdB=30)
inputs = (inputs - 0.5) * 2
condition = (condition - 0.5) * 2
label = torch.full((inputs.shape[0], 1), real_label, dtype=torch.float, device=device)
output_real = discriminator(inputs, condition)
fake = generator(condition)
output_fake = discriminator(fake, condition)
errD = loss_fn(output_real, label) + loss_fn(output_fake, label)
label.fill_(real_label)
output_fake = discriminator(fake, condition)
errG = loss_fn(output_fake, label) + (l1_reg * torch.mean(torch.abs(fake - inputs)))
total_disc_loss += errD.item()
total_gen_loss += errG.item()
num_samples += inputs.size(0)
if (n_batch <= 10):
print("Batch", n_batch+1)
# Extract labels from the test_loader
labels = classes.cpu().numpy()
generated_batch = generator(condition).cpu()
generated_batch = (generated_batch + 1) /2
generated_batch = (generated_batch * 255).clamp(0, 255).to(torch.uint8)
# Display original images
display_image_grid(original_images, labels, 1, 10, "Original Images")
# Display grayscale images
display_image_grid(grayscale_images, labels, 1, 10, "Grayscale Images")
# Display generated images
display_image_grid(generated_batch.squeeze(1), labels, 1, 10, "Generated Images")
display_image_grid(outputs_for_deepjscc,labels, 1, 10, "DeepJSCC(30dB)")
n_batch+=1
gen_test_loss = total_gen_loss / num_samples
disc_test_loss = total_disc_loss / num_samples
print("Generator test loss:", gen_test_loss)
print("Discriminator test loss:", disc_test_loss)
generator_location = "Generator.pth"
generator = Generator().to(device)
generator.load_state_dict(torch.load(generator_location))
discriminator_location = "Discriminator.pth"
discriminator = Discriminator().to(device)
discriminator.load_state_dict(torch.load(discriminator_location))
test_GAN(generator, discriminator, 80)
'Research > Semantic Communication' 카테고리의 다른 글
JPEG의 header, tail (0) | 2023.12.06 |
---|---|
temp_20231205 (0) | 2023.12.05 |
JPEG temp (0) | 2023.12.05 |
JPEG+LDPC+16QAM python (0) | 2023.12.05 |
temp : JPEG+LDPG+16QAM (0) | 2023.12.04 |
Comments