UOMOP

Patch selection with zero padding 본문

Main

Patch selection with zero padding

Happy PinGu 2024. 3. 13. 15:25
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import seaborn as sns

def calculate_correlation(patches):
    flattened_patches = patches.view(patches.size(0), -1)
    mean_centered = flattened_patches - flattened_patches.mean(dim=1, keepdim=True)
    norm = mean_centered.norm(dim=1, keepdim=True)
    norm[norm == 0] = 1
    correlation_matrix = torch.mm(mean_centered, mean_centered.t()) / torch.mm(norm, norm.t())
    return correlation_matrix

def image_to_patches(image, patch_size):
    _, h, w = image.shape
    patches = image.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)
    patches = patches.contiguous().view(-1, 3, patch_size, patch_size)
    return patches

def select_patches(correlation_map, total_patches, patch_count):
    selected_indexes = [total_patches // 2]  # Start from the center patch
    while len(selected_indexes) < patch_count:
        last_selected_index = selected_indexes[-1]
        # Mask already selected indexes
        masked_correlation = correlation_map[last_selected_index].clone()
        masked_correlation[selected_indexes] = 1  # Set high value to ignore selected
        next_patch_index = masked_correlation.argmin().item()
        selected_indexes.append(next_patch_index)
    return selected_indexes

def mask_unselected_patches(image, patches, selected_indexes, patch_size):
    _, H, W = image.shape
    mask = torch.ones((H, W), dtype=torch.bool)
    for index in selected_indexes:
        row = (index // (H // patch_size)) * patch_size
        col = (index % (H // patch_size)) * patch_size
        mask[row:row+patch_size, col:col+patch_size] = False
    masked_image = image.clone()
    masked_image[:, mask] = 0
    return masked_image

def plot_heatmap(correlation_map, patch_size):
    plt.figure(figsize=(8, 6))
    sns.heatmap(correlation_map, cmap='viridis', annot=False)
    title = 'Correlation Matrix (Patch Size ' + str(patch_size) + ')'
    plt.title(title)
    plt.xlabel('Patch Index')
    plt.ylabel('Patch Index')
    plt.show()

def main():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    patch_size = 2
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True, num_workers=2)
    images, labels = next(iter(trainloader))
    image = images[0]  # Select a single image

    patches = image_to_patches(image, patch_size)
    correlation_map = calculate_correlation(patches)
    print(correlation_map.shape)  # Print out the shape of the correlation matrix
    plot_heatmap(correlation_map, patch_size)  # Plot the heatmap

    # Select patches based on correlation
    total_patches = patches.shape[0]
    mask_ratio = 0.5
    selected_patches_indexes = select_patches(correlation_map, total_patches, int(total_patches * mask_ratio))
    print(selected_patches_indexes)

    # Mask unselected patches
    masked_image = mask_unselected_patches(image, patches, selected_patches_indexes, patch_size)

    # Show original and masked images
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.imshow(image.permute(1, 2, 0) * 0.5 + 0.5)  # Unnormalize
    plt.title('Original Image')
    plt.subplot(1, 2, 2)
    plt.imshow(masked_image.permute(1, 2, 0) * 0.5 + 0.5)  # Unnormalize
    plt.title('Masked Image')
    plt.show()

if __name__ == '__main__':
    main()

'Main' 카테고리의 다른 글

High Attention Selection  (0) 2024.03.15
Low Attention Selection  (0) 2024.03.15
cifar10 patch correlation map  (0) 2024.03.11
DeepJSCC  (0) 2024.03.11
DDPM cifar10 Simple Unet  (1) 2024.02.14
Comments