UOMOP

Mean_far_selection 본문

Main

Mean_far_selection

Happy PinGu 2024. 4. 1. 18:34
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch

def select_patches_to_mask(images, patch_size, mask_ratio):

    B, C, H, W = images.shape
    n_patches_horizontal = H // patch_size
    n_patches_vertical = W // patch_size
    total_patches = n_patches_horizontal * n_patches_vertical
    n_patches_to_mask = int(total_patches * mask_ratio)

    mask = torch.zeros((B, H, W), dtype=torch.bool)

    for b in range(B):
        image = images[b]
        channel_means = image.mean(dim=(1, 2))
        selected_patches = []

        min_diff = float('inf')
        while len(selected_patches) < n_patches_to_mask:
            for i in range(0, H, patch_size):
                for j in range(0, W, patch_size):
                    if (i // patch_size, j // patch_size) in selected_patches:
                        continue

                    patch = image[:, i:i + patch_size, j:j + patch_size]
                    patch_diff = torch.abs(patch - channel_means[:, None, None]).sum()

                    if patch_diff < min_diff:
                        min_diff = patch_diff
                        selected_patch = (i, j)

            selected_patches.append((selected_patch[0] // patch_size, selected_patch[1] // patch_size))
            mask[b, selected_patch[0]:selected_patch[0] + patch_size,
            selected_patch[1]:selected_patch[1] + patch_size] = True
            min_diff = float('inf')


    mask_expanded = mask.unsqueeze(1).repeat(1, C, 1, 1)

    masked_images = images.clone()
    masked_images[mask_expanded] = 0
    return masked_images


transform = transforms.Compose([transforms.ToTensor()])
cifar10 = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
image, _ = cifar10[11]

image = image.unsqueeze(0)


patch_size = 2
mask_ratio = 0.75
masked_image = select_patches_to_mask(image, patch_size, mask_ratio)

fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(image.squeeze().permute(1, 2, 0))
axes[0].set_title('Original Image')
axes[0].axis('off')

axes[1].imshow(masked_image.squeeze().permute(1, 2, 0))
axes[1].set_title('Masked Image')
axes[1].axis('off')

plt.show()

Comments