UOMOP

send img with contour 본문

Wireless Comm./Python

send img with contour

Happy PinGu 2023. 6. 2. 02:32
import cv2
import math
import time
import torch
import random

import torchvision
import numpy as np
from PIL import Image
import torch.nn as nn
from numpy import sqrt
from tqdm import trange
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision.transforms as tr
from torchvision import datasets
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Dataset

transf = tr.Compose([tr.ToTensor(), tr.Grayscale(num_output_channels=1) ])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transf) #위에 전처리
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transf)

trainloader = DataLoader(trainset, batch_size=50, shuffle=True)
testloader = DataLoader(testset, batch_size=50, shuffle=False)

def cal_D(c_row, c_col, r, c) :
    s = (c_row-r)**2+ (c_col-c)**2
    return s**(1/2)

def filter_radius(fshift, rad, low = True) :
    rows, cols = fshift.shape
    c_row, c_col = int(rows/2), int(cols/2)

    filter_fshift = fshift.copy()

    for r in range(rows) :
        for c in range(cols) :
            if low :
                if cal_D(c_row, c_col, r, c) > rad :
                    filter_fshift[r, c] = 0

            else :
                if cal_D(c_row, c_col, r, c) < rad :
                    filter_fshift[r, c] = 0

    return filter_fshift
    
    
args = {
    'BATCH_SIZE' : 50,
    'LEARNING_RATE' : 0.001,
    'NUM_EPOCH' : 50,
    'SNR_dB' : [1, 10, 20], 
    'latent_dim' : 50,
    'input_dim' : 32*32,
    'testdata_num' : 10
}



        
        
class Encoder(nn.Module) :
    def __init__(self, filtering_size: int) :
        super(Encoder, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(args['input_dim'], 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, args['latent_dim']),
            nn.ReLU()
        )


    def high_freq(self, x, filtering_size) :

        save_list = []
        #print("x.shape : {}".format(x.shape))

        x = x.view(args['BATCH_SIZE'], 1, 32, 32)

        for i in range(x.shape[0]) :

            f = np.fft.fft2(x[i][0])
            fshift = np.fft.fftshift(f)
            high_fshift = filter_radius(fshift, rad = filtering_size, low = False)
            high_ishift = np.fft.ifftshift(high_fshift)
            high_img = np.fft.ifft2(high_ishift)
            high_img = np.abs(high_img)

            save_list.append(high_img)

        save_arr = np.array(save_list).reshape(x.shape[0], 1, 32, 32)
        save_tensor = torch.tensor(save_arr)

        return save_tensor
        

    def forward(self, x, filtering_size) :

        #print(x.shape)
        
        return self.encoder(x), self.high_freq(x, filtering_size)
        
        
class Encoder2(nn.Module) :
    def __init__(self) :
        super(Encoder2, self).__init__()
        
        self.encoder2 = nn.Sequential(
            nn.Linear(args['input_dim'], 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, args['latent_dim']),
            nn.ReLU()
        )

    def forward(self, x) :
        x = x.view(-1, args['input_dim'])
        x = x.to(torch.float32)
        
        return self.encoder2(x)
        
        
        
class Decoder(nn.Module) :
    def __init__(self) :
        super(Decoder, self).__init__()
        
        self.decoder = nn.Sequential(
            nn.Linear(args['latent_dim']*2, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, args['input_dim']),
            nn.Sigmoid()
        )

    def forward(self, x) :
        
        return self.decoder(x)
        
class Autoencoder(nn.Module) :
    def __init__(
        self,
        filtering_size : int,
        encoder_class : object = Encoder,
        encoder2_class : object = Encoder2,
        decoder_class : object = Decoder
        ) :

        super(Autoencoder, self).__init__()

        self.encoder = encoder_class(filtering_size)
        self.encoder2 = encoder2_class()
        self.decoder = decoder_class()

    def forward(self, x, filtering_size) :
        #print(1)
        after_encoder, high_freq = self.encoder(x, filtering_size)
        #print("after_encoder shape : {}".format(after_encoder.size()))
        #print("high_freq shape : {}".format(high_freq.size()))
        #print(2)
        high_freq_after_encoder2 = self.encoder2(high_freq)

        #print(3)
        output = torch.cat([after_encoder, high_freq_after_encoder2], dim=1)
        final = self.decoder(output)

        

        return final

for i in range( len(args['SNR_dB']) ):

    model = Autoencoder(filtering_size = 5)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr = args['LEARNING_RATE'])

    print("+++++ SNR = {} Training Start! +++++".format(args['SNR_dB'][i]))

    for epoch in range(args['NUM_EPOCH']) :

        running_loss = 0.0

        for data in trainloader :

            inputs = data[0]
            optimizer.zero_grad()
            #print(type(inputs))
            #print(inputs.shape)
            #print(inputs.view(-1, args['input_dim']).shape)
            outputs = model( inputs.view(-1, args['input_dim']), 5 )
            outputs = outputs.view(-1, 1, 32, 32)
            loss = criterion(inputs, outputs)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        cost = running_loss/len(trainloader)
        print("[{}] Loss : {}".format(epoch + 1, round(cost, 6)))

Comments