UOMOP
cifar10 patch correlation map 본문
import torch
import torchvision
import torchvision.transforms as transforms
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import seaborn as sns
def calculate_correlation(patches):
flattened_patches = patches.view(patches.size(0), -1)
mean_centered = flattened_patches - flattened_patches.mean(dim=1, keepdim=True)
norm = mean_centered.norm(dim=1, keepdim=True)
norm[norm == 0] = 1
correlation_matrix = torch.mm(mean_centered, mean_centered.t()) / torch.mm(norm, norm.t())
return correlation_matrix
def image_to_patches(image, patch_size):
_, h, w = image.shape
patches = image.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)
patches = patches.contiguous().view(-1, 3, patch_size, patch_size)
return patches
def plot_heatmap(correlation_map, patch_size):
plt.figure(figsize=(8, 6))
sns.heatmap(correlation_map, cmap='viridis', annot=False)
title = 'Correlation Matrix (Patch Size ' + str(patch_size) + ')'
plt.title(title)
plt.xlabel('Patch Index')
plt.ylabel('Patch Index')
plt.show()
def main():
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
patch_size = 8
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True, num_workers=2)
images, labels = next(iter(trainloader))
# Select a single image, remove batch dimension (index 0) and preprocess
image = images[0]
patches_cifar = image_to_patches(image, patch_size)
correlation_map_size_8 = calculate_correlation(patches_cifar)
# Print out the shape of the correlation matrix
print(correlation_map_size_8.shape)
# Plot the heatmap
plot_heatmap(correlation_map_size_8, patch_size)
if __name__ == '__main__':
main()
'Main' 카테고리의 다른 글
High Attention Selection (0) | 2024.03.15 |
---|---|
Low Attention Selection (0) | 2024.03.15 |
Patch selection with zero padding (0) | 2024.03.13 |
DeepJSCC (0) | 2024.03.11 |
DDPM cifar10 Simple Unet (1) | 2024.02.14 |
Comments