UOMOP
Patch selection with zero padding 본문
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import seaborn as sns
def calculate_correlation(patches):
flattened_patches = patches.view(patches.size(0), -1)
mean_centered = flattened_patches - flattened_patches.mean(dim=1, keepdim=True)
norm = mean_centered.norm(dim=1, keepdim=True)
norm[norm == 0] = 1
correlation_matrix = torch.mm(mean_centered, mean_centered.t()) / torch.mm(norm, norm.t())
return correlation_matrix
def image_to_patches(image, patch_size):
_, h, w = image.shape
patches = image.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)
patches = patches.contiguous().view(-1, 3, patch_size, patch_size)
return patches
def select_patches(correlation_map, total_patches, patch_count):
selected_indexes = [total_patches // 2] # Start from the center patch
while len(selected_indexes) < patch_count:
last_selected_index = selected_indexes[-1]
# Mask already selected indexes
masked_correlation = correlation_map[last_selected_index].clone()
masked_correlation[selected_indexes] = 1 # Set high value to ignore selected
next_patch_index = masked_correlation.argmin().item()
selected_indexes.append(next_patch_index)
return selected_indexes
def mask_unselected_patches(image, patches, selected_indexes, patch_size):
_, H, W = image.shape
mask = torch.ones((H, W), dtype=torch.bool)
for index in selected_indexes:
row = (index // (H // patch_size)) * patch_size
col = (index % (H // patch_size)) * patch_size
mask[row:row+patch_size, col:col+patch_size] = False
masked_image = image.clone()
masked_image[:, mask] = 0
return masked_image
def plot_heatmap(correlation_map, patch_size):
plt.figure(figsize=(8, 6))
sns.heatmap(correlation_map, cmap='viridis', annot=False)
title = 'Correlation Matrix (Patch Size ' + str(patch_size) + ')'
plt.title(title)
plt.xlabel('Patch Index')
plt.ylabel('Patch Index')
plt.show()
def main():
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
patch_size = 2
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True, num_workers=2)
images, labels = next(iter(trainloader))
image = images[0] # Select a single image
patches = image_to_patches(image, patch_size)
correlation_map = calculate_correlation(patches)
print(correlation_map.shape) # Print out the shape of the correlation matrix
plot_heatmap(correlation_map, patch_size) # Plot the heatmap
# Select patches based on correlation
total_patches = patches.shape[0]
mask_ratio = 0.5
selected_patches_indexes = select_patches(correlation_map, total_patches, int(total_patches * mask_ratio))
print(selected_patches_indexes)
# Mask unselected patches
masked_image = mask_unselected_patches(image, patches, selected_patches_indexes, patch_size)
# Show original and masked images
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(image.permute(1, 2, 0) * 0.5 + 0.5) # Unnormalize
plt.title('Original Image')
plt.subplot(1, 2, 2)
plt.imshow(masked_image.permute(1, 2, 0) * 0.5 + 0.5) # Unnormalize
plt.title('Masked Image')
plt.show()
if __name__ == '__main__':
main()
'Main' 카테고리의 다른 글
High Attention Selection (0) | 2024.03.15 |
---|---|
Low Attention Selection (0) | 2024.03.15 |
cifar10 patch correlation map (0) | 2024.03.11 |
DeepJSCC (0) | 2024.03.11 |
DDPM cifar10 Simple Unet (1) | 2024.02.14 |
Comments