UOMOP

Position Estimator 본문

Main

Position Estimator

Happy PinGu 2024. 7. 10. 12:52
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torch.nn.functional as F
import math
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as f
from torch.utils.data import DataLoader, Dataset
import time
from params import *
import os
from tqdm import tqdm
import numpy as np
from skimage.metrics import structural_similarity as ssim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def patch_importance(image, patch_size=2, type='variance', how_many=2, noise_scale=0):
    if isinstance(image, torch.Tensor):
        image = image.numpy()

    H, W = image.shape[-2:]
    extended_patch_size = patch_size + 2 * how_many
    value_map = np.zeros((H // patch_size, W // patch_size))

    for i in range(0, H, patch_size):
        for j in range(0, W, patch_size):
            start_i = max(i - how_many, 0)
            end_i = min(i + patch_size + how_many, H)
            start_j = max(j - how_many, 0)
            end_j = min(j + patch_size + how_many, W)

            extended_patch = image[start_i:end_i, start_j:end_j]

            if type == 'variance':
                value = np.std(extended_patch)
            elif type == 'mean_brightness':
                value = np.mean(extended_patch)
            elif type == 'contrast':
                value = extended_patch.max() - extended_patch.min()
            elif type == 'edge_density':
                dy, dx = np.gradient(extended_patch)
                value = np.sum(np.sqrt(dx ** 2 + dy ** 2))
            elif type == 'color_diversity':
                value = np.std(extended_patch)

            noise = np.random.randn() * noise_scale
            value_map[i // patch_size, j // patch_size] = value + noise

    return value_map

def chessboard_mask(images, patch_size=2, mask_ratio=0.5, importance_type='variance', how_many=0, noise_scale=0):
    B, C, H, W = images.shape
    masked_images = images.clone()
    unmasked_counts = []
    unmasked_patches = []
    patch_index = []

    target_unmasked_ratio = 1 - mask_ratio
    num_patches = (H // patch_size) * (W // patch_size)
    target_unmasked_patches = int(num_patches * target_unmasked_ratio)

    for b in range(B):
        unmasked_count = 0
        patch_importance_map = patch_importance(images[b, 0], patch_size, importance_type, how_many, noise_scale)

        mask = np.zeros((H // patch_size, W // patch_size), dtype=bool)
        for i in range(H // patch_size):
            for j in range(W // patch_size):
                if (i + j) % 2 == 0:
                    mask[i, j] = True

        unmasked_count = np.sum(~mask)

        if mask_ratio < 0.5:
            masked_indices = np.argwhere(mask)
            importances = patch_importance_map[mask]
            sorted_indices = masked_indices[np.argsort(importances)[::-1]]

            for idx in sorted_indices:
                if unmasked_count >= target_unmasked_patches:
                    break
                mask[tuple(idx)] = False
                unmasked_count += 1

        elif mask_ratio > 0.5:
            unmasked_indices = np.argwhere(~mask)
            importances = patch_importance_map[~mask]
            sorted_indices = unmasked_indices[np.argsort(importances)]

            for idx in sorted_indices:
                if unmasked_count <= target_unmasked_patches:
                    break
                mask[tuple(idx)] = True
                unmasked_count -= 1

        patches = []
        for i in range(H // patch_size):
            for j in range(W // patch_size):
                if not mask[i, j]:
                    patch = images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size]
                    patches.append(patch.view(-1))

        for i in range(H // patch_size):
            for j in range(W // patch_size):
                if mask[i, j]:
                    masked_images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = 0

        unmasked_patches.append(torch.stack(patches))
        unmasked_counts.append(unmasked_count)


        for i in range(int(H / patch_size)):
            for j in range(int(W / patch_size)):
                if ((i % 2 == 0) and (j % 2 != 0)) or ((i % 2 != 0) and (j % 2 == 0)) :
                    patch_index.append(0)
                else :
                    if not mask[i, j]:
                        patch_index.append(1)
                    else :
                        patch_index.append(0)


    return masked_images, unmasked_patches, patch_index


import torch.nn as nn
import torch.nn.functional as F

class PatchIndexPredictor(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(PatchIndexPredictor, self).__init__()
        self.conv1 = nn.Conv1d(input_dim, 32, kernel_size=5, stride=2, padding=1)
        self.conv2 = nn.Conv1d(32, 64, kernel_size=5, stride=2, padding=1)
        self.conv3 = nn.Conv1d(64, 128, kernel_size=5, stride=2, padding=1)
        self.fc1 = nn.Linear(9600, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.sig = nn.Sigmoid()

    def forward(self, x):
        x = x.permute(0, 2, 1)  # Change shape to [batch_size, input_dim, seq_len]
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))

        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = self.sig(x)
        return x




class PreprocessedCIFAR10Dataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.file_paths = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith('.pt')]

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        data = torch.load(file_path)
        unmasked_patches = data['unmasked_patches']
        patch_index = data['patch_index']
        if self.transform:
            unmasked_patches = self.transform(unmasked_patches)
        return unmasked_patches, patch_index


def preprocess_and_save_dataset(dataset, root_dir, patch_size, mask_ratio, importance_type, how_many, noise_scale):
    os.makedirs(root_dir, exist_ok=True)
    for i, (images, _) in tqdm(enumerate(dataset), total=len(dataset)):
        masked_images, unmasked_patches, patch_index = chessboard_mask(images.unsqueeze(0), patch_size, mask_ratio, importance_type, how_many, noise_scale)
        torch.save({
            'unmasked_patches': unmasked_patches[0],
            'patch_index': torch.tensor(patch_index, dtype=torch.int)
        }, os.path.join(root_dir, f'data_{i}.pt'))

def top_k_to_binary(tensor, k):
    topk_indices = torch.topk(tensor, k, dim=1).indices
    binary_tensor = torch.zeros_like(tensor)
    binary_tensor.scatter_(1, topk_indices, 1)
    return binary_tensor


def train_patch_index_predictor(input_channels, patch_size, hidden_dim, output_dim, trainloader, testloader, num_epochs, learning_rate):

    model = PatchIndexPredictor(input_channels, hidden_dim, output_dim).to(device)
    criterion = torch.nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        model.train()
        timetemp = time.time()

        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for unmasked_patches, patch_index in trainloader:

            unmasked_patches = unmasked_patches.to(device)
            patch_index = patch_index.float().to(device)

            optimizer.zero_grad()
            outputs = model(unmasked_patches)
            #print(123456)
            #print(outputs.shape)
            #print(patch_index.shape)

            binary_outputs = top_k_to_binary(outputs, 102)

            loss = criterion(outputs, patch_index)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            train_correct += (binary_outputs == patch_index).sum().item()
            train_total += patch_index.numel()

        train_cost = train_loss / len(trainloader)
        train_accuracy = 100 * train_correct / train_total

        model.eval()
        test_loss = 0.0
        test_correct = 0
        test_total = 0

        with torch.no_grad():
            for unmasked_patches, patch_index in testloader:
                unmasked_patches = unmasked_patches.to(device)
                patch_index = patch_index.float().to(device)
                outputs = model(unmasked_patches)

                binary_outputs = top_k_to_binary(outputs, 102)

                loss = criterion(outputs, patch_index)
                test_loss += loss.item()

                test_correct += (binary_outputs == patch_index).sum().item()
                test_total += patch_index.numel()

            test_cost = test_loss / len(testloader)
            test_accuracy = 100 * test_correct / test_total

            training_time = time.time() - timetemp

            print(
                "[{:>3}-Epoch({:>5}sec.)]  Accuracy(Train / Val) : {:>6.2f}% / {:>6.2f}%        Loss(Train / Val) : {:>6.4f} / {:>6.4f}".format(
                    epoch + 1, round(training_time, 2), train_accuracy, test_accuracy, train_cost, test_cost))

if __name__ == '__main__':
    for ps_i in range(len(params['PS'])):
        for dim_i in range(len(params['DIM'])):
            for mr_i in range(len(params['MR'])):
                for hm_i in range(len(params['HM'])):
                    Processed_train_path = "ProcessedTrain(PS=" + str(params['PS'][ps_i]) + "_MR=" + str(params['MR'][mr_i]) + "_IT=" + str(params['IT']) + "_HM=" + str(params['HM'][hm_i]) + ")"
                    Processed_test_path = "ProcessedTest(PS=" + str(params['PS'][ps_i]) + "_MR=" + str(params['MR'][mr_i]) + "_IT=" + str(params['IT']) + "_HM=" + str(params['HM'][hm_i]) + ")"

                    if not os.path.exists(Processed_train_path):
                        transform = transforms.Compose([transforms.ToTensor()])
                        train_cifar10 = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
                        preprocess_and_save_dataset(train_cifar10, Processed_train_path, patch_size=params['PS'][ps_i], mask_ratio=params['MR'][mr_i], importance_type=params['IT'], how_many=params['HM'][hm_i], noise_scale=params['NS'])

                    if not os.path.exists(Processed_test_path):
                        transform = transforms.Compose([transforms.ToTensor()])
                        test_cifar10 = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
                        preprocess_and_save_dataset(test_cifar10, Processed_test_path, patch_size=params['PS'][ps_i], mask_ratio=params['MR'][mr_i], importance_type=params['IT'], how_many=params['HM'][hm_i], noise_scale=params['NS'])

                    traindataset = PreprocessedCIFAR10Dataset(root_dir=Processed_train_path)
                    testdataset = PreprocessedCIFAR10Dataset(root_dir=Processed_test_path)

                    trainloader = DataLoader(traindataset, batch_size=params['BS'], shuffle=True, num_workers=4)
                    testloader = DataLoader(testdataset, batch_size=params['BS'], shuffle=True, num_workers=4)

                    input_channels = 3  # 3 channels for RGB images
                    hidden_dim = 5000  # Adjust as needed
                    output_dim = params['DIM'][0]  # Adjust to the number of possible patch indices

                    train_patch_index_predictor(input_channels, params['PS'][ps_i], hidden_dim, output_dim, trainloader, testloader, num_epochs=params['EP'], learning_rate=params['LR'])
params = {
    'BS': 64,
    'LR': 0.00005,
    'EP': 5000,
    'SNR': [40],
    'DIM': [1024],
    'MR' : [0.4],
    'PS' : [1],
    'ES' : 20,
    'IT' : 'variance',
    'HM' : [1],
    'NS' : 0,
    'channel' : 'Rayleigh'
}

 

'Main' 카테고리의 다른 글

Good  (1) 2024.07.26
CBS  (0) 2024.07.25
No Encoder Symbol Check  (1) 2024.07.05
No Masking Symbol Check  (0) 2024.07.05
Object/background focusing  (0) 2024.06.25
Comments