UOMOP

CheckerBoard Selection 본문

DE/Code

CheckerBoard Selection

Happy PinGu 2024. 7. 30. 17:58
def patch_importance(image, patch_size=2, type='variance', how_many=2, noise_scale=0):
    if isinstance(image, torch.Tensor):
        image = image.numpy()

    H, W = image.shape[-2:]
    extended_patch_size = patch_size + 2 * how_many
    value_map = np.zeros((H // patch_size, W // patch_size))

    for i in range(0, H, patch_size):
        for j in range(0, W, patch_size):
            start_i = max(i - how_many, 0)
            end_i = min(i + patch_size + how_many, H)
            start_j = max(j - how_many, 0)
            end_j = min(j + patch_size + how_many, W)

            extended_patch = image[start_i:end_i, start_j:end_j]

            if type == 'variance':
                value = np.std(extended_patch)
            elif type == 'mean_brightness':
                value = np.mean(extended_patch)
            elif type == 'contrast':
                value = extended_patch.max() - extended_patch.min()
            elif type == 'edge_density':
                dy, dx = np.gradient(extended_patch)
                value = np.sum(np.sqrt(dx ** 2 + dy ** 2))
            elif type == 'color_diversity':
                value = np.std(extended_patch)

            noise = np.random.randn() * noise_scale
            value_map[i // patch_size, j // patch_size] = value + noise

    return value_map

def chessboard_mask(images, patch_size=2, mask_ratio=0.5, importance_type='variance', how_many=1, noise_scale=0):
    B, C, H, W = images.shape
    masked_images = images.clone()
    unmasked_counts = []
    unmasked_patches = []
    patch_index = []

    target_unmasked_ratio = 1 - mask_ratio
    num_patches = (H // patch_size) * (W // patch_size)

    num_masked_patches = int(num_patches * mask_ratio)
    
    target_unmasked_patches = int(num_patches * target_unmasked_ratio)

    print(target_unmasked_patches)

    for b in range(B):

        patch_importance_map = patch_importance(images[b, 0], patch_size, importance_type, how_many, noise_scale)

        mask = np.zeros((H // patch_size, W // patch_size), dtype=bool)
        for i in range(H // patch_size):
            for j in range(W // patch_size):
                if (i + j) % 2 == 0:
                    mask[i, j] = True

        unmasked_count = np.sum(~mask)

        if mask_ratio < 0.5:
            masked_indices = np.argwhere(mask)
            importances = patch_importance_map[mask]
            sorted_indices = masked_indices[np.argsort(importances)[::-1]]

            for idx in sorted_indices:
                if unmasked_count >= target_unmasked_patches:
                    break
                mask[tuple(idx)] = False
                unmasked_count += 1

        elif mask_ratio > 0.5:
            unmasked_indices = np.argwhere(~mask)
            importances = patch_importance_map[~mask]
            sorted_indices = unmasked_indices[np.argsort(importances)]

            for idx in sorted_indices:
                if unmasked_count <= target_unmasked_patches:
                    break
                mask[tuple(idx)] = True
                unmasked_count -= 1

        patches = []
        for i in range(H // patch_size):
            for j in range(W // patch_size):
                if mask[i, j]:
                    masked_images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = 0
                else:
                    patch = images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size]

                    patches.append(patch)
                    patch_index.append((H // patch_size)*i + j)

        unmasked_patches.append(torch.cat(patches, dim=-1))
        unmasked_counts.append(unmasked_count)
        unmasked_patches_image = torch.cat(unmasked_patches, dim=-1)

        #print(unmasked_patches_image.shape)

        if mask_ratio == 0.33984 :
            split_len = 26
        elif mask_ratio == 0.2343 :
            split_len = 28

        split_tensor = torch.split(unmasked_patches_image, split_len, dim=2)

        reshaped = torch.cat(split_tensor, dim=1)


    return masked_images, reshaped, torch.tensor(patch_index)
cifar10 = CIFAR10(root='.', download=True, transform=ToTensor())
image, _ = cifar10[0]
images = image.unsqueeze(0)  # Add batch dimension

# Apply the chessboard_mask function
masked_images, reshaped, patch_index = chessboard_mask(images, patch_size=2, mask_ratio=0.2343)

# Plot the original image, masked image, and reshaped image
fig, axs = plt.subplots(1, 2, figsize=(15, 5))

# Original image
axs[0].imshow(images[0].permute(1, 2, 0))
axs[0].set_title('Original Image')
axs[0].axis('off')

# Masked image
axs[1].imshow(masked_images[0].permute(1, 2, 0))
axs[1].set_title('Masked Image')
axs[1].axis('off')

plt.show()

# Plot the reshaped image
fig, ax = plt.subplots(1, 1, figsize=(5, 5))

# Reshaped image
print(1)
print(reshaped.shape)

reshaped_image = reshaped.permute(1, 2, 0)

ax.imshow(reshaped_image)
ax.set_title('Reshaped Image')
ax.axis('off')

plt.show()

'DE > Code' 카테고리의 다른 글

Patch selection code (CBS)  (0) 2024.08.05
Image (variance , entropy , edge)  (0) 2024.08.05
Random Selection  (0) 2024.07.30
Masking comparison (STL10)  (0) 2024.07.29
Masking comparison (Cifar10)  (0) 2024.07.29
Comments