UOMOP

Random Selection 본문

DE/Code

Random Selection

Happy PinGu 2024. 7. 30. 17:56
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor

# Load CIFAR-10 dataset
cifar10 = CIFAR10(root='.', download=True, transform=ToTensor())
image, _ = cifar10[0]
images = image.unsqueeze(0)  # Add batch dimension

# Define the random_mask function
def random_mask(images, patch_size=2, mask_ratio=0.5):
    B, C, H, W = images.shape
    masked_images = images.clone() 
    num_patches = (H // patch_size) * (W // patch_size)   # 256
    print(num_patches)
    target_unmasked_ratio = 1 - mask_ratio

    target_unmasked_patches = int(num_patches * target_unmasked_ratio)
    num_masked_patches = num_patches - target_unmasked_patches
    print(111111111)
    print(num_masked_patches)
    unmasked_patches = []
    patch_index = []


    for b in range(B):
        mask = np.zeros((H // patch_size, W // patch_size), dtype=bool)
        mask_indices = np.random.choice(num_patches, num_masked_patches, replace=False)
        mask_indices = np.unravel_index(mask_indices, mask.shape)
        mask[mask_indices] = True

        patches = []
        for i in range(H // patch_size):
            for j in range(W // patch_size):
                if mask[i, j]:
                    masked_images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = 0
                else:
                    patch = images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size]
                    patches.append(patch)
                    patch_index.append((H // patch_size) * i + j)

        unmasked_patches.append(torch.cat(patches, dim=-1))
        print(1)
        
        unmasked_patches_image = torch.cat(unmasked_patches, dim=-1)
        print(unmasked_patches_image.shape)

        split_len = 28
        split_tensor = torch.split(unmasked_patches_image, split_len, dim=2)
        print(split_tensor[0].shape)
        print(len(split_tensor))
        reshaped = torch.cat(split_tensor, dim=1)
        print(reshaped.shape)

    return masked_images, reshaped, torch.tensor(patch_index)
cifar10 = CIFAR10(root='.', download=True, transform=ToTensor())
image, _ = cifar10[0]
images = image.unsqueeze(0)  # Add batch dimension

# Apply the chessboard_mask function
masked_images, reshaped, patch_index = random_mask(images, patch_size=2, mask_ratio=0.2343)

# Plot the original image, masked image, and reshaped image
fig, axs = plt.subplots(1, 2, figsize=(15, 5))

# Original image
axs[0].imshow(images[0].permute(1, 2, 0))
axs[0].set_title('Original Image')
axs[0].axis('off')

# Masked image
axs[1].imshow(masked_images[0].permute(1, 2, 0))
axs[1].set_title('Masked Image')
axs[1].axis('off')

plt.show()

# Plot the reshaped image
fig, ax = plt.subplots(1, 1, figsize=(5, 5))

# Reshaped image
print(1)
print(reshaped.shape)

reshaped_image = reshaped.permute(1, 2, 0)

ax.imshow(reshaped_image)
ax.set_title('Reshaped Image')
ax.axis('off')

plt.show()

'DE > Code' 카테고리의 다른 글

Patch selection code (CBS)  (0) 2024.08.05
Image (variance , entropy , edge)  (0) 2024.08.05
CheckerBoard Selection  (0) 2024.07.30
Masking comparison (STL10)  (0) 2024.07.29
Masking comparison (Cifar10)  (0) 2024.07.29
Comments