UOMOP

High Attention Selection 본문

Main

High Attention Selection

Happy PinGu 2024. 3. 15. 18:29
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torch.nn.functional as F
import math
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as f
from torch.utils.data import DataLoader, Dataset
import time
from params import *
import os
from tqdm import tqdm


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def attention_based_selection(images, patch_size, mask_ratio) :

    batch_patches = F.unfold(images, kernel_size=patch_size, stride=patch_size)
    batch_patches = batch_patches.view(images.size(0), 3, -1, patch_size, patch_size)
    batch_patches = batch_patches.permute(0, 2, 1, 3, 4).contiguous()

    correlation_map = calculate_correlation_batch(batch_patches)
    total_patches = batch_patches.size(1)

    patch_count = int(total_patches * (1 - mask_ratio))
    selected_patches_indexes = select_patches_batch_high_correlation(correlation_map, patch_count)

    masked_images = mask_unselected_patches_batch(images, selected_patches_indexes, patch_size)

    return masked_images



def calculate_correlation_batch(patches):

    B, N, C, H, W = patches.shape
    patches = patches.view(B, N, -1)

    mean_centered = patches - patches.mean(dim=2, keepdim=True)
    norm = mean_centered.norm(dim=2, keepdim=True)
    norm[norm == 0] = 1

    correlation_matrix = torch.matmul(mean_centered, mean_centered.transpose(1, 2)) / torch.matmul(norm, norm.transpose(1, 2))

    return correlation_matrix

def select_patches_batch_high_correlation(correlation_map, patch_count):
    B, N, _ = correlation_map.shape
    selected_indexes = torch.zeros((B, patch_count), dtype=torch.long)

    for b in range(B):
        selected = [N // 2]  # 시작점은 중간 패치로 동일하게 설정
        for _ in range(1, patch_count):
            last_selected_index = selected[-1]
            masked_correlation = correlation_map[b, last_selected_index].clone()
            masked_correlation[selected] = -1  # 이미 선택된 패치는 선택되지 않도록 점수를 최소화
            next_patch_index = masked_correlation.argmax().item()  # 가장 높은 correlation 점수를 가진 인덱스 선택
            selected.append(next_patch_index)
        selected_indexes[b] = torch.tensor(selected)

    return selected_indexes

def mask_unselected_patches_batch(images, selected_indexes, patch_size):

    B, C, H, W = images.shape
    masked_images = images.clone()
    for b in range(B):
        mask = torch.ones((H, W), dtype=torch.bool, device=images.device)
        for index in selected_indexes[b]:
            row = torch.div(index, torch.div(H, patch_size, rounding_mode='floor'), rounding_mode='floor') * patch_size
            col = torch.div(index % (H // patch_size), 1, rounding_mode='floor') * patch_size
            mask[row:row + patch_size, col:col + patch_size] = False
        masked_images[b, :, mask] = 0

    return masked_images


class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),  # Output: [batch, 32, 16, 16]
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # Output: [batch, 64, 8, 8]
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),  # Output: [batch, 128, 4, 4]
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(4*4*128, self.latent_dim),
        )

    def forward(self, x):
        return self.encoder(x)

class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.latent_dim = latent_dim
        self.decoder = nn.Sequential(
            nn.Linear(self.latent_dim, 4*4*128),
            nn.ReLU(),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # Output: [batch, 64, 8, 8]
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # Output: [batch, 32, 16, 16]
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),  # Output: [batch, 3, 32, 32]
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.decoder(x)


class Autoencoder(nn.Module):
    def __init__(self, latent_dim):
        super(Autoencoder, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def AWGN(self, input, SNRdB):

        normalized_tensor = f.normalize(input, dim=1)

        SNR = 10.0 ** (SNRdB / 10.0)

        std = 1 / math.sqrt(self.latent_dim * SNR)
        n = torch.normal(0, std, size=normalized_tensor.size()).to(device)

        return normalized_tensor + n


    def forward(self, x, SNRdB):

        encoded = self.encoder(x)
        channel_output = self.AWGN(encoded, SNRdB)
        decoded = self.decoder(channel_output)

        return decoded

def preprocess_and_save_dataset(dataset, root_dir, patch_size, mask_ratio):
    os.makedirs(root_dir, exist_ok=True)
    for i, (images, _) in tqdm(enumerate(dataset), total=len(dataset)):
        masked_images = attention_based_selection(images.unsqueeze(0), patch_size, mask_ratio)
        torch.save({
            'masked_images': masked_images.squeeze(0),
            'original_images': images
        }, os.path.join(root_dir, f'data_{i}.pt'))

class PreprocessedCIFAR10Dataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.file_paths = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith('.pt')]

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        data = torch.load(file_path)
        masked_images = data['masked_images']
        original_images = data['original_images']

        if self.transform:
            masked_images = self.transform(masked_images)
            original_images = self.transform(original_images)

        return masked_images, original_images


def train(latent_dim, patch_size, mask_ratio, trainloader, testloader):


    for snr_i in range(len(params['SNR'])) :

        model = Autoencoder(latent_dim=latent_dim).to(device)
        print("Model size : {}".format(count_parameters(model)))
        criterion = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr=params['LR'])

        min_test_cost = float('inf')
        epochs_no_improve = 0  # 감소하지 않은 에폭 수
        n_epochs_stop = 20  # 조기 중지 기준이 되는 에폭 수

        print("+++++ SNR = {} Training Start! +++++\t".format(params['SNR'][snr_i]))

        max_psnr = 0

        for epoch in range(params['EP']):
            # ========================================== Train ==========================================
            train_loss = 0.0

            model.train()
            timetemp = time.time()

            for masked_images, original_images in trainloader:

                original_images = original_images.to(device)
                masked_images = masked_images.to(device)

                optimizer.zero_grad()
                outputs = model(masked_images, SNRdB = params['SNR'][snr_i])

                loss = criterion(original_images, outputs)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()

            train_cost = train_loss / len(trainloader)
            tr_psnr = round(10 * math.log10(1.0 / train_cost), 3)

            # ========================================================================

            test_loss = 0.0

            model.eval()

            with torch.no_grad():
                for masked_images, original_images in testloader:

                    original_images = original_images.to(device)
                    masked_images = masked_images.to(device)
                    outputs = model(masked_images, SNRdB=params['SNR'][snr_i])
                    loss = criterion(original_images, outputs)
                    test_loss += loss.item()

            test_cost = test_loss / len(testloader)
            test_psnr = round(10 * math.log10(1.0 / test_cost), 3)

            # 조기 중지 조건 확인
            if test_cost < min_test_cost:
                min_test_cost = test_cost
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1

            if epochs_no_improve == n_epochs_stop:
                print("Early stopping!")
                break  # 조기 종료

            training_time = time.time() - timetemp

            print(
                "[{:>3}-Epoch({:>5}sec.)]  PSNR(Train / Val) : {:>6.4f} / {:>6.4f}        Loss(Train / Val) : {:>5.5f} / {:>5.5f}".format(
                    epoch + 1, round(training_time, 2), tr_psnr, test_psnr, train_cost,  test_cost))


            if test_psnr > max_psnr:

                save_folder = 'trained_model'

                if not os.path.exists(save_folder):
                    os.makedirs(save_folder)

                max_psnr = test_psnr
                save_path = os.path.join(save_folder, "High_atten_selec(PS=" + str(patch_size) +  "_DIM=" + str(latent_dim) + "_MR=" + str(mask_ratio) + "_SNR=" + str(params['SNR'][snr_i]) + "_PSNR=" + str(max_psnr) + ").pt")

                torch.save(model, save_path)

            '''
            plt.figure(figsize=(12, 6))
            plt.subplot(1, 2, 1)
            plt.imshow(images[0].permute(1, 2, 0))  # Unnormalize
            plt.title('Original Image')
            plt.subplot(1, 2, 2)
            plt.imshow(masked_images[0].permute(1, 2, 0))  # Unnormalize
            plt.title('Masked Image')
            plt.show()
            '''



if __name__ == '__main__':

    for ps_i in range(len(params['PS'])):
        for dim_i in range(len(params['DIM'])):
            for mr_i in range(len(params['MR'])):

                Processed_train_path = "ProcessedTrain(PS=" + str(params['PS'][ps_i]) + "_MR=" + str(params['MR'][mr_i]) + ")"
                Processed_test_path  = "ProcessedTest(PS=" + str(params['PS'][ps_i]) + "_MR=" + str(params['MR'][mr_i]) + ")"

                if not os.path.exists(Processed_train_path):
                    transform = transforms.Compose([transforms.ToTensor()])
                    train_cifar10 = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
                    preprocess_and_save_dataset(train_cifar10, Processed_train_path, patch_size=params['PS'][ps_i], mask_ratio=params['MR'][mr_i])

                if not os.path.exists(Processed_test_path):
                    transform = transforms.Compose([transforms.ToTensor()])
                    test_cifar10 = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
                    preprocess_and_save_dataset(test_cifar10, Processed_test_path, patch_size=params['PS'][ps_i], mask_ratio=params['MR'][mr_i])

                traindataset = PreprocessedCIFAR10Dataset(root_dir=Processed_train_path)
                testdataset  = PreprocessedCIFAR10Dataset(root_dir=Processed_test_path)

                trainloader = DataLoader(traindataset, batch_size=params['BS'], shuffle=True)
                testloader = DataLoader(testdataset, batch_size=params['BS'], shuffle=True)

                train(params['DIM'][dim_i], params['PS'][ps_i], params['MR'][mr_i], trainloader, testloader)
params = {
    'BS': 64,
    'LR': 0.0005,
    'EP': 500,
    'SNR': [0, 15, 30],
    'DIM': [32, 128, 512],
    'MR' : [0, 0.25, 0.5, 0.75, 1],
    'PS' : [2, 4, 8, 16]
}

'Main' 카테고리의 다른 글

Low Attention Selection Performance (CR : 1/6, 1/24, 1/96) (PS : 2)  (1) 2024.03.15
Random Selection  (0) 2024.03.15
Low Attention Selection  (0) 2024.03.15
Patch selection with zero padding  (0) 2024.03.13
cifar10 patch correlation map  (0) 2024.03.11
Comments