UOMOP

cifar10 patch correlation map 본문

Main

cifar10 patch correlation map

Happy PinGu 2024. 3. 11. 14:47
import torch
import torchvision
import torchvision.transforms as transforms
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import seaborn as sns

def calculate_correlation(patches):

    flattened_patches = patches.view(patches.size(0), -1)

    mean_centered = flattened_patches - flattened_patches.mean(dim=1, keepdim=True)

    norm = mean_centered.norm(dim=1, keepdim=True)

    norm[norm == 0] = 1

    correlation_matrix = torch.mm(mean_centered, mean_centered.t()) / torch.mm(norm, norm.t())

    return correlation_matrix
def image_to_patches(image, patch_size):

    _, h, w = image.shape
    patches = image.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)
    patches = patches.contiguous().view(-1, 3, patch_size, patch_size)

    return patches

def plot_heatmap(correlation_map, patch_size):
    plt.figure(figsize=(8, 6))
    sns.heatmap(correlation_map, cmap='viridis', annot=False)
    title = 'Correlation Matrix (Patch Size ' + str(patch_size) + ')'
    plt.title(title)
    plt.xlabel('Patch Index')
    plt.ylabel('Patch Index')
    plt.show()

def main():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    patch_size = 8

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True, num_workers=2)

    images, labels = next(iter(trainloader))

    # Select a single image, remove batch dimension (index 0) and preprocess
    image = images[0]

    patches_cifar = image_to_patches(image, patch_size)
    correlation_map_size_8 = calculate_correlation(patches_cifar)

    # Print out the shape of the correlation matrix
    print(correlation_map_size_8.shape)

    # Plot the heatmap
    plot_heatmap(correlation_map_size_8, patch_size)

if __name__ == '__main__':
    main()

'Main' 카테고리의 다른 글

High Attention Selection  (0) 2024.03.15
Low Attention Selection  (0) 2024.03.15
Patch selection with zero padding  (0) 2024.03.13
DeepJSCC  (0) 2024.03.11
DDPM cifar10 Simple Unet  (1) 2024.02.14
Comments