UOMOP

Colorization cifar10(gray to color) 본문

Research/Semantic Communication

Colorization cifar10(gray to color)

Happy PinGu 2023. 12. 19. 17:38
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision import models,transforms
from torchvision.utils import make_grid
from torchvision.datasets import CIFAR10
from torch.utils.data.sampler import SubsetRandomSampler
from mpl_toolkits.axes_grid1 import ImageGrid
#from torchsummary import summary
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import math
import math
import torch
import torchvision
from fractions import Fraction
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as f
import matplotlib.pyplot as plt
import torchvision.transforms as tr
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset

def cifar10(batch_sz, random_seed, valid_size=0.2, shuffle=True):

  if random_seed is not None:
    torch.manual_seed(random_seed)

  transform_train = transforms.Compose([
      transforms.ToTensor(),
  ])
  transform_valid = transforms.ToTensor()
  transform_test = transforms.ToTensor()

  train_dataset = CIFAR10(root='./datasets', train=True, download=True, transform=transform_train)
  valid_dataset = CIFAR10(root='./datasets', train=True, download=True, transform=transform_valid)

  num_train = len(train_dataset)
  indices = list(range(num_train))
  split = int(np.floor(valid_size * num_train))
  if shuffle:
    np.random.seed(random_seed)
    np.random.shuffle(indices)
  train_idx, valid_idx = indices[split:], indices[:split]

  train_sampler = SubsetRandomSampler(train_idx)
  valid_sampler = SubsetRandomSampler(valid_idx)
  train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_sz, sampler=train_sampler, pin_memory=True)
  valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_sz, sampler=valid_sampler, pin_memory=True)

  test_dataset = CIFAR10(root='./datasets', train=False, download=True, transform=transform_test)
  test_loader = DataLoader(test_dataset, batch_size=batch_sz)

  return train_loader, valid_loader, test_loader


def convert_to_grayscale(batch):

  grayscale_dataset = []
  for image in batch:
    # Convert colored image to grayscale by taking the average of all channels
    grayscale_image = torch.mean(image, dim=0, keepdim=True)
    grayscale_dataset.append((grayscale_image))

  grayscale_dataset = torch.stack(grayscale_dataset, dim=0)

  return grayscale_dataset

batch_sz=64 # this is batch size i.e. the number of rows in a batch of data
random_seed=2000
train_loader, valid_loader, test_loader = cifar10(batch_sz, random_seed)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


class Generator(nn.Module):
  """Generator model"""
  def __init__(self):
    super(Generator, self).__init__()

    # U-Net architecture

    self.encoding_unit1 = nn.Sequential(
        nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(),
    )

    self.encoding_unit2 = nn.Sequential(
        nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(),
        nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(),
    )

    self.encoding_unit3 = nn.Sequential(
        nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(),
        nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(),
    )

    self.encoding_unit4 = nn.Sequential(
        nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(),
        nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(),
    )

    self.encoding_unit5 = nn.Sequential(
        nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(1024),
        nn.LeakyReLU(),
        nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(1024),
        nn.LeakyReLU(),
    )

    self.decoding_unit1 = nn.Sequential(
        nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(),
        nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(),
    )

    self.decoding_unit2 = nn.Sequential(
        nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(),
        nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(),
    )

    self.decoding_unit3 = nn.Sequential(
        nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(),
        nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(),
    )

    self.decoding_unit4 = nn.Sequential(
        nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(),
    )
    self.decoding_unit5 = nn.Sequential(
        nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1, stride=1, padding=0),
    )

    self.upsampling1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, stride=2, padding=1, output_padding=1)
    self.upsampling2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1)
    self.upsampling3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1)
    self.upsampling4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1)

  def forward(self, x):

    x1 = self.encoding_unit1(x)
    x = F.max_pool2d(x1, (2,2))

    x2 = self.encoding_unit2(x)
    x = F.max_pool2d(x2, (2,2))

    x3 = self.encoding_unit3(x)
    x = F.max_pool2d(x3, (2,2))

    x4 = self.encoding_unit4(x)
    x = F.max_pool2d(x4, (2,2))

    x = self.encoding_unit5(x)

    x = self.upsampling1(x)
    x = torch.cat([x, x4], dim=1)
    x = self.decoding_unit1(x)

    x = self.upsampling2(x)
    x = torch.cat([x, x3], dim=1)
    x = self.decoding_unit2(x)

    x = self.upsampling3(x)
    x = torch.cat([x, x2], dim=1)
    x = self.decoding_unit3(x)

    x = self.upsampling4(x)
    x = torch.cat([x, x1], dim=1)
    x = self.decoding_unit4(x)

    x = self.decoding_unit5(x)

    x = torch.tanh(x)

    return x

class Discriminator(nn.Module):
  """Discriminator model"""
  def __init__(self):
    super(Discriminator, self).__init__()
    self.disc_model = nn.Sequential(
        nn.Conv2d(in_channels=4, out_channels=64, kernel_size=5, stride=2, padding=2),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(),

        nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(),

        nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(),

        nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(),


        nn.Flatten(1,-1),

    )

    self.layer_1 = nn.Linear(512 * 2 * 2, 256)
    self.layer_1_b = nn.BatchNorm1d(256)
    self.layer_1_a = nn.LeakyReLU()

    self.layer_2 = nn.Linear(256, 128)
    self.layer_2_b = nn.BatchNorm1d(128)
    self.layer_2_a = nn.LeakyReLU()

    self.layer_3 = nn.Linear(128, 1)
    self.layer_3_a = nn.Sigmoid()

  def forward(self, x, condition):

    x = torch.cat([x,condition], 1)

    x = self.disc_model(x)

    x = self.layer_1_a(self.layer_1_b(self.layer_1(x)))
    x = self.layer_2_a(self.layer_2_b(self.layer_2(x)))
    x = self.layer_3_a(self.layer_3(x))

    return x

def display_image_grid(images, labels, num_rows, num_cols, title_text):
  fig = plt.figure(figsize=(num_cols*3., num_rows*3.), )
  grid = ImageGrid(fig, 111, nrows_ncols=(num_rows, num_cols), axes_pad=0.15)

  for ax, im, l in zip(grid, images, labels):
    if im.size(0) == 1:
      if im.dtype == torch.float32 or im.dtype == torch.float64:
        ax.imshow(np.clip(im.cpu().permute(1,2,0).numpy(), 0, 1), cmap = 'gray')
      else:
        ax.imshow(np.clip(im.cpu().permute(1,2,0).numpy(), 0, 255), cmap = 'gray')
    else:
      if im.dtype == torch.float32 or im.dtype == torch.float64:
        ax.imshow(np.clip(im.cpu().permute(1,2,0).numpy(), 0, 1))
      else:
        ax.imshow(np.clip(im.cpu().permute(1,2,0).numpy(), 0, 255))

    ax.axis("off")
    ax.set_title(classes[l])

  plt.suptitle(title_text, fontsize=20)
  plt.tight_layout()
  plt.show()

args = {
    'BATCH_SIZE' : 64,
    'LEARNING_RATE' : 0.0005,
    'NUM_EPOCH' : 500,
    'SNRdB_list' : [30],
    'latent_dim' : 512,
    'input_size' : 32*32
}

transf = tr.Compose([tr.ToTensor()])

trainset = torchvision.datasets.CIFAR10(root = './data', train = True, download = True, transform = transf)
testset = torchvision.datasets.CIFAR10(root = './data', train = False, download = True, transform = transf)

trainloader = DataLoader(trainset, batch_size = args['BATCH_SIZE'], shuffle = True)
testloader  = DataLoader(testset, batch_size = args['BATCH_SIZE'], shuffle = True)

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        c_hid = 32

        self.encoder = nn.Sequential(
            nn.Conv2d(3, c_hid, kernel_size=3, padding=1, stride=2),  # 32x32 => 16x16
            nn.ReLU(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2),  # 16x16 => 8x8
            nn.ReLU(),
            nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2),  # 8x8 => 4x4
            nn.ReLU(),
            nn.Flatten(),  # Image grid to single feature vector
            nn.Linear(2 * 16 * c_hid, args['latent_dim'])
        )

    def forward(self, x):
        encoded = self.encoder(x)

        return encoded

class Decoder(nn.Module):

    def __init__(self):
        super(Decoder, self).__init__()
        c_hid = 32

        self.linear = nn.Sequential(
            nn.Linear(args['latent_dim'], 2 * 16 * c_hid),
            nn.ReLU()
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(2 * c_hid, 2 * c_hid, kernel_size=3, output_padding=1, padding=1, stride=2),  # 4x4 => 8x8
            nn.ReLU(),

            nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
            nn.ReLU(),

            nn.ConvTranspose2d(2 * c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2),  # 8x8 => 16x16
            nn.ReLU(),

            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            nn.ReLU(),

            nn.ConvTranspose2d(c_hid, 3, kernel_size=3, output_padding=1, padding=1, stride=2),  # 16x16 => 32x32
        )

    def forward(self, x):
        x = self.linear(x)
        x = x.reshape(x.shape[0], -1, 4, 4)
        decoded = self.decoder(x)

        return decoded

class Autoencoder(nn.Module):
    def __init__(
            self,
            encoder_class: object = Encoder,
            decoder_class: object = Decoder
    ):
        super(Autoencoder, self).__init__()

        self.encoder = encoder_class()
        self.decoder = decoder_class()


    def AWGN(self, input, SNRdB):

        normalized_tensor = f.normalize(input, dim=1)

        SNR = 10.0 ** (SNRdB / 10.0)
        K = args['latent_dim']
        std = 1 / math.sqrt(K * SNR)
        n = torch.normal(0, std, size=normalized_tensor.size()).to(device)

        return normalized_tensor + n



    def forward(self, x, SNRdB):

        encoded = self.encoder(x)
        channel_output = self.AWGN(encoded, SNRdB)
        decoded = self.decoder(channel_output)

        return decoded



device = "cuda"
def test_GAN(generator, discriminator, l1_reg):

    real_label = 1
    fake_label = 0

    model_location = "DeepJSCC(512)_30dB.pth"
    model = Autoencoder().to(device)
    model.load_state_dict(torch.load(model_location))

    generator.eval()
    discriminator.eval()
    loss_fn = nn.BCELoss()

    n_batch = 0
    total_gen_loss = 0.0
    total_disc_loss = 0.0
    num_samples = 0

    with torch.no_grad():
      for batch in test_loader:
        original_images, classes = batch
        inputs = original_images.to(device)
        classes = classes.to(device)
        grayscale_images = convert_to_grayscale(original_images)
        condition = grayscale_images.to(device)

        inputs_for_deepjscc = batch[0].to(device)
        outputs_for_deepjscc = model(inputs_for_deepjscc, SNRdB=30)

        inputs = (inputs - 0.5) * 2
        condition = (condition - 0.5) * 2

        label = torch.full((inputs.shape[0], 1), real_label, dtype=torch.float, device=device)
        output_real = discriminator(inputs, condition)

        fake = generator(condition)

        output_fake = discriminator(fake, condition)
        errD = loss_fn(output_real, label) + loss_fn(output_fake, label)

        label.fill_(real_label)
        output_fake = discriminator(fake, condition)
        errG = loss_fn(output_fake, label) + (l1_reg * torch.mean(torch.abs(fake - inputs)))

        total_disc_loss += errD.item()
        total_gen_loss += errG.item()
        num_samples += inputs.size(0)

        if (n_batch <= 10):
          print("Batch", n_batch+1)
          # Extract labels from the test_loader
          labels = classes.cpu().numpy()
          generated_batch = generator(condition).cpu()
          generated_batch = (generated_batch + 1) /2
          generated_batch = (generated_batch * 255).clamp(0, 255).to(torch.uint8)

          # Display original images
          display_image_grid(original_images, labels, 1, 10, "Original Images")

          # Display grayscale images
          display_image_grid(grayscale_images, labels, 1, 10, "Grayscale Images")

          # Display generated images
          display_image_grid(generated_batch.squeeze(1), labels, 1, 10, "Generated Images")

          display_image_grid(outputs_for_deepjscc,labels, 1, 10, "DeepJSCC(30dB)")

          n_batch+=1

    gen_test_loss = total_gen_loss / num_samples
    disc_test_loss = total_disc_loss / num_samples

    print("Generator test loss:", gen_test_loss)
    print("Discriminator test loss:", disc_test_loss)


generator_location = "Generator.pth"
generator = Generator().to(device)
generator.load_state_dict(torch.load(generator_location))

discriminator_location = "Discriminator.pth"
discriminator = Discriminator().to(device)
discriminator.load_state_dict(torch.load(discriminator_location))


test_GAN(generator, discriminator, 80)

 

'Research > Semantic Communication' 카테고리의 다른 글

JPEG의 header, tail  (0) 2023.12.06
temp_20231205  (0) 2023.12.05
JPEG temp  (0) 2023.12.05
JPEG+LDPC+16QAM python  (0) 2023.12.05
temp : JPEG+LDPG+16QAM  (0) 2023.12.04
Comments