UOMOP

Masked, Reshaped, index Gen 본문

DE/Code

Masked, Reshaped, index Gen

Happy PinGu 2024. 8. 10. 21:16
import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
def patch_importance(image, patch_size=2, type='variance', how_many=2):
    if isinstance(image, torch.Tensor):
        image = image.numpy()

    H, W = image.shape[-2:]

    value_map = np.zeros((H // patch_size, W // patch_size))

    for i in range(0, H, patch_size):
        for j in range(0, W, patch_size):
            start_i = max(i - how_many, 0)
            end_i = min(i + patch_size + how_many, H)
            start_j = max(j - how_many, 0)
            end_j = min(j + patch_size + how_many, W)

            extended_patch = image[start_i:end_i, start_j:end_j]

            if type == 'variance':
                value = np.std(extended_patch)
            elif type == 'mean_brightness':
                value = np.mean(extended_patch)
            elif type == 'contrast':
                value = extended_patch.max() - extended_patch.min()
            elif type == 'edge_density':
                dy, dx = np.gradient(extended_patch)
                value = np.sum(np.sqrt(dx ** 2 + dy ** 2))
            elif type == 'color_diversity':
                value = np.std(extended_patch)

            value_map[i // patch_size, j // patch_size] = value

    return value_map

def chessboard_mask(images, patch_size=1, mask_ratio=0.5, importance_type='variance', how_many=1):
    for mr_i in range(len(mask_ratio)):
        MR = mask_ratio[mr_i]
        B, C, H, W = images.shape
        masked_images = images.clone()
        unmasked_counts = []
        unmasked_patches = []
        patch_index = []

        target_unmasked_ratio = 1 - MR
        num_patches = (H // patch_size) * (W // patch_size)
        target_unmasked_patches = int(num_patches * target_unmasked_ratio)

        for b in range(B):

            patch_importance_map = patch_importance(images[b, 0], patch_size, importance_type, how_many)

            mask = np.zeros((H // patch_size, W // patch_size), dtype=bool)
            for i in range(H // patch_size):
                for j in range(W // patch_size):
                    if (i + j) % 2 == 0:
                        mask[i, j] = True

            unmasked_count = np.sum(~mask)

            if MR < 0.5:
                masked_indices = np.argwhere(mask)
                importances = patch_importance_map[mask]
                sorted_indices = masked_indices[np.argsort(importances)[::-1]]

                for idx in sorted_indices:
                    if unmasked_count >= target_unmasked_patches:
                        break
                    mask[tuple(idx)] = False
                    unmasked_count += 1

            elif MR > 0.5:
                unmasked_indices = np.argwhere(~mask)
                importances = patch_importance_map[~mask]
                sorted_indices = unmasked_indices[np.argsort(importances)]

                for idx in sorted_indices:
                    if unmasked_count <= target_unmasked_patches:
                        break
                    mask[tuple(idx)] = True
                    unmasked_count -= 1

            patches = []
            for i in range(H // patch_size):
                for j in range(W // patch_size):
                    if mask[i, j]:
                        masked_images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = 0
                    else:
                        patch = images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size]

                        patches.append(patch)
                        patch_index.append((H // patch_size)*i + j)

            unmasked_patches.append(torch.cat(patches, dim=-1))
            unmasked_counts.append(unmasked_count)
            unmasked_patches_image = torch.cat(unmasked_patches, dim=-1)


                
            if MR == 0.234375 :
                split_len = 28
                split_tensor = torch.split(unmasked_patches_image, split_len, dim=2)
                masked_23 = masked_images
                reshaped_23 = torch.cat(split_tensor, dim=1)
                index_23 = torch.tensor(patch_index)

            elif MR == 0.4375 :
                split_len = 24
                split_tensor = torch.split(unmasked_patches_image, split_len, dim=2)
                masked_43 = masked_images
                reshaped_43 = torch.cat(split_tensor, dim=1)
                index_43 = torch.tensor(patch_index)

            elif MR == 0.609375 :
                split_len = 20
                split_tensor = torch.split(unmasked_patches_image, split_len, dim=2)
                masked_60 = masked_images
                reshaped_60 = torch.cat(split_tensor, dim=1)
                index_60 = torch.tensor(patch_index)



    return  masked_23, masked_43, masked_60, reshaped_23, reshaped_43, reshaped_60, index_23, index_43, index_60

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor()
])

cifar10 = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
image, _ = cifar10[0]  # Get the first image from the test set
image = image.unsqueeze(0)  # Add batch dimension
# CIFAR10 데이터셋 로드
masked_23, masked_43, masked_60, reshaped_23, reshaped_43, reshaped_60, index_23, index_43, index_60 = chessboard_mask(image, mask_ratio=[0.234375, 0.4375, 0.609375], patch_size=1)

# Plot the original, masked, and reshaped images
fig, axs = plt.subplots(2, 4, figsize=(16, 8))

axs[0, 0].imshow(image.squeeze().permute(1, 2, 0))
axs[0, 0].set_title("Original Image")
axs[0, 0].axis('off')

axs[0, 1].imshow(masked_23.squeeze().permute(1, 2, 0))
axs[0, 1].set_title("Masked Image (MR=0.234375)")
axs[0, 1].axis('off')

axs[0, 2].imshow(masked_43.squeeze().permute(1, 2, 0))
axs[0, 2].set_title("Masked Image (MR=0.4375)")
axs[0, 2].axis('off')

axs[0, 3].imshow(masked_60.squeeze().permute(1, 2, 0))
axs[0, 3].set_title("Masked Image (MR=0.609375)")
axs[0, 3].axis('off')

axs[1, 1].imshow(reshaped_23.squeeze().permute(1, 2, 0))
axs[1, 1].set_title("Reshaped Image (MR=0.234375)")
axs[1, 1].axis('off')

axs[1, 2].imshow(reshaped_43.squeeze().permute(1, 2, 0))
axs[1, 2].set_title("Reshaped Image (MR=0.4375)")
axs[1, 2].axis('off')

axs[1, 3].imshow(reshaped_60.squeeze().permute(1, 2, 0))
axs[1, 3].set_title("Reshaped Image (MR=0.609375)")
axs[1, 3].axis('off')

plt.show()

'DE > Code' 카테고리의 다른 글

Cifar10 Fourier  (0) 2024.08.27
DE : Selection (33%, 60%, 75%)  (0) 2024.08.15
DeepJSCC BenchMark  (0) 2024.08.10
Proposed Network Architecture  (0) 2024.08.10
Adaptive decoder  (0) 2024.08.06
Comments