UOMOP

iNet_4MLP_0.2DO 본문

Research/Semantic Communication

iNet_4MLP_0.2DO

Happy PinGu 2023. 11. 13. 14:01
import cv2
import math
import random
import torch
import torchvision
from fractions import Fraction
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as f
import matplotlib.pyplot as plt
import torchvision.transforms as tr
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

args = {
    'BATCH_SIZE': 64,
    'LEARNING_RATE': 0.0001,
    'NUM_EPOCH': 1000,
    'SNRdB_list': [0, 15, 30],
    'latent_dim': 512,
    'input_dim': 32 * 32
}

transf = tr.Compose([tr.ToTensor()])

trainset = torchvision.datasets.CIFAR10(root = './data', train = True, download = True, transform = transf)
testset = torchvision.datasets.CIFAR10(root = './data', train = False, download = True, transform = transf)

trainloader = DataLoader(trainset, batch_size = args['BATCH_SIZE'], shuffle = True)
testloader  = DataLoader(testset, batch_size = args['BATCH_SIZE'], shuffle = True)


class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()

        c_hid = 32

        self.encoder = nn.Sequential(
            nn.Conv2d(3, c_hid, kernel_size=3, padding=1, stride=2),  # 32x32 => 16x16
            nn.ReLU(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2),  # 16x16 => 8x8
            nn.ReLU(),
            nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2),  # 8x8 => 4x4
            nn.ReLU(),
            nn.Flatten(),  # Image grid to single feature vector
            nn.Linear(2 * 16 * c_hid, args['latent_dim'])
        )

    def forward(self, x):
        encoded = self.encoder(x)

        return encoded

class MLP(nn.Module) :
    def __init__(self) :
        super(MLP, self).__init__()

        self.mlp = nn.Sequential(
            nn.Linear(512, 700),
            nn.Linear(700, 600),
            nn.Linear(600, 512),
            nn.Linear(512, 512),
            nn.Dropout(0.2)

        )

    def forward(self, x) :

        output = self.mlp(x)

        return output

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        c_hid = 32  # This is the hidden layer size, you can adjust it if you want a wider network

        # We increase the network depth by adding more layers
        self.linear = nn.Sequential(
            nn.Linear(args['latent_dim'], 2 * 16 * c_hid),
            nn.ReLU()
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(2 * c_hid, 2 * c_hid, kernel_size=3, output_padding=1, padding=1, stride=2),
            nn.BatchNorm2d(2 * c_hid),
            nn.ReLU(),

            nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
            nn.BatchNorm2d(2 * c_hid),
            nn.ReLU(),

            # We add extra layers here
            nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
            nn.BatchNorm2d(2 * c_hid),
            nn.ReLU(),

            nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
            nn.BatchNorm2d(2 * c_hid),
            nn.ReLU(),

            nn.ConvTranspose2d(2 * c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2),
            nn.BatchNorm2d(c_hid),
            nn.ReLU(),

            # And more layers here
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            nn.BatchNorm2d(c_hid),
            nn.ReLU(),

            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            nn.BatchNorm2d(c_hid),
            nn.ReLU(),

            nn.ConvTranspose2d(c_hid, 3, kernel_size=3, output_padding=1, padding=1, stride=2),
            # Here we don't use BatchNorm or ReLU since this is the output layer
        )

    def forward(self, x):
        x = self.linear(x)
        x = x.reshape(x.shape[0], -1, 4, 4)
        decoded = self.decoder(x)

        return decoded

class Autoencoder(nn.Module):
    def __init__(
            self,
            encoder_class: object = Encoder,
            decoder_class: object = Decoder,
            mlp_class    : object = MLP
    ):
        super(Autoencoder, self).__init__()

        self.encoder = encoder_class()
        self.decoder = decoder_class()
        self.mlp     = mlp_class()


    def AWGN(self, input, SNRdB):

        normalized_tensor = f.normalize(input, dim=1)

        SNR = 10.0 ** (SNRdB / 10.0)
        K = args['latent_dim']
        std = 1 / math.sqrt(K * SNR)
        n = torch.normal(0, std, size=normalized_tensor.size()).to(device)

        return normalized_tensor + n


    def forward(self, x, SNRdB, iteration):

        encoded = self.encoder(x)

        mlped = self.mlp(encoded)

        for i in range(iteration-1) :
            mlped = self.mlp(mlped)

        channel_output = self.AWGN(mlped, SNRdB)
        decoded = self.decoder(channel_output)

        return decoded


def model_train(trainloader, testloader):
    for i in range(len(args['SNRdB_list'])):

        model = Autoencoder().to(device)
        criterion = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr=args['LEARNING_RATE'])

        print("+++++ SNR = {} Training Start! +++++\t".format(args['SNRdB_list'][i]))

        for epoch in range(args['NUM_EPOCH']):

            print("[{} Epoch]\n".format(epoch + 1))

            # ========================================== Train ==========================================
            train_loss = 0.0

            model.train()

            for data in trainloader:
                inputs = data[0].to(device)
                optimizer.zero_grad()

                iteration = random.randint(1, 4)
                outputs = model(inputs, SNRdB=args['SNRdB_list'][i], iteration=iteration)
                loss = criterion(inputs, outputs)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()

            train_cost = train_loss / len(trainloader)
            print("Train PSNR : {}".format(round(10 * math.log10(1.0 / train_cost), 4)))
            # ========================================== Test ==========================================
            test_loss = 0.0
            psnr_loss = [0.0, 0.0, 0.0, 0.0]
            iter_counter = [0, 0, 0, 0]

            model.eval()

            for data in testloader:
                inputs = data[0].to(device)
                with torch.no_grad():
                    # optimizer.zero_grad()
                    iteration = random.randint(1, 4)
                    outputs = model(inputs, SNRdB=args['SNRdB_list'][i], iteration=iteration)
                    for k in range(4):
                        if iteration == int(k + 1):
                            psnr_loss[k] += criterion(inputs, outputs).item()
                            iter_counter[k] += 1

            mse_iter_1 = psnr_loss[0] / iter_counter[0]
            mse_iter_2 = psnr_loss[1] / iter_counter[1]
            mse_iter_3 = psnr_loss[2] / iter_counter[2]
            mse_iter_4 = psnr_loss[3] / iter_counter[3]
            mse_average = round(((mse_iter_1 + mse_iter_2 + mse_iter_3 + mse_iter_4) / 4), 4)

            print("Val PSNR(iter={})  : {}(counter = {}/{})".format(1, round(10 * math.log10(1.0 / mse_iter_1), 4),
                                                                   iter_counter[0], sum(iter_counter)))
            print("Val PSNR(iter={})  : {}(counter = {}/{})".format(2, round(10 * math.log10(1.0 / mse_iter_2), 4),
                                                                   iter_counter[1], sum(iter_counter)))
            print("Val PSNR(iter={})  : {}(counter = {}/{})".format(3, round(10 * math.log10(1.0 / mse_iter_3), 4),
                                                                   iter_counter[2], sum(iter_counter)))
            print("Val PSNR(iter={})  : {}(counter = {}/{})".format(4, round(10 * math.log10(1.0 / mse_iter_4), 4),
                                                                   iter_counter[3], sum(iter_counter)))
            print("Val PSNR(average) : {}".format(round(10 * math.log10(1.0 / mse_average), 4)))
            print("\n\n")

        torch.save(model.state_dict(), "DeepJSCC_" + str(args['SNRdB_list'][i]) + "dB.pth")

model_train(trainloader, testloader)

 

 

Comments