UOMOP

Masking strategy comparison 본문

Main

Masking strategy comparison

Happy PinGu 2024. 4. 11. 17:45
import numpy as np
import cv2
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

def patch_std(image, patch_size=2):
    # Calculate the standard deviation within each patch
    H, W = image.shape
    std_map = np.zeros((H // patch_size, W // patch_size))
    for i in range(0, H, patch_size):
        for j in range(0, W, patch_size):
            patch = image[i:i+patch_size, j:j+patch_size]
            std_map[i // patch_size, j // patch_size] = np.std(patch)
    return std_map

def mask_patches_chessboard(images, patch_size=2, mask_ratio=0.5, complexity_based=False):

    if mask_ratio != 0.5 :
        B, C, H, W = images.shape
        masked_images = images.clone()

        for b in range(B):
            image = images[b].permute(1, 2, 0).cpu().numpy() * 255
            image = image.astype(np.uint8)
            gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)

            # Calculate complexity for each patch
            complexity_map = patch_std(gray, patch_size)

            # Initialize mask with chessboard pattern for complexity map dimensions
            complexity_height, complexity_width = complexity_map.shape
            mask = np.zeros((complexity_height, complexity_width), dtype=bool)
            mask[::2, 1::2] = 1
            mask[1::2, ::2] = 1

            if complexity_based:
                if mask_ratio > 0.5:
                    additional_masking_ratio = (mask_ratio - 0.5) / 0.5
                    complexity_threshold = np.quantile(complexity_map[~mask], 1 - additional_masking_ratio)
                    additional_mask = complexity_map <= complexity_threshold
                    print(111)
                    print(additional_mask.shape)
                    mask[~mask] = additional_mask[~mask]

                
                else:
                    unmasking_ratio = (0.5 - mask_ratio) / 0.5
                    complexity_threshold = np.quantile(complexity_map[mask], unmasking_ratio)
                    unmask = complexity_map >= complexity_threshold
                    mask[mask] = ~unmask[mask]

            # Apply mask to the original image based on complexity map
            for i in range(complexity_height):
                for j in range(complexity_width):
                    if mask[i, j]:
                        image[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size] = 0

            # Convert image back to PyTorch format
            image = image.astype(np.float32) / 255.0
            masked_images[b] = torch.from_numpy(image).permute(2, 0, 1)

    elif mask_ratio == 0.5 :
            B, C, H, W = images.shape
            masked_images = images.clone()

            # Create the chessboard pattern
            pattern = np.tile(np.array([[1, 0] * (W // (2 * patch_size)), [0, 1] * (W // (2 * patch_size))]), (H // (2 * patch_size), 1))

            for b in range(B):
                image = images[b].permute(1, 2, 0).cpu().numpy() * 255
                image = image.astype(np.uint8)

                # Apply masking
                mask = np.repeat(np.repeat(pattern, patch_size, axis=0), patch_size, axis=1)
                image[mask == 0] = 0  # Apply chessboard pattern masking

                # Convert back to PyTorch format
                image = image.astype(np.float32) / 255.0
                masked_images[b] = torch.from_numpy(image).permute(2, 0, 1)


    return masked_images
    
def random_patch_masking(images, patch_size, mask_ratio):

    B, C, H, W = images.shape
    n_patches_horizontal = H // patch_size
    n_patches_vertical = W // patch_size
    n_patches_per_image = n_patches_horizontal * n_patches_vertical
    n_patches_to_mask = int(n_patches_per_image * mask_ratio)

    masked_images = images.clone()

    for i in range(B):
        mask_indices = torch.randint(0, n_patches_per_image, (n_patches_to_mask,))

        for idx in mask_indices:

            row = torch.div(idx, n_patches_vertical, rounding_mode='floor') * patch_size
            col = (idx % n_patches_vertical) * patch_size
            masked_images[i, :, row:row + patch_size, col:col + patch_size] = 0

    return masked_images

def select_patches_gaussian(images, patch_size, mask_ratio):

    select_ratio = 1-mask_ratio
    B, C, H, W = images.shape
    # 이미지 중앙 좌표
    image_center = np.array([H / 2, W / 2])
    # 패치 중앙 좌표 계산을 위한 그리드 생성
    grid_x, grid_y = np.meshgrid(np.arange(patch_size / 2, W, patch_size),
                                 np.arange(patch_size / 2, H, patch_size))
    patch_centers = np.stack([grid_y.flatten(), grid_x.flatten()], axis=1)
    # 거리 및 확률 계산
    distances = np.linalg.norm(patch_centers - image_center, axis=1)
    sigma = H / 4  # 적절한 분산 값 선택
    probabilities = np.exp(-distances ** 2 / (2 * sigma ** 2))
    probabilities /= probabilities.sum()  # 정규화

    # 선택할 패치 수
    n_patches_to_select = int(np.ceil(len(patch_centers) * select_ratio))
    # 확률에 따라 패치 선택
    selected_indices = np.random.choice(len(patch_centers), size=n_patches_to_select,
                                        replace=False, p=probabilities)
    # 마스킹 이미지 생성
    masked_images = torch.zeros_like(images)
    for idx in selected_indices:
        row, col = map(int, patch_centers[idx])
        row_start = max(row - patch_size // 2, 0)
        col_start = max(col - patch_size // 2, 0)
        masked_images[:, :, row_start:row_start + patch_size, col_start:col_start + patch_size] = images[:, :,row_start:row_start + patch_size, col_start:col_start + patch_size]

    return masked_images
    
def batch_mask_center_block(images, patch_size, mask_ratio):
    B, C, H, W = images.shape
    masked_images = images.clone()

    # Calculate the number of pixels to mask
    total_pixels = H * W
    pixels_to_mask = int(total_pixels * mask_ratio)

    # Determine the size of the mask based on the mask ratio
    mask_height = mask_width = int(np.sqrt(pixels_to_mask))

    # If a square mask is too big or too small, adjust to make a rectangular mask
    if mask_height * mask_width != pixels_to_mask:
        # Calculate aspect ratio of the image
        aspect_ratio = W / H
        # Start with a square mask and adjust width and height
        if aspect_ratio >= 1:  # Width is greater than height
            mask_height = H
            mask_width = int(pixels_to_mask / mask_height)
        else:  # Height is greater than width
            mask_width = W
            mask_height = int(pixels_to_mask / mask_width)

        # Fine-tune the mask size to ensure the exact number of pixels are masked
        while mask_height * mask_width < pixels_to_mask:
            if (mask_width + patch_size) * mask_height <= pixels_to_mask:
                mask_width += patch_size
            elif mask_width * (mask_height + patch_size) <= pixels_to_mask:
                mask_height += patch_size

        while mask_height * mask_width > pixels_to_mask:
            if (mask_width - patch_size) * mask_height >= pixels_to_mask:
                mask_width -= patch_size
            elif mask_width * (mask_height - patch_size) >= pixels_to_mask:
                mask_height -= patch_size

    # Ensure mask dimensions are multiples of patch_size
    mask_width = mask_width // patch_size * patch_size
    mask_height = mask_height // patch_size * patch_size

    # Calculate the start position to center the mask
    start_i = (H - mask_height) // 2
    start_j = (W - mask_width) // 2

    # Apply the mask to the center of the image
    for b in range(B):
        masked_images[b, :, start_i:start_i + mask_height, start_j:start_j + mask_width] = 0

    return masked_images
    
def auto_canny(image, sigma=0.33):
    image = cv2.GaussianBlur(image, (3, 3), 0)
    v = np.median(image)
    lower = int(max(0, (1.0 - sigma) * v))
    upper = int(min(255, (1.0 + sigma) * v))
    edged = cv2.Canny(image, lower, upper)
    return edged

def batch_mask_patches_auto_canny(images, patch_size=2, mask_ratio=0.5):
    import cv2
    import numpy as np
    import torch

    B, C, H, W = images.shape
    masked_images = images.clone()

    for b in range(B):
        image = images[b].permute(1, 2, 0).cpu().numpy() * 255
        image = image.astype(np.uint8)

        # Convert RGB to BGR for OpenCV
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

        # Apply Canny edge detection
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        edges = auto_canny(gray)

        edge_strengths = []
        for i in range(0, edges.shape[0], patch_size):
            for j in range(0, edges.shape[1], patch_size):
                patch = edges[i:i+patch_size, j:j+patch_size]
                edge_strength = np.sum(patch)  # Calculate the sum of edge strengths in the patch
                edge_strengths.append((edge_strength, (i, j)))

        # Sort patches by their edge strength in ascending order
        edge_strengths.sort()
        num_patches_to_mask = int(len(edge_strengths) * mask_ratio)

        # Mask patches with the lowest edge strength
        for _, (i, j) in edge_strengths[:num_patches_to_mask]:
            image[i:i+patch_size, j:j+patch_size] = 0

        # Convert image back to the original format
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image.astype(np.float32) / 255.0
        masked_images[b] = torch.from_numpy(image).permute(2, 0, 1)

    return masked_images
    
import matplotlib.pyplot as plt

# Assuming 'images' is your batch of images from CIFAR10
# Let's say you've already applied the masking functions and stored their outputs
transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
loader = DataLoader(dataset, batch_size=1, shuffle=True)

images, _ = next(iter(loader))  # Load a single batch (one image)
# Mask ratios to iterate over
mask_ratios = [0.25, 0.5, 0.75]

# Initialize matplotlib subplot
fig, axs = plt.subplots(len(mask_ratios) + 1, 6, figsize=(20, 10))  # Add one for the original image row

# Display the original image in the first row
axs[0, 0].imshow(images[0].permute(1, 2, 0).numpy())
axs[0, 0].set_title('Original Image')
for ax in axs[0, 1:]:
    ax.axis('off')  # Hide unused plots in the first row

# Titles for each column
column_titles = ['Original', 'Chessboard', 'Random', 'Gaussian', 'Center Block', 'Auto Canny']
for ax, col in zip(axs[0], column_titles):
    ax.set_title(col)

# Row labels for each masking ratio
row_labels = ['MR = 25%', 'MR = 50%', 'MR = 75%']

# Fill in the subplots with masked images for each ratio and method
for i, ratio in enumerate(mask_ratios):
    # Placeholder for generating masked images with each method
    masked_images_methods = [
        images,  # Assuming 'images' is the original images from CIFAR10
        mask_patches_chessboard(images, patch_size=2, mask_ratio=ratio, complexity_based=True),
        random_patch_masking(images, patch_size=2, mask_ratio=ratio),
        select_patches_gaussian(images, patch_size=2, mask_ratio=ratio),
        batch_mask_center_block(images, patch_size=2, mask_ratio=ratio),
        batch_mask_patches_auto_canny(images, patch_size=2, mask_ratio=ratio)
    ]

    # Display masked images
    for j, masked_images in enumerate(masked_images_methods):
        img = masked_images[0].permute(1, 2, 0).numpy()  # Convert to numpy array for matplotlib
        axs[i + 1, j].imshow(img)
        axs[i + 1, j].axis('off')  # Hide axes for cleaner visualization

    # Remove x and y ticks
    axs[i + 1, 0].set_xticks([])
    axs[i + 1, 0].set_yticks([])

    # Add row label
    axs[i + 1, 0].set_ylabel(row_labels[i], rotation=0, size='large', labelpad=60)

plt.tight_layout()
plt.show()

'Main' 카테고리의 다른 글

ChessBoard Masking with Colored Random Noise  (0) 2024.05.02
CBS (odd, odd) masking  (0) 2024.04.18
ChessBoard Selection  (0) 2024.04.09
Chessboard_masking  (0) 2024.04.08
Mean_far_selection  (0) 2024.04.01
Comments