UOMOP
Masking strategy comparison 본문
import numpy as np
import cv2
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
def patch_std(image, patch_size=2):
# Calculate the standard deviation within each patch
H, W = image.shape
std_map = np.zeros((H // patch_size, W // patch_size))
for i in range(0, H, patch_size):
for j in range(0, W, patch_size):
patch = image[i:i+patch_size, j:j+patch_size]
std_map[i // patch_size, j // patch_size] = np.std(patch)
return std_map
def mask_patches_chessboard(images, patch_size=2, mask_ratio=0.5, complexity_based=False):
if mask_ratio != 0.5 :
B, C, H, W = images.shape
masked_images = images.clone()
for b in range(B):
image = images[b].permute(1, 2, 0).cpu().numpy() * 255
image = image.astype(np.uint8)
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# Calculate complexity for each patch
complexity_map = patch_std(gray, patch_size)
# Initialize mask with chessboard pattern for complexity map dimensions
complexity_height, complexity_width = complexity_map.shape
mask = np.zeros((complexity_height, complexity_width), dtype=bool)
mask[::2, 1::2] = 1
mask[1::2, ::2] = 1
if complexity_based:
if mask_ratio > 0.5:
additional_masking_ratio = (mask_ratio - 0.5) / 0.5
complexity_threshold = np.quantile(complexity_map[~mask], 1 - additional_masking_ratio)
additional_mask = complexity_map <= complexity_threshold
print(111)
print(additional_mask.shape)
mask[~mask] = additional_mask[~mask]
else:
unmasking_ratio = (0.5 - mask_ratio) / 0.5
complexity_threshold = np.quantile(complexity_map[mask], unmasking_ratio)
unmask = complexity_map >= complexity_threshold
mask[mask] = ~unmask[mask]
# Apply mask to the original image based on complexity map
for i in range(complexity_height):
for j in range(complexity_width):
if mask[i, j]:
image[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size] = 0
# Convert image back to PyTorch format
image = image.astype(np.float32) / 255.0
masked_images[b] = torch.from_numpy(image).permute(2, 0, 1)
elif mask_ratio == 0.5 :
B, C, H, W = images.shape
masked_images = images.clone()
# Create the chessboard pattern
pattern = np.tile(np.array([[1, 0] * (W // (2 * patch_size)), [0, 1] * (W // (2 * patch_size))]), (H // (2 * patch_size), 1))
for b in range(B):
image = images[b].permute(1, 2, 0).cpu().numpy() * 255
image = image.astype(np.uint8)
# Apply masking
mask = np.repeat(np.repeat(pattern, patch_size, axis=0), patch_size, axis=1)
image[mask == 0] = 0 # Apply chessboard pattern masking
# Convert back to PyTorch format
image = image.astype(np.float32) / 255.0
masked_images[b] = torch.from_numpy(image).permute(2, 0, 1)
return masked_images
def random_patch_masking(images, patch_size, mask_ratio):
B, C, H, W = images.shape
n_patches_horizontal = H // patch_size
n_patches_vertical = W // patch_size
n_patches_per_image = n_patches_horizontal * n_patches_vertical
n_patches_to_mask = int(n_patches_per_image * mask_ratio)
masked_images = images.clone()
for i in range(B):
mask_indices = torch.randint(0, n_patches_per_image, (n_patches_to_mask,))
for idx in mask_indices:
row = torch.div(idx, n_patches_vertical, rounding_mode='floor') * patch_size
col = (idx % n_patches_vertical) * patch_size
masked_images[i, :, row:row + patch_size, col:col + patch_size] = 0
return masked_images
def select_patches_gaussian(images, patch_size, mask_ratio):
select_ratio = 1-mask_ratio
B, C, H, W = images.shape
# 이미지 중앙 좌표
image_center = np.array([H / 2, W / 2])
# 패치 중앙 좌표 계산을 위한 그리드 생성
grid_x, grid_y = np.meshgrid(np.arange(patch_size / 2, W, patch_size),
np.arange(patch_size / 2, H, patch_size))
patch_centers = np.stack([grid_y.flatten(), grid_x.flatten()], axis=1)
# 거리 및 확률 계산
distances = np.linalg.norm(patch_centers - image_center, axis=1)
sigma = H / 4 # 적절한 분산 값 선택
probabilities = np.exp(-distances ** 2 / (2 * sigma ** 2))
probabilities /= probabilities.sum() # 정규화
# 선택할 패치 수
n_patches_to_select = int(np.ceil(len(patch_centers) * select_ratio))
# 확률에 따라 패치 선택
selected_indices = np.random.choice(len(patch_centers), size=n_patches_to_select,
replace=False, p=probabilities)
# 마스킹 이미지 생성
masked_images = torch.zeros_like(images)
for idx in selected_indices:
row, col = map(int, patch_centers[idx])
row_start = max(row - patch_size // 2, 0)
col_start = max(col - patch_size // 2, 0)
masked_images[:, :, row_start:row_start + patch_size, col_start:col_start + patch_size] = images[:, :,row_start:row_start + patch_size, col_start:col_start + patch_size]
return masked_images
def batch_mask_center_block(images, patch_size, mask_ratio):
B, C, H, W = images.shape
masked_images = images.clone()
# Calculate the number of pixels to mask
total_pixels = H * W
pixels_to_mask = int(total_pixels * mask_ratio)
# Determine the size of the mask based on the mask ratio
mask_height = mask_width = int(np.sqrt(pixels_to_mask))
# If a square mask is too big or too small, adjust to make a rectangular mask
if mask_height * mask_width != pixels_to_mask:
# Calculate aspect ratio of the image
aspect_ratio = W / H
# Start with a square mask and adjust width and height
if aspect_ratio >= 1: # Width is greater than height
mask_height = H
mask_width = int(pixels_to_mask / mask_height)
else: # Height is greater than width
mask_width = W
mask_height = int(pixels_to_mask / mask_width)
# Fine-tune the mask size to ensure the exact number of pixels are masked
while mask_height * mask_width < pixels_to_mask:
if (mask_width + patch_size) * mask_height <= pixels_to_mask:
mask_width += patch_size
elif mask_width * (mask_height + patch_size) <= pixels_to_mask:
mask_height += patch_size
while mask_height * mask_width > pixels_to_mask:
if (mask_width - patch_size) * mask_height >= pixels_to_mask:
mask_width -= patch_size
elif mask_width * (mask_height - patch_size) >= pixels_to_mask:
mask_height -= patch_size
# Ensure mask dimensions are multiples of patch_size
mask_width = mask_width // patch_size * patch_size
mask_height = mask_height // patch_size * patch_size
# Calculate the start position to center the mask
start_i = (H - mask_height) // 2
start_j = (W - mask_width) // 2
# Apply the mask to the center of the image
for b in range(B):
masked_images[b, :, start_i:start_i + mask_height, start_j:start_j + mask_width] = 0
return masked_images
def auto_canny(image, sigma=0.33):
image = cv2.GaussianBlur(image, (3, 3), 0)
v = np.median(image)
lower = int(max(0, (1.0 - sigma) * v))
upper = int(min(255, (1.0 + sigma) * v))
edged = cv2.Canny(image, lower, upper)
return edged
def batch_mask_patches_auto_canny(images, patch_size=2, mask_ratio=0.5):
import cv2
import numpy as np
import torch
B, C, H, W = images.shape
masked_images = images.clone()
for b in range(B):
image = images[b].permute(1, 2, 0).cpu().numpy() * 255
image = image.astype(np.uint8)
# Convert RGB to BGR for OpenCV
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# Apply Canny edge detection
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
edges = auto_canny(gray)
edge_strengths = []
for i in range(0, edges.shape[0], patch_size):
for j in range(0, edges.shape[1], patch_size):
patch = edges[i:i+patch_size, j:j+patch_size]
edge_strength = np.sum(patch) # Calculate the sum of edge strengths in the patch
edge_strengths.append((edge_strength, (i, j)))
# Sort patches by their edge strength in ascending order
edge_strengths.sort()
num_patches_to_mask = int(len(edge_strengths) * mask_ratio)
# Mask patches with the lowest edge strength
for _, (i, j) in edge_strengths[:num_patches_to_mask]:
image[i:i+patch_size, j:j+patch_size] = 0
# Convert image back to the original format
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = image.astype(np.float32) / 255.0
masked_images[b] = torch.from_numpy(image).permute(2, 0, 1)
return masked_images
import matplotlib.pyplot as plt
# Assuming 'images' is your batch of images from CIFAR10
# Let's say you've already applied the masking functions and stored their outputs
transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
loader = DataLoader(dataset, batch_size=1, shuffle=True)
images, _ = next(iter(loader)) # Load a single batch (one image)
# Mask ratios to iterate over
mask_ratios = [0.25, 0.5, 0.75]
# Initialize matplotlib subplot
fig, axs = plt.subplots(len(mask_ratios) + 1, 6, figsize=(20, 10)) # Add one for the original image row
# Display the original image in the first row
axs[0, 0].imshow(images[0].permute(1, 2, 0).numpy())
axs[0, 0].set_title('Original Image')
for ax in axs[0, 1:]:
ax.axis('off') # Hide unused plots in the first row
# Titles for each column
column_titles = ['Original', 'Chessboard', 'Random', 'Gaussian', 'Center Block', 'Auto Canny']
for ax, col in zip(axs[0], column_titles):
ax.set_title(col)
# Row labels for each masking ratio
row_labels = ['MR = 25%', 'MR = 50%', 'MR = 75%']
# Fill in the subplots with masked images for each ratio and method
for i, ratio in enumerate(mask_ratios):
# Placeholder for generating masked images with each method
masked_images_methods = [
images, # Assuming 'images' is the original images from CIFAR10
mask_patches_chessboard(images, patch_size=2, mask_ratio=ratio, complexity_based=True),
random_patch_masking(images, patch_size=2, mask_ratio=ratio),
select_patches_gaussian(images, patch_size=2, mask_ratio=ratio),
batch_mask_center_block(images, patch_size=2, mask_ratio=ratio),
batch_mask_patches_auto_canny(images, patch_size=2, mask_ratio=ratio)
]
# Display masked images
for j, masked_images in enumerate(masked_images_methods):
img = masked_images[0].permute(1, 2, 0).numpy() # Convert to numpy array for matplotlib
axs[i + 1, j].imshow(img)
axs[i + 1, j].axis('off') # Hide axes for cleaner visualization
# Remove x and y ticks
axs[i + 1, 0].set_xticks([])
axs[i + 1, 0].set_yticks([])
# Add row label
axs[i + 1, 0].set_ylabel(row_labels[i], rotation=0, size='large', labelpad=60)
plt.tight_layout()
plt.show()
'Main' 카테고리의 다른 글
ChessBoard Masking with Colored Random Noise (0) | 2024.05.02 |
---|---|
CBS (odd, odd) masking (0) | 2024.04.18 |
ChessBoard Selection (0) | 2024.04.09 |
Chessboard_masking (0) | 2024.04.08 |
Mean_far_selection (0) | 2024.04.01 |
Comments