UOMOP
Masked, Reshaped, index Gen 본문
import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
def patch_importance(image, patch_size=2, type='variance', how_many=2):
if isinstance(image, torch.Tensor):
image = image.numpy()
H, W = image.shape[-2:]
value_map = np.zeros((H // patch_size, W // patch_size))
for i in range(0, H, patch_size):
for j in range(0, W, patch_size):
start_i = max(i - how_many, 0)
end_i = min(i + patch_size + how_many, H)
start_j = max(j - how_many, 0)
end_j = min(j + patch_size + how_many, W)
extended_patch = image[start_i:end_i, start_j:end_j]
if type == 'variance':
value = np.std(extended_patch)
elif type == 'mean_brightness':
value = np.mean(extended_patch)
elif type == 'contrast':
value = extended_patch.max() - extended_patch.min()
elif type == 'edge_density':
dy, dx = np.gradient(extended_patch)
value = np.sum(np.sqrt(dx ** 2 + dy ** 2))
elif type == 'color_diversity':
value = np.std(extended_patch)
value_map[i // patch_size, j // patch_size] = value
return value_map
def chessboard_mask(images, patch_size=1, mask_ratio=0.5, importance_type='variance', how_many=1):
for mr_i in range(len(mask_ratio)):
MR = mask_ratio[mr_i]
B, C, H, W = images.shape
masked_images = images.clone()
unmasked_counts = []
unmasked_patches = []
patch_index = []
target_unmasked_ratio = 1 - MR
num_patches = (H // patch_size) * (W // patch_size)
target_unmasked_patches = int(num_patches * target_unmasked_ratio)
for b in range(B):
patch_importance_map = patch_importance(images[b, 0], patch_size, importance_type, how_many)
mask = np.zeros((H // patch_size, W // patch_size), dtype=bool)
for i in range(H // patch_size):
for j in range(W // patch_size):
if (i + j) % 2 == 0:
mask[i, j] = True
unmasked_count = np.sum(~mask)
if MR < 0.5:
masked_indices = np.argwhere(mask)
importances = patch_importance_map[mask]
sorted_indices = masked_indices[np.argsort(importances)[::-1]]
for idx in sorted_indices:
if unmasked_count >= target_unmasked_patches:
break
mask[tuple(idx)] = False
unmasked_count += 1
elif MR > 0.5:
unmasked_indices = np.argwhere(~mask)
importances = patch_importance_map[~mask]
sorted_indices = unmasked_indices[np.argsort(importances)]
for idx in sorted_indices:
if unmasked_count <= target_unmasked_patches:
break
mask[tuple(idx)] = True
unmasked_count -= 1
patches = []
for i in range(H // patch_size):
for j in range(W // patch_size):
if mask[i, j]:
masked_images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = 0
else:
patch = images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size]
patches.append(patch)
patch_index.append((H // patch_size)*i + j)
unmasked_patches.append(torch.cat(patches, dim=-1))
unmasked_counts.append(unmasked_count)
unmasked_patches_image = torch.cat(unmasked_patches, dim=-1)
if MR == 0.234375 :
split_len = 28
split_tensor = torch.split(unmasked_patches_image, split_len, dim=2)
masked_23 = masked_images
reshaped_23 = torch.cat(split_tensor, dim=1)
index_23 = torch.tensor(patch_index)
elif MR == 0.4375 :
split_len = 24
split_tensor = torch.split(unmasked_patches_image, split_len, dim=2)
masked_43 = masked_images
reshaped_43 = torch.cat(split_tensor, dim=1)
index_43 = torch.tensor(patch_index)
elif MR == 0.609375 :
split_len = 20
split_tensor = torch.split(unmasked_patches_image, split_len, dim=2)
masked_60 = masked_images
reshaped_60 = torch.cat(split_tensor, dim=1)
index_60 = torch.tensor(patch_index)
return masked_23, masked_43, masked_60, reshaped_23, reshaped_43, reshaped_60, index_23, index_43, index_60
# Load CIFAR-10 dataset
transform = transforms.Compose([
transforms.ToTensor()
])
cifar10 = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
image, _ = cifar10[0] # Get the first image from the test set
image = image.unsqueeze(0) # Add batch dimension
# CIFAR10 데이터셋 로드
masked_23, masked_43, masked_60, reshaped_23, reshaped_43, reshaped_60, index_23, index_43, index_60 = chessboard_mask(image, mask_ratio=[0.234375, 0.4375, 0.609375], patch_size=1)
# Plot the original, masked, and reshaped images
fig, axs = plt.subplots(2, 4, figsize=(16, 8))
axs[0, 0].imshow(image.squeeze().permute(1, 2, 0))
axs[0, 0].set_title("Original Image")
axs[0, 0].axis('off')
axs[0, 1].imshow(masked_23.squeeze().permute(1, 2, 0))
axs[0, 1].set_title("Masked Image (MR=0.234375)")
axs[0, 1].axis('off')
axs[0, 2].imshow(masked_43.squeeze().permute(1, 2, 0))
axs[0, 2].set_title("Masked Image (MR=0.4375)")
axs[0, 2].axis('off')
axs[0, 3].imshow(masked_60.squeeze().permute(1, 2, 0))
axs[0, 3].set_title("Masked Image (MR=0.609375)")
axs[0, 3].axis('off')
axs[1, 1].imshow(reshaped_23.squeeze().permute(1, 2, 0))
axs[1, 1].set_title("Reshaped Image (MR=0.234375)")
axs[1, 1].axis('off')
axs[1, 2].imshow(reshaped_43.squeeze().permute(1, 2, 0))
axs[1, 2].set_title("Reshaped Image (MR=0.4375)")
axs[1, 2].axis('off')
axs[1, 3].imshow(reshaped_60.squeeze().permute(1, 2, 0))
axs[1, 3].set_title("Reshaped Image (MR=0.609375)")
axs[1, 3].axis('off')
plt.show()
'DE > Code' 카테고리의 다른 글
Cifar10 Fourier (0) | 2024.08.27 |
---|---|
DE : Selection (33%, 60%, 75%) (0) | 2024.08.15 |
DeepJSCC BenchMark (0) | 2024.08.10 |
Proposed Network Architecture (0) | 2024.08.10 |
Adaptive decoder (0) | 2024.08.06 |
Comments