UOMOP
Object/background focusing 본문
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
def patch_importance(image, patch_size=2, type='variance', how_many=2, noise_scale=0):
if isinstance(image, torch.Tensor):
image = image.numpy()
H, W = image.shape[-2:]
extended_patch_size = patch_size + 2 * how_many
value_map = np.zeros((H // patch_size, W // patch_size))
for i in range(0, H, patch_size):
for j in range(0, W, patch_size):
start_i = max(i - how_many, 0)
end_i = min(i + patch_size + how_many, H)
start_j = max(j - how_many, 0)
end_j = min(j + patch_size + how_many, W)
extended_patch = image[start_i:end_i, start_j:end_j]
if type == 'variance':
value = np.std(extended_patch)
elif type == 'mean_brightness':
value = np.mean(extended_patch)
elif type == 'contrast':
value = extended_patch.max() - extended_patch.min()
elif type == 'edge_density':
dy, dx = np.gradient(extended_patch)
value = np.sum(np.sqrt(dx ** 2 + dy ** 2))
elif type == 'color_diversity':
value = np.std(extended_patch)
noise = np.random.randn() * noise_scale
value_map[i // patch_size, j // patch_size] = value + noise
return value_map
def chessboard_mask(images, patch_size=2, mask_ratio=0.5, importance_type='variance', how_many=0, noise_scale=0, iterations=1):
B, C, H, W = images.shape
masked_images = images.clone()
unmasked_counts = []
target_unmasked_ratio = 1 - mask_ratio
num_patches = (H // patch_size) * (W // patch_size)
target_unmasked_patches = int(num_patches * target_unmasked_ratio)
for b in range(B):
patch_importance_map = patch_importance(images[b, 0], patch_size, importance_type, how_many, noise_scale)
mask = np.zeros((H // patch_size, W // patch_size), dtype=bool)
for i in range(H // patch_size):
for j in range(W // patch_size):
if (i + j) % 2 == 0:
mask[i, j] = True
unmasked_count = np.sum(~mask)
if mask_ratio < 0.5:
masked_indices = np.argwhere(mask)
importances = patch_importance_map[mask]
sorted_indices = masked_indices[np.argsort(importances)[::-1]]
for idx in sorted_indices:
if unmasked_count >= target_unmasked_patches:
break
mask[tuple(idx)] = False
unmasked_count += 1
elif mask_ratio > 0.5:
unmasked_indices = np.argwhere(~mask)
importances = patch_importance_map[~mask]
sorted_indices = unmasked_indices[np.argsort(importances)]
for idx in sorted_indices:
if unmasked_count <= target_unmasked_patches:
break
mask[tuple(idx)] = True
unmasked_count -= 1
# 추가 작업: 50% 마스킹일 때, 중요도 기반 마스킹/언마스킹
if mask_ratio == 0.5:
for _ in range(iterations):
masked_indices = np.argwhere(mask)
unmasked_indices = np.argwhere(~mask)
masked_importances = patch_importance_map[mask]
unmasked_importances = patch_importance_map[~mask]
if len(masked_importances) > 0 and len(unmasked_importances) > 0:
most_important_masked_idx = masked_indices[np.argmax(masked_importances)]
least_important_unmasked_idx = unmasked_indices[np.argmin(unmasked_importances)]
mask[tuple(most_important_masked_idx)] = False
mask[tuple(least_important_unmasked_idx)] = True
for i in range(H // patch_size):
for j in range(W // patch_size):
if mask[i, j]:
masked_images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = 0
unmasked_counts.append(unmasked_count)
return masked_images, unmasked_counts
transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
loader = DataLoader(dataset, batch_size=1, shuffle=False)
images, _ = next(iter(loader))
def imshow(img):
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.axis('off')
# Display original image
plt.figure(figsize=(12, 8))
plt.subplot(2, 2, 1)
imshow(images[0])
plt.title("Original Image")
# Display masked images
mask_ratios = [0.25, 0.5, 0.75]
for idx, mask_ratio in enumerate(mask_ratios, start=2):
masked_images, _ = chessboard_mask(images, patch_size=2, mask_ratio=mask_ratio, importance_type='variance', how_many=1, noise_scale=0, iterations=16)
plt.subplot(2, 2, idx)
imshow(masked_images[0])
plt.title(f"Masked Image {int(mask_ratio * 100)}%")
plt.tight_layout()
plt.show()
'Main' 카테고리의 다른 글
No Encoder Symbol Check (1) | 2024.07.05 |
---|---|
No Masking Symbol Check (0) | 2024.07.05 |
Patch importance (0) | 2024.06.03 |
Matlab code for PSNR performance comparison (0) | 2024.05.24 |
DeepJSCC performance ( DIM = 768, 1536, 2304 ) (0) | 2024.05.17 |
Comments