UOMOP

Level decision model's MSE training (x+x_f, x, x_f) 본문

DE/Code

Level decision model's MSE training (x+x_f, x, x_f)

Happy PinGu 2024. 8. 27. 23:01

원본 +  fft

import os
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim
import pickle
from sklearn.preprocessing import StandardScaler
from torchvision import transforms
from torch.utils.data import DataLoader
from InPainting import Loader_maker_for_InPainting
from tqdm import tqdm
from utils import *

# 모델 파일 로드
dim = 1024
snr = 40
model_files = [f for f in os.listdir('inpaint_model') if f.startswith(f'InPaint(DIM={dim}_SNR={snr}')]
if not model_files:
    raise FileNotFoundError(f"No model found for DIM={dim} and SNR={snr}")
model_path = os.path.join('inpaint_model', model_files[-1])  # 가장 최신 모델 사용
model = torch.load(model_path)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# 푸리에 변환 함수
def compute_fft(image):
    fft_image = np.fft.fft2(image, axes=(0, 1))
    fft_image_shifted = np.fft.fftshift(fft_image, axes=(0, 1))
    magnitude_spectrum = np.abs(fft_image_shifted)
    return np.log(magnitude_spectrum + 1)


# Edge 검출량 계산 함수
from scipy.ndimage import sobel


def calculate_edge_amount(image):
    dx = sobel(image, axis=0)
    dy = sobel(image, axis=1)
    edge_magnitude = np.hypot(dx, dy)
    return edge_magnitude.mean()


import cv2


def calculate_canny_edge_amount(image):
    if image.dtype != np.uint8:
        image = (image * 255).astype(np.uint8)
    edges = cv2.Canny(image, 100, 200)
    return edges.mean()


# Custom Dataset 정의
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        original_images = sample[0]
        psnr_0 = sample[1]
        psnr_33 = sample[2]
        psnr_60 = sample[3]
        psnr_75 = sample[4]

        # Compute Fourier Transform of the image
        x_f = compute_fft(original_images.numpy())
        x_f = torch.tensor(x_f, dtype=torch.float32)



        return {
            'original_images': original_images,
            'x_f': x_f,  # Fourier transformed image
            'psnr_values': torch.tensor([psnr_0, psnr_33, psnr_60, psnr_75], dtype=torch.float32)
        }


# PSNR 계산 함수
def psnr(original, recon):
    mse = np.mean((original - recon) ** 2)
    if mse == 0:
        return 100
    max_pixel = 1.0
    psnr_value = 20 * np.log10(max_pixel / np.sqrt(mse))
    return psnr_value


# 데이터 로드 및 전처리
train_data_file = 'prepared_traindata.pkl'
test_data_file = 'prepared_testdata.pkl'

if os.path.exists(train_data_file) and os.path.exists(test_data_file):
    with open(train_data_file, 'rb') as f:
        prepared_traindata = pickle.load(f)
    with open(test_data_file, 'rb') as f:
        prepared_testdata = pickle.load(f)
    print("Data loaded from files.")
else:
    masked_train_path = f"Masked_Train/{dim}/{snr}"
    trainset = Loader_maker_for_InPainting(root_dir=masked_train_path)
    trainloader = DataLoader(trainset, batch_size=1, shuffle=True)

    masked_test_path = f"Masked_Test/{dim}/{snr}"
    testset = Loader_maker_for_InPainting(root_dir=masked_test_path)
    testloader = DataLoader(testset, batch_size=1, shuffle=False)

    prepared_traindata = []
    prepared_testdata = []

    with torch.no_grad():
        for i, data in tqdm(enumerate(trainloader), total=len(trainloader)):
            original_images = data['original_images'].squeeze()
            #print(original_images.shape)
            recon_0 = data['recon_0'].squeeze(1).to(device)
            recon_masked_33 = data['recon_masked_33'].squeeze(1).to(device)
            recon_masked_60 = data['recon_masked_60'].squeeze(1).to(device)
            recon_masked_75 = data['recon_masked_75'].squeeze(1).to(device)

            outputs_0 = model(recon_0)
            outputs_33 = model(recon_masked_33)
            outputs_60 = model(recon_masked_60)
            outputs_75 = model(recon_masked_75)

            psnr_0 = psnr(original_images.squeeze().cpu().numpy(), outputs_0.squeeze().cpu().numpy())
            psnr_33 = psnr(original_images.squeeze().cpu().numpy(), outputs_33.squeeze().cpu().numpy())
            psnr_60 = psnr(original_images.squeeze().cpu().numpy(), outputs_60.squeeze().cpu().numpy())
            psnr_75 = psnr(original_images.squeeze().cpu().numpy(), outputs_75.squeeze().cpu().numpy())

            prepared_traindata.append(
                [original_images.cpu(), psnr_0, psnr_33, psnr_60, psnr_75])

        for i, data in tqdm(enumerate(testloader), total=len(testloader)):
            original_images = data['original_images'].squeeze()
            recon_0 = data['recon_0'].squeeze(1).to(device)
            recon_masked_33 = data['recon_masked_33'].squeeze(1).to(device)
            recon_masked_60 = data['recon_masked_60'].squeeze(1).to(device)
            recon_masked_75 = data['recon_masked_75'].squeeze(1).to(device)

            outputs_0 = model(recon_0)
            outputs_33 = model(recon_masked_33)
            outputs_60 = model(recon_masked_60)
            outputs_75 = model(recon_masked_75)

            psnr_0 = psnr(original_images.squeeze().cpu().numpy(), outputs_0.squeeze().cpu().numpy())
            psnr_33 = psnr(original_images.squeeze().cpu().numpy(), outputs_33.squeeze().cpu().numpy())
            psnr_60 = psnr(original_images.squeeze().cpu().numpy(), outputs_60.squeeze().cpu().numpy())
            psnr_75 = psnr(original_images.squeeze().cpu().numpy(), outputs_75.squeeze().cpu().numpy())

            prepared_testdata.append(
                [original_images.cpu(), psnr_0, psnr_33, psnr_60, psnr_75])

    with open(train_data_file, 'wb') as f:
        pickle.dump(prepared_traindata, f)
    with open(test_data_file, 'wb') as f:
        pickle.dump(prepared_testdata, f)
    print("Data saved to files.")


# 모델 정의
class fc_ResBlock(nn.Module):
    def __init__(self, Nin, Nout):
        super(fc_ResBlock, self).__init__()
        Nh = Nin * 2
        self.use_fc3 = False
        self.fc1 = nn.Linear(Nin, Nh)
        self.fc2 = nn.Linear(Nh, Nout)
        self.relu = nn.ReLU()
        if Nin != Nout:
            self.use_fc3 = True
            self.fc3 = nn.Linear(Nin, Nout)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        if self.use_fc3:
            x = self.fc3(x)
        out = out + x
        out = self.relu(out)
        return out


class PSNRPredictionResNet(nn.Module):
    def __init__(self, Nc_max):
        super(PSNRPredictionResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=4, stride=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)

        self.conv1_f = nn.Conv2d(3, 16, kernel_size=4, stride=3, padding=1)
        self.conv2_f = nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1)
        self.conv3_f = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)

        self.fc_resblock1 = fc_ResBlock(513, Nc_max)  # 입력 크기: 두 conv 레이어의 출력 결합
        self.fc_resblock2 = fc_ResBlock(Nc_max, Nc_max)
        self.fc_out = nn.Linear(Nc_max, 4)  # psnr_0, psnr_33, psnr_60, psnr_75 출력

        self.relu = nn.ReLU()

    def forward(self, x, x_f, snr):
        # Original Image CNN
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = x.view(x.size(0), -1)  # Flatten the output

        # Fourier Transformed Image CNN
        x_f = self.relu(self.conv1_f(x_f))
        x_f = self.relu(self.conv2_f(x_f))
        x_f = self.relu(self.conv3_f(x_f))
        x_f = x_f.view(x_f.size(0), -1)  # Flatten the output

        # Concatenate flattened outputs and SNR
        snr_tensor = torch.tensor(snr, dtype=torch.float32, device=x.device).unsqueeze(0).repeat(x.size(0), 1)
        combined_features = torch.cat((x, x_f, snr_tensor), dim=1)

        # Fully Connected Residual Blocks
        x = self.fc_resblock1(combined_features)
        x = self.fc_resblock2(x)
        x = self.fc_out(x)

        return x


train_loader = DataLoader(CustomDataset(prepared_traindata), batch_size=64, shuffle=True)
test_loader = DataLoader(CustomDataset(prepared_testdata), batch_size=64, shuffle=False)

# 모델 학습 설정
model = PSNRPredictionResNet(Nc_max=128).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# 학습 루프
num_epochs = 5000

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20, verbose=True)
min_cost = 100000
previous_best_model_path = None

for epoch in range(num_epochs):
    trainloss = 0.0
    model.train()
    for data in train_loader:
        #print(data['original_images'].shape)


        images = data['original_images'].float().to(device)
        x_f = data['x_f'].float().to(device)
        #print(x_f.shape)
        targets = data['psnr_values'].to(device)

        #print(images.shape)
        #print(x_f.shape)

        outputs = model(images, x_f, snr)
        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        trainloss += loss.item()
    traincost = trainloss / len(train_loader)

    testloss = 0.0
    model.eval()
    with torch.no_grad():
        for data in test_loader:
            images = data['original_images'].float().to(device)
            x_f = data['x_f'].float().to(device)
            targets = data['psnr_values'].to(device)

            outputs = model(images, x_f, snr)
            loss = criterion(outputs, targets)

            testloss += loss.item()
        testcost = testloss / len(test_loader)

    scheduler.step(testcost)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Train MSE: {traincost:.4f}    Test MSE: {testcost:.4f}")

    if testcost < min_cost:
        save_folder = 'Lv. decision'

        if not os.path.exists(save_folder):
            os.makedirs(save_folder)
        previous_cost = min_cost
        min_cost = testcost

        if previous_best_model_path is not None:
            os.remove(previous_best_model_path)
            print(f"Performance update!! {previous_cost:.4f} to {min_cost:.4f}")

        save_path = os.path.join(save_folder, f"Lv(DIM={dim}_SNR={snr}).pt")
        torch.save(model, save_path)
        print()

        previous_best_model_path = save_path

        with open('Transmission_peformance.txt', 'a', encoding='utf-8') as file:
            file.write(f"\nDIM:{dim}")
            file.write(f"\nSNR({snr}dB) : {testcost:.4f}")

print("Training complete.")
0.7799

 

 

원본만

import os
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim
import pickle
from sklearn.preprocessing import StandardScaler
from torchvision import transforms
from torch.utils.data import DataLoader
from InPainting import Loader_maker_for_InPainting
from tqdm import tqdm
from utils import *

# 모델 파일 로드
dim = 1024
snr = 40
model_files = [f for f in os.listdir('inpaint_model') if f.startswith(f'InPaint(DIM={dim}_SNR={snr}')]
if not model_files:
    raise FileNotFoundError(f"No model found for DIM={dim} and SNR={snr}")
model_path = os.path.join('inpaint_model', model_files[-1])  # 가장 최신 모델 사용
model = torch.load(model_path)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# 푸리에 변환 함수
def compute_fft(image):
    fft_image = np.fft.fft2(image, axes=(0, 1))
    fft_image_shifted = np.fft.fftshift(fft_image, axes=(0, 1))
    magnitude_spectrum = np.abs(fft_image_shifted)
    return np.log(magnitude_spectrum + 1)


# Edge 검출량 계산 함수
from scipy.ndimage import sobel


def calculate_edge_amount(image):
    dx = sobel(image, axis=0)
    dy = sobel(image, axis=1)
    edge_magnitude = np.hypot(dx, dy)
    return edge_magnitude.mean()


import cv2


def calculate_canny_edge_amount(image):
    if image.dtype != np.uint8:
        image = (image * 255).astype(np.uint8)
    edges = cv2.Canny(image, 100, 200)
    return edges.mean()


# Custom Dataset 정의
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        original_images = sample[0]
        psnr_0 = sample[1]
        psnr_33 = sample[2]
        psnr_60 = sample[3]
        psnr_75 = sample[4]

        # Compute Fourier Transform of the image
        x_f = compute_fft(original_images.numpy())
        x_f = torch.tensor(x_f, dtype=torch.float32)



        return {
            'original_images': original_images,
            'x_f': x_f,  # Fourier transformed image
            'psnr_values': torch.tensor([psnr_0, psnr_33, psnr_60, psnr_75], dtype=torch.float32)
        }


# PSNR 계산 함수
def psnr(original, recon):
    mse = np.mean((original - recon) ** 2)
    if mse == 0:
        return 100
    max_pixel = 1.0
    psnr_value = 20 * np.log10(max_pixel / np.sqrt(mse))
    return psnr_value


# 데이터 로드 및 전처리
train_data_file = 'prepared_traindata.pkl'
test_data_file = 'prepared_testdata.pkl'

if os.path.exists(train_data_file) and os.path.exists(test_data_file):
    with open(train_data_file, 'rb') as f:
        prepared_traindata = pickle.load(f)
    with open(test_data_file, 'rb') as f:
        prepared_testdata = pickle.load(f)
    print("Data loaded from files.")
else:
    masked_train_path = f"Masked_Train/{dim}/{snr}"
    trainset = Loader_maker_for_InPainting(root_dir=masked_train_path)
    trainloader = DataLoader(trainset, batch_size=1, shuffle=True)

    masked_test_path = f"Masked_Test/{dim}/{snr}"
    testset = Loader_maker_for_InPainting(root_dir=masked_test_path)
    testloader = DataLoader(testset, batch_size=1, shuffle=False)

    prepared_traindata = []
    prepared_testdata = []

    with torch.no_grad():
        for i, data in tqdm(enumerate(trainloader), total=len(trainloader)):
            original_images = data['original_images'].squeeze()
            #print(original_images.shape)
            recon_0 = data['recon_0'].squeeze(1).to(device)
            recon_masked_33 = data['recon_masked_33'].squeeze(1).to(device)
            recon_masked_60 = data['recon_masked_60'].squeeze(1).to(device)
            recon_masked_75 = data['recon_masked_75'].squeeze(1).to(device)

            outputs_0 = model(recon_0)
            outputs_33 = model(recon_masked_33)
            outputs_60 = model(recon_masked_60)
            outputs_75 = model(recon_masked_75)

            psnr_0 = psnr(original_images.squeeze().cpu().numpy(), outputs_0.squeeze().cpu().numpy())
            psnr_33 = psnr(original_images.squeeze().cpu().numpy(), outputs_33.squeeze().cpu().numpy())
            psnr_60 = psnr(original_images.squeeze().cpu().numpy(), outputs_60.squeeze().cpu().numpy())
            psnr_75 = psnr(original_images.squeeze().cpu().numpy(), outputs_75.squeeze().cpu().numpy())

            prepared_traindata.append(
                [original_images.cpu(), psnr_0, psnr_33, psnr_60, psnr_75])

        for i, data in tqdm(enumerate(testloader), total=len(testloader)):
            original_images = data['original_images'].squeeze()
            recon_0 = data['recon_0'].squeeze(1).to(device)
            recon_masked_33 = data['recon_masked_33'].squeeze(1).to(device)
            recon_masked_60 = data['recon_masked_60'].squeeze(1).to(device)
            recon_masked_75 = data['recon_masked_75'].squeeze(1).to(device)

            outputs_0 = model(recon_0)
            outputs_33 = model(recon_masked_33)
            outputs_60 = model(recon_masked_60)
            outputs_75 = model(recon_masked_75)

            psnr_0 = psnr(original_images.squeeze().cpu().numpy(), outputs_0.squeeze().cpu().numpy())
            psnr_33 = psnr(original_images.squeeze().cpu().numpy(), outputs_33.squeeze().cpu().numpy())
            psnr_60 = psnr(original_images.squeeze().cpu().numpy(), outputs_60.squeeze().cpu().numpy())
            psnr_75 = psnr(original_images.squeeze().cpu().numpy(), outputs_75.squeeze().cpu().numpy())

            prepared_testdata.append(
                [original_images.cpu(), psnr_0, psnr_33, psnr_60, psnr_75])

    with open(train_data_file, 'wb') as f:
        pickle.dump(prepared_traindata, f)
    with open(test_data_file, 'wb') as f:
        pickle.dump(prepared_testdata, f)
    print("Data saved to files.")


# 모델 정의
class fc_ResBlock(nn.Module):
    def __init__(self, Nin, Nout):
        super(fc_ResBlock, self).__init__()
        Nh = Nin * 2
        self.use_fc3 = False
        self.fc1 = nn.Linear(Nin, Nh)
        self.fc2 = nn.Linear(Nh, Nout)
        self.relu = nn.ReLU()
        if Nin != Nout:
            self.use_fc3 = True
            self.fc3 = nn.Linear(Nin, Nout)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        if self.use_fc3:
            x = self.fc3(x)
        out = out + x
        out = self.relu(out)
        return out


class PSNRPredictionResNet(nn.Module):
    def __init__(self, Nc_max):
        super(PSNRPredictionResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=4, stride=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)

        self.conv1_f = nn.Conv2d(3, 32, kernel_size=4, stride=3, padding=1)
        self.conv2_f = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)

        self.fc_resblock1 = fc_ResBlock(1601, Nc_max)  # 입력 크기: 두 conv 레이어의 출력 결합
        self.fc_resblock2 = fc_ResBlock(Nc_max, Nc_max)
        self.fc_out = nn.Linear(Nc_max, 4)  # psnr_0, psnr_33, psnr_60, psnr_75 출력

        self.relu = nn.ReLU()

    def forward(self, x, x_f, snr):
        # Original Image CNN
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # Flatten the output

        # Fourier Transformed Image CNN
        x_f = self.relu(self.conv1_f(x_f))
        x_f = self.relu(self.conv2_f(x_f))
        x_f = x_f.view(x_f.size(0), -1)  # Flatten the output

        # Concatenate flattened outputs and SNR
        snr_tensor = torch.tensor(snr, dtype=torch.float32, device=x.device).unsqueeze(0).repeat(x.size(0), 1)
        combined_features = torch.cat((x, snr_tensor), dim=1)

        # Fully Connected Residual Blocks
        x = self.fc_resblock1(combined_features)
        x = self.fc_resblock2(x)
        x = self.fc_out(x)

        return x


train_loader = DataLoader(CustomDataset(prepared_traindata), batch_size=64, shuffle=True)
test_loader = DataLoader(CustomDataset(prepared_testdata), batch_size=64, shuffle=False)

# 모델 학습 설정
model = PSNRPredictionResNet(Nc_max=512).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# 학습 루프
num_epochs = 5000

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20, verbose=True)
min_cost = 100000
previous_best_model_path = None

for epoch in range(num_epochs):
    trainloss = 0.0
    model.train()
    for data in train_loader:
        #print(data['original_images'].shape)


        images = data['original_images'].float().to(device)
        x_f = data['x_f'].float().to(device)
        #print(x_f.shape)
        targets = data['psnr_values'].to(device)

        #print(images.shape)
        #print(x_f.shape)

        outputs = model(images, x_f, snr)
        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        trainloss += loss.item()
    traincost = trainloss / len(train_loader)

    testloss = 0.0
    model.eval()
    with torch.no_grad():
        for data in test_loader:
            images = data['original_images'].float().to(device)
            x_f = data['x_f'].float().to(device)
            targets = data['psnr_values'].to(device)

            outputs = model(images, x_f, snr)
            loss = criterion(outputs, targets)

            testloss += loss.item()
        testcost = testloss / len(test_loader)

    scheduler.step(testcost)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Train MSE: {traincost:.4f}    Test MSE: {testcost:.4f}")

    if testcost < min_cost:
        save_folder = 'Lv. decision'

        if not os.path.exists(save_folder):
            os.makedirs(save_folder)
        previous_cost = min_cost
        min_cost = testcost

        if previous_best_model_path is not None:
            os.remove(previous_best_model_path)
            print(f"Performance update!! {previous_cost:.4f} to {min_cost:.4f}")

        save_path = os.path.join(save_folder, f"Lv(DIM={dim}_SNR={snr}).pt")
        torch.save(model, save_path)
        print()

        previous_best_model_path = save_path

        with open('Transmission_peformance.txt', 'a', encoding='utf-8') as file:
            file.write(f"\nDIM:{dim}")
            file.write(f"\nSNR({snr}dB) : {testcost:.4f}")

print("Training complete.")
0.9191

 

 

 

 

FFT만

import os
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim
import pickle
from sklearn.preprocessing import StandardScaler
from torchvision import transforms
from torch.utils.data import DataLoader
from InPainting import Loader_maker_for_InPainting
from tqdm import tqdm
from utils import *

# 모델 파일 로드
dim = 1024
snr = 40
model_files = [f for f in os.listdir('inpaint_model') if f.startswith(f'InPaint(DIM={dim}_SNR={snr}')]
if not model_files:
    raise FileNotFoundError(f"No model found for DIM={dim} and SNR={snr}")
model_path = os.path.join('inpaint_model', model_files[-1])  # 가장 최신 모델 사용
model = torch.load(model_path)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# 푸리에 변환 함수
def compute_fft(image):
    fft_image = np.fft.fft2(image, axes=(0, 1))
    fft_image_shifted = np.fft.fftshift(fft_image, axes=(0, 1))
    magnitude_spectrum = np.abs(fft_image_shifted)
    return np.log(magnitude_spectrum + 1)


# Edge 검출량 계산 함수
from scipy.ndimage import sobel


def calculate_edge_amount(image):
    dx = sobel(image, axis=0)
    dy = sobel(image, axis=1)
    edge_magnitude = np.hypot(dx, dy)
    return edge_magnitude.mean()


import cv2


def calculate_canny_edge_amount(image):
    if image.dtype != np.uint8:
        image = (image * 255).astype(np.uint8)
    edges = cv2.Canny(image, 100, 200)
    return edges.mean()


# Custom Dataset 정의
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        original_images = sample[0]
        psnr_0 = sample[1]
        psnr_33 = sample[2]
        psnr_60 = sample[3]
        psnr_75 = sample[4]

        # Compute Fourier Transform of the image
        x_f = compute_fft(original_images.numpy())
        x_f = torch.tensor(x_f, dtype=torch.float32)



        return {
            'original_images': original_images,
            'x_f': x_f,  # Fourier transformed image
            'psnr_values': torch.tensor([psnr_0, psnr_33, psnr_60, psnr_75], dtype=torch.float32)
        }


# PSNR 계산 함수
def psnr(original, recon):
    mse = np.mean((original - recon) ** 2)
    if mse == 0:
        return 100
    max_pixel = 1.0
    psnr_value = 20 * np.log10(max_pixel / np.sqrt(mse))
    return psnr_value


# 데이터 로드 및 전처리
train_data_file = 'prepared_traindata.pkl'
test_data_file = 'prepared_testdata.pkl'

if os.path.exists(train_data_file) and os.path.exists(test_data_file):
    with open(train_data_file, 'rb') as f:
        prepared_traindata = pickle.load(f)
    with open(test_data_file, 'rb') as f:
        prepared_testdata = pickle.load(f)
    print("Data loaded from files.")
else:
    masked_train_path = f"Masked_Train/{dim}/{snr}"
    trainset = Loader_maker_for_InPainting(root_dir=masked_train_path)
    trainloader = DataLoader(trainset, batch_size=1, shuffle=True)

    masked_test_path = f"Masked_Test/{dim}/{snr}"
    testset = Loader_maker_for_InPainting(root_dir=masked_test_path)
    testloader = DataLoader(testset, batch_size=1, shuffle=False)

    prepared_traindata = []
    prepared_testdata = []

    with torch.no_grad():
        for i, data in tqdm(enumerate(trainloader), total=len(trainloader)):
            original_images = data['original_images'].squeeze()
            #print(original_images.shape)
            recon_0 = data['recon_0'].squeeze(1).to(device)
            recon_masked_33 = data['recon_masked_33'].squeeze(1).to(device)
            recon_masked_60 = data['recon_masked_60'].squeeze(1).to(device)
            recon_masked_75 = data['recon_masked_75'].squeeze(1).to(device)

            outputs_0 = model(recon_0)
            outputs_33 = model(recon_masked_33)
            outputs_60 = model(recon_masked_60)
            outputs_75 = model(recon_masked_75)

            psnr_0 = psnr(original_images.squeeze().cpu().numpy(), outputs_0.squeeze().cpu().numpy())
            psnr_33 = psnr(original_images.squeeze().cpu().numpy(), outputs_33.squeeze().cpu().numpy())
            psnr_60 = psnr(original_images.squeeze().cpu().numpy(), outputs_60.squeeze().cpu().numpy())
            psnr_75 = psnr(original_images.squeeze().cpu().numpy(), outputs_75.squeeze().cpu().numpy())

            prepared_traindata.append(
                [original_images.cpu(), psnr_0, psnr_33, psnr_60, psnr_75])

        for i, data in tqdm(enumerate(testloader), total=len(testloader)):
            original_images = data['original_images'].squeeze()
            recon_0 = data['recon_0'].squeeze(1).to(device)
            recon_masked_33 = data['recon_masked_33'].squeeze(1).to(device)
            recon_masked_60 = data['recon_masked_60'].squeeze(1).to(device)
            recon_masked_75 = data['recon_masked_75'].squeeze(1).to(device)

            outputs_0 = model(recon_0)
            outputs_33 = model(recon_masked_33)
            outputs_60 = model(recon_masked_60)
            outputs_75 = model(recon_masked_75)

            psnr_0 = psnr(original_images.squeeze().cpu().numpy(), outputs_0.squeeze().cpu().numpy())
            psnr_33 = psnr(original_images.squeeze().cpu().numpy(), outputs_33.squeeze().cpu().numpy())
            psnr_60 = psnr(original_images.squeeze().cpu().numpy(), outputs_60.squeeze().cpu().numpy())
            psnr_75 = psnr(original_images.squeeze().cpu().numpy(), outputs_75.squeeze().cpu().numpy())

            prepared_testdata.append(
                [original_images.cpu(), psnr_0, psnr_33, psnr_60, psnr_75])

    with open(train_data_file, 'wb') as f:
        pickle.dump(prepared_traindata, f)
    with open(test_data_file, 'wb') as f:
        pickle.dump(prepared_testdata, f)
    print("Data saved to files.")


# 모델 정의
class fc_ResBlock(nn.Module):
    def __init__(self, Nin, Nout):
        super(fc_ResBlock, self).__init__()
        Nh = Nin * 2
        self.use_fc3 = False
        self.fc1 = nn.Linear(Nin, Nh)
        self.fc2 = nn.Linear(Nh, Nout)
        self.relu = nn.ReLU()
        if Nin != Nout:
            self.use_fc3 = True
            self.fc3 = nn.Linear(Nin, Nout)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        if self.use_fc3:
            x = self.fc3(x)
        out = out + x
        out = self.relu(out)
        return out


class PSNRPredictionResNet(nn.Module):
    def __init__(self, Nc_max):
        super(PSNRPredictionResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=4, stride=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)

        self.conv1_f = nn.Conv2d(3, 32, kernel_size=4, stride=3, padding=1)
        self.conv2_f = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)

        self.fc_resblock1 = fc_ResBlock(1601, Nc_max)  # 입력 크기: 두 conv 레이어의 출력 결합
        self.fc_resblock2 = fc_ResBlock(Nc_max, Nc_max)
        self.fc_out = nn.Linear(Nc_max, 4)  # psnr_0, psnr_33, psnr_60, psnr_75 출력

        self.relu = nn.ReLU()

    def forward(self, x, x_f, snr):
        # Original Image CNN
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # Flatten the output

        # Fourier Transformed Image CNN
        x_f = self.relu(self.conv1_f(x_f))
        x_f = self.relu(self.conv2_f(x_f))
        x_f = x_f.view(x_f.size(0), -1)  # Flatten the output

        # Concatenate flattened outputs and SNR
        snr_tensor = torch.tensor(snr, dtype=torch.float32, device=x.device).unsqueeze(0).repeat(x.size(0), 1)
        combined_features = torch.cat((x_f, snr_tensor), dim=1)

        # Fully Connected Residual Blocks
        x = self.fc_resblock1(combined_features)
        x = self.fc_resblock2(x)
        x = self.fc_out(x)

        return x


train_loader = DataLoader(CustomDataset(prepared_traindata), batch_size=64, shuffle=True)
test_loader = DataLoader(CustomDataset(prepared_testdata), batch_size=64, shuffle=False)

# 모델 학습 설정
model = PSNRPredictionResNet(Nc_max=512).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# 학습 루프
num_epochs = 5000

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20, verbose=True)
min_cost = 100000
previous_best_model_path = None

for epoch in range(num_epochs):
    trainloss = 0.0
    model.train()
    for data in train_loader:
        #print(data['original_images'].shape)


        images = data['original_images'].float().to(device)
        x_f = data['x_f'].float().to(device)
        #print(x_f.shape)
        targets = data['psnr_values'].to(device)

        #print(images.shape)
        #print(x_f.shape)

        outputs = model(images, x_f, snr)
        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        trainloss += loss.item()
    traincost = trainloss / len(train_loader)

    testloss = 0.0
    model.eval()
    with torch.no_grad():
        for data in test_loader:
            images = data['original_images'].float().to(device)
            x_f = data['x_f'].float().to(device)
            targets = data['psnr_values'].to(device)

            outputs = model(images, x_f, snr)
            loss = criterion(outputs, targets)

            testloss += loss.item()
        testcost = testloss / len(test_loader)

    scheduler.step(testcost)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Train MSE: {traincost:.4f}    Test MSE: {testcost:.4f}")

    if testcost < min_cost:
        save_folder = 'Lv. decision'

        if not os.path.exists(save_folder):
            os.makedirs(save_folder)
        previous_cost = min_cost
        min_cost = testcost

        if previous_best_model_path is not None:
            os.remove(previous_best_model_path)
            print(f"Performance update!! {previous_cost:.4f} to {min_cost:.4f}")

        save_path = os.path.join(save_folder, f"Lv(DIM={dim}_SNR={snr}).pt")
        torch.save(model, save_path)
        print()

        previous_best_model_path = save_path

        with open('Transmission_peformance.txt', 'a', encoding='utf-8') as file:
            file.write(f"\nDIM:{dim}")
            file.write(f"\nSNR({snr}dB) : {testcost:.4f}")

print("Training complete.")
1.0928

'DE > Code' 카테고리의 다른 글

Transmission code 20240911  (0) 2024.09.11
FLOPs (Proposed encoder)  (0) 2024.08.28
Cifar10 Fourier  (0) 2024.08.27
DE : Selection (33%, 60%, 75%)  (0) 2024.08.15
Masked, Reshaped, index Gen  (0) 2024.08.10
Comments