UOMOP

Chessboard_masking 본문

Main

Chessboard_masking

Happy PinGu 2024. 4. 8. 19:00
import cv2
import numpy as np
import torch
import torch.nn.functional as F

def patch_std(image, patch_size=2):
    # Calculate the standard deviation within each patch
    H, W = image.shape
    std_map = np.zeros((H // patch_size, W // patch_size))
    for i in range(0, H, patch_size):
        for j in range(0, W, patch_size):
            patch = image[i:i+patch_size, j:j+patch_size]
            std_map[i // patch_size, j // patch_size] = np.std(patch)
    return std_map

def mask_patches_chessboard(images, patch_size=2, mask_ratio=0.5, complexity_based=False):

    if mask_ratio != 0.5 :
        B, C, H, W = images.shape
        masked_images = images.clone()

        for b in range(B):
            image = images[b].permute(1, 2, 0).cpu().numpy() * 255
            image = image.astype(np.uint8)
            gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)

            # Calculate complexity for each patch
            complexity_map = patch_std(gray, patch_size)

            # Initialize mask with chessboard pattern for complexity map dimensions
            complexity_height, complexity_width = complexity_map.shape
            mask = np.zeros((complexity_height, complexity_width), dtype=bool)
            mask[::2, ::2] = 1
            mask[1::2, 1::2] = 1

            if complexity_based:
                if mask_ratio > 0.5:
                    additional_masking_ratio = (mask_ratio - 0.5) / 0.5
                    complexity_threshold = np.quantile(complexity_map[~mask], 1 - additional_masking_ratio)
                    additional_mask = complexity_map <= complexity_threshold
                    mask[~mask] = additional_mask[~mask]

                
                else:
                    unmasking_ratio = (0.5 - mask_ratio) / 0.5
                    complexity_threshold = np.quantile(complexity_map[mask], unmasking_ratio)
                    unmask = complexity_map >= complexity_threshold
                    mask[mask] = ~unmask[mask]

            # Apply mask to the original image based on complexity map
            for i in range(complexity_height):
                for j in range(complexity_width):
                    if mask[i, j]:
                        image[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size] = 0

            # Convert image back to PyTorch format
            image = image.astype(np.float32) / 255.0
            masked_images[b] = torch.from_numpy(image).permute(2, 0, 1)

    elif mask_ratio == 0.5 :
            B, C, H, W = images.shape
            masked_images = images.clone()

            # Create the chessboard pattern
            pattern = np.tile(np.array([[1, 0] * (W // (2 * patch_size)), [0, 1] * (W // (2 * patch_size))]), (H // (2 * patch_size), 1))

            for b in range(B):
                image = images[b].permute(1, 2, 0).cpu().numpy() * 255
                image = image.astype(np.uint8)

                # Apply masking
                mask = np.repeat(np.repeat(pattern, patch_size, axis=0), patch_size, axis=1)
                image[mask == 0] = 0  # Apply chessboard pattern masking

                # Convert back to PyTorch format
                image = image.astype(np.float32) / 255.0
                masked_images[b] = torch.from_numpy(image).permute(2, 0, 1)


    return masked_images


import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Load a batch of images from CIFAR-10
transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
loader = DataLoader(dataset, batch_size=1, shuffle=True)

images, _ = next(iter(loader))  # Load a single batch (one image)

# Helper function to display images
def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis('off')

# Display original image
plt.figure(figsize=(12, 8))
plt.subplot(2, 2, 1)
imshow(images[0])
plt.title("Original Image")

# Display masked images
mask_ratios = [0.25, 0.5, 0.75]
for idx, mask_ratio in enumerate(mask_ratios, start=2):
    # Adjust the `mask_patches_chessboard` function call according to your setup
    masked_images = mask_patches_chessboard(images, patch_size=2, mask_ratio=mask_ratio, complexity_based=True)
    plt.subplot(2, 2, idx)
    imshow(masked_images[0])
    plt.title(f"Masked Image {int(mask_ratio * 100)}%")

plt.tight_layout()
plt.show()

Comments