UOMOP

이거 성능 왜 높게 나옴? 본문

Wireless Comm./CISL

이거 성능 왜 높게 나옴?

Happy PinGu 2023. 8. 7. 17:24
import cv2
import math
import time
import torch
import random

import torchvision
import numpy as np
from PIL import Image
import torch.nn as nn
from numpy import sqrt
import torch.optim as optim
import torch.nn.functional as f
import matplotlib.pyplot as plt
import torchvision.transforms as tr
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset

def sobel_filter(img, x_or_y) :

    if x_or_y == 'x' :
        return cv2.Sobel(img, -1, 1, 0)
    else :
        return cv2.Sobel(img, -1, 0, 1)

def scharr_filter(img, x_or_y) :
    if x_or_y == 'x' :
        return cv2.Scharr(img, -1, 1, 0)
    else :
        return cv2.Scharr(img, -1, 0, 1)

def average_filter(img, kernel_size) :

    return cv2.blur(img, (kernel_size, kernel_size))

def gaussian_filter(img, kernel_size) :

    # cv2.GaussianBlur의 parameter에는 sigma가 존재하는데, 이것은 가우시안 커널의 X, Y 방향의 표준편차이다.
    return cv2.GaussianBlur(img, (kernel_size, kernel_size), 0)

def median_filter(img, kernel_size) :

    return cv2.medianBlur(img, kernel_size)

def bilateral_filter(img, kernel_size) :

    return cv2.bilateralFilter(img, kernel_size, 75, 75)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


args = {
    'BATCH_SIZE' : 32,
    'LEARNING_RATE' : 0.001,
    'NUM_EPOCH' : 200,
    'SNRdB_list' : [1, 10, 20, 30],
    'pi_dim' : 512,
    'si_dim' : 0,
    'input_dim' : 32 * 32,
    'filter_type' : 'scharr',
    'filter_dir' : 'x',
    'es_pat' : 10,
    'rl_pat' : 10
}


transf = tr.Compose([tr.ToTensor()])
# 원래 pytorch cifar10은 0~1사이의 값을 가진다.
# -1~1로 정규화를 시켜준다.

trainset = torchvision.datasets.CIFAR10(root = './data', train = True, download = True, transform = transf)
testset = torchvision.datasets.CIFAR10(root = './data', train = False, download = True, transform = transf)

trainloader = DataLoader(trainset, batch_size = args['BATCH_SIZE'], shuffle = True)
testloader  = DataLoader(testset, batch_size = args['BATCH_SIZE'], shuffle = True)

def spectrum(img) :

    f = np.fft.fft2(img)
    fshift = np.fft.fftshift(f)
    magnitude_fshift = np.log(np.abs(fshift) + 1)

    return magnitude_fshift


def psnr(SNRdB_list, model_name_without_SNRdB, testloader) :

    psnr_list = []

    for i in range(len(SNRdB_list)) :

        SNRdB = SNRdB_list[i]
        model_name = model_name_without_SNRdB + "_pi=" + str(args['pi_dim']) +"_" + "si=" + str(args['si_dim']) + "_SNR=" + str(SNRdB) + ".pth"
        model_1 = Autoencoder().to(device)
        model_1.load_state_dict(torch.load(model_name))

        psnr_culmi = 0.0

        for data in testloader :

            inputs = data[0].to(device)
            outputs = model_1(inputs, SNRdB = SNRdB)

            for j in range(len(data)) :
                psnr_culmi += cv2.PSNR(inputs[j].detach().cpu().numpy(), outputs[j].detach().cpu().numpy(), 2)

        psnr_list.append(round(psnr_culmi / (len(testloader) * len(data)), 3))

    return psnr_list


def model_train(SNRdB, learning_rate, epoch_num, trainloader) :

    check_num = 1

    model = Autoencoder().to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr = learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=args['rl_pat'], threshold=1e-3, verbose = True)

    print("\n\n+++++ SNR = {} Training Start! +++++\t".format(SNRdB))

    for epoch in range(epoch_num) :

        train_psnr = 0.0
        val_psnr  = 0.0

        ## Train dataset

        for data in trainloader :

            inputs = data[0].to(device)

            optimizer.zero_grad()
            outputs = model(inputs, SNRdB = SNRdB)

            loss = criterion(inputs, outputs)
            loss.backward()
            optimizer.step()

            for j in range(len(data)) :
                  train_psnr += cv2.PSNR(inputs[j].detach().cpu().numpy(), outputs[j].detach().cpu().numpy(), 2)

        train_psnr_fin = round(train_psnr / (len(trainloader) * len(data)), 3)

        ## Test dataset

        for val_data in testloader :
            inputs = val_data[0].to(device)
            outputs = model( inputs, SNRdB = SNRdB)


            for j in range(len(val_data)) :
                  val_psnr += cv2.PSNR(inputs[j].detach().cpu().numpy(), outputs[j].detach().cpu().numpy(), 2)

        val_psnr_fin = round(val_psnr / (len(testloader) * len(val_data)), 3)

## Early stopping
        '''
        if epoch % check_num == 0:

            es(val_psnr_fin, model, SNRdB)
            if es.early_stop:
                print(epoch, loss.item())
                break
        '''
        scheduler.step(val_psnr)

        print("[{} Epoch] train_PSNR : {}\tvalidation_PSNR : {}".format(epoch + 1, train_psnr_fin, val_psnr_fin))

    torch.save(model.state_dict(), "./AWGN" + "_pi=" + str(args['pi_dim']) +"_" + "si=" + str(args['si_dim']) + "_SNR=" + str(SNRdB) +".pth")


class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        c_hid = 32

        self.encoder_pi = nn.Sequential(
            nn.Conv2d(3, c_hid, kernel_size=3, padding=1, stride=2),  # 32x32 => 16x16
            nn.ReLU(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2),  # 16x16 => 8x8
            nn.ReLU(),
            nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2),  # 8x8 => 4x4
            nn.ReLU(),
            nn.Flatten(),  # Image grid to single feature vector
            nn.Linear(2 * 16 * c_hid, args['pi_dim'])

        )

        self.encoder_si = nn.Sequential(
            nn.Linear(args['input_dim'], 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, args['si_dim'])

        )

    def HPF(self, x, filter_type, filter_dir):

        # print("x shape : {}".format(x.size()))

        gray_transform = tr.Grayscale()

        save_list = []

        # x = x.view(args['BATCH_SIZE'], 1, 32, 32)

        # print("reshaped x's shape : {}".format(x.size()))

        for i in range(x.shape[0]):
            # print("123245")
            # print(x[i].size())

            gray_img = cv2.cvtColor(x[i].permute(1, 2, 0).detach().cpu().numpy(), cv2.COLOR_BGR2GRAY)

            # print("x[i] shape : {}".format(gray_img.shape))

            if filter_type == 'sobel':
                if filter_dir == 'x':

                    save_list.append(sobel_filter(gray_img, 'x'))
                elif filter_dir == 'y':
                    save_list.append(sobel_filter(gray_img, 'y'))
                else:
                    print("Type Error(filter direction)")

            elif filter_type == 'scharr':
                if filter_dir == 'x':
                    save_list.append(scharr_filter(gray_img, 'x'))
                elif filter_dir == 'y':
                    save_list.append(scharr_filter(gray_img, 'y'))
                else:
                    print("Type Error(filter direction)")

            else:
                print("Type Error(filter type)")

        # print("save_list's length : {}".format(len(save_list)))

        save_arr = np.array(save_list).reshape(x.shape[0], 1, 32, 32)
        # print("save_arr shape : {}".format(save_arr.size))

        save_tensor = torch.Tensor(save_arr)

        # print("save_tensor's size : {}".format(save_tensor.size()))

        return save_tensor

    def forward(self, x, filter_type, filter_dir):

        HPFed_x = self.HPF(x, args['filter_type'], args['filter_dir']).to(device)

        # print("x's size : {}".format(x.size()))
        # print("HPFed_x's size : {}".format(HPFed_x.size()))

        encoded_pi = self.encoder_pi(x)
        # print("encoded_pi size : {}".format(encoded_pi.size()))
        encoded_si = self.encoder_si(HPFed_x.view(-1, args['input_dim']).to(torch.float32))
        # print("encoded_si size : {}".format(encoded_si.size()))

        return encoded_pi, encoded_si


class EarlyStopping:

    def __init__(self, patience=7, verbose=False, delta=0):

        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.psnr_max = float(0)

        self.delta = delta

    def __call__(self, psnr, model, SNRdB):

        score = psnr

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(psnr, model, SNRdB)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print("EarlyStopping counter : {} out of {}".format(self.counter, self.patience))
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(psnr, model, SNRdB)
            self.counter = 0



    def save_checkpoint(self, psnr, model, SNRdB):
        '''Saves model when validation loss decrease.'''
        #if self.verbose:

            #print("Validation PSNR increase({} --> {})".format(round(self.psnr_max, 3), round(psnr, 3)))
        torch.save(model.state_dict(), "./AWGN" + "_pi=" + str(args['pi_dim']) +"_" + "si=" + str(args['si_dim']) + "_SNR=" + str(SNRdB) +".pth")
        self.psnr_max = psnr



class Decoder(nn.Module):

    def __init__(self):
        super(Decoder, self).__init__()
        c_hid = 32

        self.linear = nn.Sequential(
            nn.Linear(args['pi_dim'] + args['si_dim'], 2 * 16 * c_hid),
            nn.ReLU()
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(2 * c_hid, 2 * c_hid, kernel_size=3, output_padding=1, padding=1, stride=2),
            # 4x4 => 8x8
            nn.ReLU(),
            nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
            nn.ReLU(),

            # nn.ConvTranspose2d(2 * c_hid, 2 * c_hid, kernel_size=3, output_padding=0, padding=1, stride=1),  # 8x8 => 8x8
            # nn.ReLU(),
            # nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
            # nn.ReLU(),

            nn.ConvTranspose2d(2 * c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2),  # 8x8 => 16x16
            nn.ReLU(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            nn.ReLU(),

            # nn.ConvTranspose2d(c_hid, c_hid, kernel_size=3, output_padding=0, padding=1, stride=1),  # 16x16 => 16x16
            # nn.ReLU(),
            # nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            # nn.ReLU(),

            nn.ConvTranspose2d(c_hid, 3, kernel_size=3, output_padding=1, padding=1, stride=2),  # 16x16 => 32x32
        )

    def forward(self, x):
        x = self.linear(x)
        # print("before : {}".format(x.size()))
        x = x.reshape(x.shape[0], -1, 4, 4)
        # print("after : {}".format(x.size()))
        decoded = self.decoder(x)
        # print(decoded.size())
        # print("111")

        return decoded


class Autoencoder(nn.Module):
    def __init__(
            self,
            encoder_class: object = Encoder,
            decoder_class: object = Decoder
    ):
        super(Autoencoder, self).__init__()

        self.encoder = encoder_class()
        self.decoder = decoder_class()

    def AWGN(self, input, SNRdB):

        normalized_tensor = f.normalize(input, dim=1)

        SNR = 10.0 ** (SNRdB / 10.0)
        K = args['pi_dim'] + args['si_dim']
        std = 1 / math.sqrt(K * SNR)
        n = torch.normal(0, std, size=normalized_tensor.size()).to(device)

        return normalized_tensor + n


    def L2_Normalization(self, input):
        norm_2 = torch.norm(input, p=2, dim=1)

        out_list = []

        for i in range(len(input)):
            out_list.append((norm_2[i] * input[i]).tolist())

        return torch.Tensor(out_list)

    def forward(self, x, SNRdB):
        encoded_pi, encoded_si = self.encoder(x, args['filter_type'], args['filter_dir'])
        # print("encoded_pi : {}".format(encoded_pi[0]))
        # print("encoded_pi size : {}".format(encoded_pi.size()))

        # print("encoded_si : {}".format(encoded_si[0]))
        # print("encoded_si size : {}".format(encoded_si.size()))

        Tx_output = torch.cat([encoded_pi, encoded_si], dim=1)
        # print("Tx_output : {}".format(Tx_output[0]))
        # print("Tx_output size : {}".format(Tx_output.size()))

        Rx_input = self.AWGN(Tx_output, SNRdB)
        # print("encoded_AWGN size : {}".format(Rx_input.size()))

        decoded = self.decoder(Rx_input)
        # print("decoded : {}".format(decoded[0]))
        # print("decoded size : {}".format(decoded.size()))

        return decoded

SNRdB_list = args['SNRdB_list']
learning_rate = args['LEARNING_RATE']
epoch_num = args['NUM_EPOCH']
trainloader = trainloader
model_name_without_SNRdB = 'AWGN'


for i in range( len(args['SNRdB_list'])) :
    #es = EarlyStopping(patience=args['es_pat'], verbose=True, delta=0)
    model_train(args['SNRdB_list'][i], learning_rate, epoch_num, trainloader)

import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
psnr_list = psnr(SNRdB_list, model_name_without_SNRdB, testloader)
plt.plot(SNRdB_list, psnr_list, linestyle = 'dashed', color = 'blue', label = "AWGN")
plt.grid(True)
plt.legend()
plt.ylim([10, 40])
plt.show()

print(psnr_list)

'Wireless Comm. > CISL' 카테고리의 다른 글

Custom Dataset(Gaussian Filtered Cifar 10)  (0) 2023.08.04
AWGN_cifar10_20230804  (0) 2023.08.04
DeepJSCC  (0) 2023.08.02
AWGN(pi=1024 si=0)  (0) 2023.07.28
AWGN(pi=64, si=0)  (0) 2023.07.27
Comments