UOMOP

Masking strategy comparison ( STL10 ) 본문

카테고리 없음

Masking strategy comparison ( STL10 )

Happy PinGu 2024. 6. 13. 15:06
import numpy as np
import cv2
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

def patch_importance(image, patch_size=2, type='variance', how_many=2, noise_scale=0):
    if isinstance(image, torch.Tensor):
        image = image.numpy()

    H, W = image.shape[-2:]
    extended_patch_size = patch_size + 2 * how_many
    value_map = np.zeros((H // patch_size, W // patch_size))

    for i in range(0, H, patch_size):
        for j in range(0, W, patch_size):
            start_i = max(i - how_many, 0)
            end_i = min(i + patch_size + how_many, H)
            start_j = max(j - how_many, 0)
            end_j = min(j + patch_size + how_many, W)

            extended_patch = image[start_i:end_i, start_j:end_j]

            if type == 'variance':
                value = np.std(extended_patch)
            elif type == 'mean_brightness':
                value = np.mean(extended_patch)
            elif type == 'contrast':
                value = extended_patch.max() - extended_patch.min()
            elif type == 'edge_density':
                dy, dx = np.gradient(extended_patch)
                value = np.sum(np.sqrt(dx ** 2 + dy ** 2))
            elif type == 'color_diversity':
                value = np.std(extended_patch)

            noise = np.random.randn() * noise_scale
            value_map[i // patch_size, j // patch_size] = value + noise

    return value_map

def mask_based_on_importance(images, patch_size=2, mask_ratio=0.5, importance_type='variance', how_many=0, noise_scale=0):
    B, C, H, W = images.shape
    masked_images = images.clone()
    unmasked_counts = []

    mask_ratio = 1 - mask_ratio

    target_unmasked_ratio = 1 - mask_ratio
    num_patches = (H // patch_size) * (W // patch_size)
    target_unmasked_patches = int(num_patches * target_unmasked_ratio)

    for b in range(B):
        unmasked_count = 0
        patch_importance_map = patch_importance(images[b, 0], patch_size, importance_type, how_many, noise_scale)

        patch_importance_flat = patch_importance_map.flatten()
        sorted_indices = np.argsort(patch_importance_flat)

        mask = np.zeros((H // patch_size, W // patch_size), dtype=bool)

        for idx in sorted_indices:
            if unmasked_count >= target_unmasked_patches:
                break
            mask[np.unravel_index(idx, mask.shape)] = True
            unmasked_count += 1

        for i in range(H // patch_size):
            for j in range(W // patch_size):
                if mask[i, j]:
                    masked_images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = 0

        unmasked_counts.append(unmasked_count)

    return masked_images, unmasked_counts


def chessboard_mask(images, patch_size=2, mask_ratio=0.5, importance_type='variance', how_many=0, noise_scale=0):
    B, C, H, W = images.shape
    masked_images = images.clone()
    unmasked_counts = []

    target_unmasked_ratio = 1 - mask_ratio
    num_patches = (H // patch_size) * (W // patch_size)
    target_unmasked_patches = int(num_patches * target_unmasked_ratio)

    for b in range(B):
        unmasked_count = 0
        patch_importance_map = patch_importance(images[b, 0], patch_size, importance_type, how_many, noise_scale)

        mask = np.zeros((H // patch_size, W // patch_size), dtype=bool)
        for i in range(H // patch_size):
            for j in range(W // patch_size):
                if (i + j) % 2 == 0:
                    mask[i, j] = True

        unmasked_count = np.sum(~mask)

        if mask_ratio < 0.5:
            masked_indices = np.argwhere(mask)
            importances = patch_importance_map[mask]
            sorted_indices = masked_indices[np.argsort(importances)[::-1]]

            for idx in sorted_indices:
                if unmasked_count >= target_unmasked_patches:
                    break
                mask[tuple(idx)] = False
                unmasked_count += 1

        elif mask_ratio > 0.5:
            unmasked_indices = np.argwhere(~mask)
            importances = patch_importance_map[~mask]
            sorted_indices = unmasked_indices[np.argsort(importances)]

            for idx in sorted_indices:
                if unmasked_count <= target_unmasked_patches:
                    break
                mask[tuple(idx)] = True
                unmasked_count -= 1

        for i in range(H // patch_size):
            for j in range(W // patch_size):
                if mask[i, j]:
                    masked_images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = 0

        unmasked_counts.append(unmasked_count)

    return masked_images, unmasked_counts


def random_mask(images, patch_size=2, mask_ratio=0.5):
    B, C, H, W = images.shape
    masked_images = images.clone()
    num_patches = (H // patch_size) * (W // patch_size)
    num_masked_patches = int(num_patches * mask_ratio)
    unmasked_counts = []

    for b in range(B):
        mask = np.zeros((H // patch_size, W // patch_size), dtype=bool)
        mask_indices = np.random.choice(num_patches, num_masked_patches, replace=False)
        mask_indices = np.unravel_index(mask_indices, mask.shape)

        mask[mask_indices] = True

        for i in range(H // patch_size):
            for j in range(W // patch_size):
                if mask[i, j]:
                    masked_images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = 0

        unmasked_counts.append(num_patches - num_masked_patches)

    return masked_images, unmasked_counts


def ppm(images, patch_size=2, mask_ratio=0.5):
    B, C, H, W = images.shape
    masked_images = images.clone()
    num_patches = (H // patch_size) * (W // patch_size)
    num_masked_patches = int(num_patches * mask_ratio)
    unmasked_counts = []

    center_x, center_y = H // 2, W // 2
    sigma = max(H, W) / 8

    for b in range(B):
        mask = np.zeros((H // patch_size, W // patch_size), dtype=bool)

        x = np.arange(H // patch_size)
        y = np.arange(W // patch_size)
        xv, yv = np.meshgrid(x, y, indexing='ij')

        distances = np.sqrt((xv - center_x // patch_size) ** 2 + (yv - center_y // patch_size) ** 2)

        probabilities = np.exp(-distances ** 2 / (2 * sigma ** 2))
        probabilities /= probabilities.sum()  # Normalize to sum to 1

        probabilities_flat = probabilities.flatten()
        mask_indices = np.random.choice(num_patches, num_masked_patches, replace=False, p=probabilities_flat)
        mask_indices = np.unravel_index(mask_indices, mask.shape)

        mask[mask_indices] = True

        for i in range(H // patch_size):
            for j in range(W // patch_size):
                if mask[i, j]:
                    masked_images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = 0

        unmasked_counts.append(num_patches - num_masked_patches)

    return masked_images, unmasked_counts


def cpm(images, patch_size=2, mask_ratio=0.5):
    B, C, H, W = images.shape
    masked_images = images.clone()
    num_patches = (H // patch_size) * (W // patch_size)
    num_masked_patches = int(num_patches * mask_ratio)
    unmasked_counts = []

    center_x, center_y = H // 2, W // 2
    sigma = max(H, W) / 0.0005

    for b in range(B):
        mask = np.zeros((H // patch_size, W // patch_size), dtype=bool)

        x = np.arange(H // patch_size)
        y = np.arange(W // patch_size)
        xv, yv = np.meshgrid(x, y, indexing='ij')

        distances = np.sqrt((xv - center_x // patch_size) ** 2 + (yv - center_y // patch_size) ** 2)

        probabilities = 1 - np.exp(-distances ** 2 / (2 * sigma ** 2))
        probabilities /= probabilities.sum() 

        probabilities_flat = probabilities.flatten()
        mask_indices = np.random.choice(num_patches, num_masked_patches, replace=False, p=probabilities_flat)
        mask_indices = np.unravel_index(mask_indices, mask.shape)

        mask[mask_indices] = True

        for i in range(H // patch_size):
            for j in range(W // patch_size):
                if mask[i, j]:
                    masked_images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = 0

        unmasked_counts.append(num_patches - num_masked_patches)

    return masked_images, unmasked_counts




transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.STL10(root='./data', split='unlabeled', download=True, transform=transform)
loader = DataLoader(dataset, batch_size=1, shuffle=True)

images, _ = next(iter(loader)) 
mask_ratios = [0.25, 0.5, 0.75]

fig, axs = plt.subplots(len(mask_ratios) + 1, 6, figsize=(20, 10)) 

axs[0, 0].imshow(images[0].permute(1, 2, 0).numpy())
axs[0, 0].set_title('Original Image')
for ax in axs[0, 1:]:
    ax.axis('off')  

column_titles = ['Original', 'Chessboard', 'Variance', 'Random', 'Periphery', 'Centural']
for ax, col in zip(axs[0], column_titles):
    ax.set_title(col)

row_labels = ['MR = 25%', 'MR = 50%', 'MR = 75%']

patch_size = 3
how_many = 2

for i, ratio in enumerate(mask_ratios):
    masked_images_methods = [
        images, 
        chessboard_mask(images, patch_size=patch_size, mask_ratio=ratio, importance_type='variance', how_many=how_many, noise_scale=0)[0],
        mask_based_on_importance(images, patch_size=patch_size, mask_ratio=ratio, importance_type='variance', how_many=how_many, noise_scale=0)[0],
        random_mask(images, patch_size=patch_size, mask_ratio=ratio)[0],
        ppm(images, patch_size=patch_size, mask_ratio=ratio)[0],
        cpm(images, patch_size=patch_size, mask_ratio=ratio)[0]
    ]

    for j, masked_images in enumerate(masked_images_methods):
        img = masked_images[0].permute(1, 2, 0).numpy()  
        axs[i + 1, j].imshow(img)
        axs[i + 1, j].axis('off')  

    axs[i + 1, 0].set_xticks([])
    axs[i + 1, 0].set_yticks([])

    axs[i + 1, 0].set_ylabel(row_labels[i], rotation=0, size='large', labelpad=60)

plt.tight_layout()
plt.show()

Comments