UOMOP
CBS 본문
def patch_importance(image, patch_size=2, type='variance', how_many=2, noise_scale=0):
if isinstance(image, torch.Tensor):
image = image.numpy()
H, W = image.shape[-2:]
extended_patch_size = patch_size + 2 * how_many
value_map = np.zeros((H // patch_size, W // patch_size))
for i in range(0, H, patch_size):
for j in range(0, W, patch_size):
start_i = max(i - how_many, 0)
end_i = min(i + patch_size + how_many, H)
start_j = max(j - how_many, 0)
end_j = min(j + patch_size + how_many, W)
extended_patch = image[start_i:end_i, start_j:end_j]
if type == 'variance':
value = np.std(extended_patch)
elif type == 'mean_brightness':
value = np.mean(extended_patch)
elif type == 'contrast':
value = extended_patch.max() - extended_patch.min()
elif type == 'edge_density':
dy, dx = np.gradient(extended_patch)
value = np.sum(np.sqrt(dx ** 2 + dy ** 2))
elif type == 'color_diversity':
value = np.std(extended_patch)
noise = np.random.randn() * noise_scale
value_map[i // patch_size, j // patch_size] = value + noise
return value_map
def chessboard_mask(images, patch_size=2, mask_ratio=0.5, importance_type='variance', how_many=1, noise_scale=0):
B, C, H, W = images.shape
masked_images = images.clone()
unmasked_counts = []
unmasked_patches = []
patch_index = []
check_index = 0
target_unmasked_ratio = 1 - mask_ratio
num_patches = (H // patch_size) * (W // patch_size)
target_unmasked_patches = int(num_patches * target_unmasked_ratio)
for b in range(B):
unmasked_count = 0
patch_importance_map = patch_importance(images[b, 0], patch_size, importance_type, how_many, noise_scale)
mask = np.zeros((H // patch_size, W // patch_size), dtype=bool)
for i in range(H // patch_size):
for j in range(W // patch_size):
if (i + j) % 2 == 0:
mask[i, j] = True
unmasked_count = np.sum(~mask)
if mask_ratio < 0.5:
masked_indices = np.argwhere(mask)
importances = patch_importance_map[mask]
sorted_indices = masked_indices[np.argsort(importances)[::-1]]
for idx in sorted_indices:
if unmasked_count >= target_unmasked_patches:
break
mask[tuple(idx)] = False
unmasked_count += 1
elif mask_ratio > 0.5:
unmasked_indices = np.argwhere(~mask)
importances = patch_importance_map[~mask]
sorted_indices = unmasked_indices[np.argsort(importances)]
for idx in sorted_indices:
if unmasked_count <= target_unmasked_patches:
break
mask[tuple(idx)] = True
unmasked_count -= 1
patches = []
for i in range(H // patch_size):
for j in range(W // patch_size):
if mask[i, j]:
masked_images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = 0
else:
patch = images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size]
#print(111)
#print(patch.shape)
patches.append(patch)
patch_index.append((H // patch_size)*i + j)
unmasked_patches.append(torch.cat(patches, dim=-1))
unmasked_counts.append(unmasked_count)
unmasked_patches_image = torch.cat(unmasked_patches, dim=-1)
#print(unmasked_patches_image.shape)
return masked_images, unmasked_patches_image, patch_index
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
# CIFAR-10 데이터셋 로드
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True, num_workers=2)
# 데이터 로드
dataiter = iter(trainloader)
images, labels = next(dataiter)
# 마스킹 적용
patch_size = 8
mask_ratio = 0.75
importance_type = 'variance'
how_many = 1
noise_scale = 0
masked_images, unmasked_patches_image, patch_index = chessboard_mask(images, patch_size, mask_ratio, importance_type, how_many, noise_scale)
print(masked_images.shape)
print(unmasked_patches_image.shape)
print(len(patch_index))
def reconstruct_masked_image(unmasked_patches, patch_index, image_shape, patch_size):
B, C, H, W = image_shape
reconstructed_images = torch.zeros((B, C, H, W)).to(unmasked_patches.device)
print(1)
print(unmasked_patches.shape)
save_here= []
for i in range(len(patch_index)) :
save_here.append( unmasked_patches[:, :, patch_size * i : patch_size * (i+1)] )
for idx, linear_idx in enumerate(patch_index):
i = linear_idx // (W // patch_size)
j = linear_idx % (W // patch_size)
reconstructed_images[0, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = save_here[idx]
return reconstructed_images
image_shape = images.shape
reconstructed_image = reconstruct_masked_image(unmasked_patches_image, patch_index, image_shape, patch_size)
def visualize_masked_image(original, masked, unmasked_patches_image, reconstructed):
fig, ax = plt.subplots(1, 4, figsize=(15, 5))
ax[0].imshow(np.transpose(original, (1, 2, 0)))
ax[0].set_title('Original Image')
ax[1].imshow(np.transpose(masked, (1, 2, 0)))
ax[1].set_title('Masked Image')
ax[2].imshow(np.transpose(unmasked_patches_image.cpu().numpy(), (1, 2, 0)))
ax[2].set_title('Reconstructed Image')
ax[3].imshow(np.transpose(reconstructed[0].cpu().numpy(), (1, 2, 0)))
ax[3].set_title('Reconstructed Image')
plt.show()
# 첫 번째 이미지에 대한 시각화
visualize_masked_image(images[0].numpy(), masked_images[0].numpy(), unmasked_patches_image, reconstructed_image)
'Main' 카테고리의 다른 글
Good (1) | 2024.07.26 |
---|---|
Position Estimator (0) | 2024.07.10 |
No Encoder Symbol Check (1) | 2024.07.05 |
No Masking Symbol Check (0) | 2024.07.05 |
Object/background focusing (0) | 2024.06.25 |
Comments