UOMOP

send img with contour : 20230603 본문

Wireless Comm./Python

send img with contour : 20230603

Happy PinGu 2023. 6. 3. 01:56
import cv2
import math
import time
import torch
import random

import torchvision
import numpy as np
from PIL import Image
import torch.nn as nn
from numpy import sqrt
from tqdm import trange
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision.transforms as tr
from torchvision import datasets
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Dataset

args = {
    'BATCH_SIZE' : 50,
    'LEARNING_RATE' : 0.001,
    'NUM_EPOCH' : 20,
    'SNR_dB' : [1, 10, 20], 
    'latent_dim' : 100,
    'input_dim' : 32*32,
    'comp_ratio' : 0.7,
    'filter_size' : [1, 3, 5, 7, 9, 11, 13]
}

transf = tr.Compose([tr.ToTensor(), tr.Grayscale(num_output_channels = 1) ])

trainset = torchvision.datasets.CIFAR10(root = './data', train = True, download = True, transform = transf) 
testset = torchvision.datasets.CIFAR10(root = './data', train = False, download = True, transform = transf)

trainloader = DataLoader(trainset, batch_size = args['BATCH_SIZE'], shuffle = True)
testloader  = DataLoader(testset, batch_size = args['BATCH_SIZE'], shuffle = False)

def cal_D(c_row, c_col, r, c) :
    s = (c_row-r)**2+ (c_col-c)**2
    return s**(1/2)

def filter_radius(fshift, rad, low = True) :
    rows, cols = fshift.shape
    c_row, c_col = int(rows/2), int(cols/2)

    filter_fshift = fshift.copy()

    for r in range(rows) :
        for c in range(cols) :
            if low :
                if cal_D(c_row, c_col, r, c) > rad :
                    filter_fshift[r, c] = 0

            else :
                if cal_D(c_row, c_col, r, c) < rad :
                    filter_fshift[r, c] = 0

    return filter_fshift
    
class Encoder1(nn.Module) :
    def __init__(self, filtering_size: int) :
        super(Encoder1, self).__init__()
        
        self.encoder1 = nn.Sequential(
            nn.Linear(args['input_dim'], 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, int(args['latent_dim'] * args['comp_ratio'])),
            nn.ReLU()
        )


    def high_freq(self, x, filtering_size) :

        save_list = []
        #print("x.shape : {}".format(x.shape))

        x = x.view(args['BATCH_SIZE'], 1, 32, 32)

        for i in range(x.shape[0]) :

            f = np.fft.fft2(x[i][0])
            fshift = np.fft.fftshift(f)
            high_fshift = filter_radius(fshift, rad = filtering_size, low = False)
            high_ishift = np.fft.ifftshift(high_fshift)
            high_img = np.fft.ifft2(high_ishift)
            high_img = np.abs(high_img)
            save_list.append(high_img)

        save_arr = np.array(save_list).reshape(x.shape[0], 1, 32, 32)
        save_tensor = torch.tensor(save_arr)

        return save_tensor
        

    def forward(self, x, filtering_size) :

        #print(x.shape)
        
        return self.encoder1(x), self.high_freq(x, filtering_size)



class Encoder2(nn.Module) :
    def __init__(self) :
        super(Encoder2, self).__init__()
        
        self.encoder2 = nn.Sequential(
            nn.Linear(args['input_dim'], 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, int(args['latent_dim'] * (1 - args['comp_ratio']))),
            nn.ReLU()
        )

    def forward(self, x) :

        x = x.view(-1, args['input_dim'])
        x = x.to(torch.float32)
        
        return self.encoder2(x)



class Decoder(nn.Module) :
    def __init__(self) :
        super(Decoder, self).__init__()
        
        self.decoder = nn.Sequential(
            nn.Linear(args['latent_dim'], 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, args['input_dim']),
            nn.Sigmoid()
        )

    def forward(self, x) :
        
        return self.decoder(x)



class Autoencoder(nn.Module) :
    def __init__(
        self,
        filtering_size : int,
        encoder1_class : object = Encoder1,
        encoder2_class : object = Encoder2,
        decoder_class : object = Decoder
        ) :

        super(Autoencoder, self).__init__()

        self.encoder1 = encoder1_class(filtering_size)
        self.encoder2 = encoder2_class()
        self.decoder = decoder_class()


    def addAWGN_latent(self, encoded, SNR_dB) :

        batch_size = encoded.shape[0]
        latent_dim = encoded.shape[1]

        encoded = encoded.flatten().tolist()
        complex_list = list()

        for i in range(int(len(encoded)/2)) :
            complex_list.append( complex(encoded[2*i], encoded[2*i+1]) )

        target_snr_db = SNR_dB

        sum_abs = 0
        abs_list = list()
        noise_out = list()

        for i in range(len(complex_list)):
            sum_abs = sum_abs + abs(complex_list[i])

        for i in range(len(complex_list)):
            abs_list.append(abs(complex_list[i]))

        SNR_linear = 10 ** (SNR_dB / 10)

        complex_list_power = np.mean(abs_list)

        N0 = complex_list_power / SNR_linear

        for i in range(len(complex_list)):
            noise = sqrt(N0 / 2) * complex(random.gauss(0, 1), random.gauss(0, 1))
            noise_out.append(complex_list[i] + noise)

        for_tensor_list = list()

        for i in range( int(len(noise_out)) ) :
            for_tensor_list.append( noise_out[i].real )
            for_tensor_list.append( noise_out[i].imag )

        encoded_AWGN = torch.tensor(for_tensor_list).view(batch_size, latent_dim).type(torch.float32) 

        return encoded_AWGN


    def forward(self, x, filtering_size, SNRdB) :
        #print(1)
        after_encoder, high_freq = self.encoder1(x, filtering_size)
        #print("after_encoder shape : {}".format(after_encoder.size()))
        #print("high_freq shape : {}".format(high_freq.size()))
        #print(2)
        high_freq_after_encoder2 = self.encoder2(high_freq)
        #print("high_freq_after_encoder2 shape : {}".format(high_freq_after_encoder2.size()))

        #print(3)
        output = torch.cat([after_encoder, high_freq_after_encoder2], dim=1)
        #print("output shape : {}".format(output.size()))
        output_AWGN = self.addAWGN_latent(output, SNRdB)
        final = self.decoder(output_AWGN)

        return final
        

for j in range(len(args['filter_size'])):

    print("===== Filter_size = {} =====".format(args['filter_size'][j]))

    for i in range( len(args['SNR_dB']) ):

        model = Autoencoder(filtering_size = args['filter_size'][j])
        criterion = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr = args['LEARNING_RATE'])

        print("+++++ SNR = {} Training Start! +++++".format(args['SNR_dB'][i]))

        for epoch in range(args['NUM_EPOCH']) :

            running_loss = 0.0

            for data in trainloader :

                inputs = data[0]
                optimizer.zero_grad()
                #print(type(inputs))
                #print(inputs.shape)
                #print(inputs.view(-1, args['input_dim']).shape)
                outputs = model( inputs.view(-1, args['input_dim']), args['filter_size'][j] , i)
                outputs = outputs.view(-1, 1, 32, 32)
                loss = criterion(inputs, outputs)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()

            cost = running_loss/len(trainloader)
            print("[{} Epoch] Loss : {}".format(epoch + 1, round(cost, 6)))
        print()

 

'Wireless Comm. > Python' 카테고리의 다른 글

***Cifar10_AE(color)_20230608  (0) 2023.06.08
Traditional AutoEncoder Cifar10(gray) 2023 06 07  (0) 2023.06.07
send img with contour  (0) 2023.06.02
trainloader image 확인  (0) 2023.06.01
New : add noise to image using Norm2  (0) 2023.05.30
Comments