UOMOP

CBS 본문

Main

CBS

Happy PinGu 2024. 7. 25. 16:47
def patch_importance(image, patch_size=2, type='variance', how_many=2, noise_scale=0):
    if isinstance(image, torch.Tensor):
        image = image.numpy()

    H, W = image.shape[-2:]
    extended_patch_size = patch_size + 2 * how_many
    value_map = np.zeros((H // patch_size, W // patch_size))

    for i in range(0, H, patch_size):
        for j in range(0, W, patch_size):
            start_i = max(i - how_many, 0)
            end_i = min(i + patch_size + how_many, H)
            start_j = max(j - how_many, 0)
            end_j = min(j + patch_size + how_many, W)

            extended_patch = image[start_i:end_i, start_j:end_j]

            if type == 'variance':
                value = np.std(extended_patch)
            elif type == 'mean_brightness':
                value = np.mean(extended_patch)
            elif type == 'contrast':
                value = extended_patch.max() - extended_patch.min()
            elif type == 'edge_density':
                dy, dx = np.gradient(extended_patch)
                value = np.sum(np.sqrt(dx ** 2 + dy ** 2))
            elif type == 'color_diversity':
                value = np.std(extended_patch)

            noise = np.random.randn() * noise_scale
            value_map[i // patch_size, j // patch_size] = value + noise

    return value_map

def chessboard_mask(images, patch_size=2, mask_ratio=0.5, importance_type='variance', how_many=1, noise_scale=0):
    B, C, H, W = images.shape
    masked_images = images.clone()
    unmasked_counts = []
    unmasked_patches = []
    patch_index = []
    check_index = 0

    target_unmasked_ratio = 1 - mask_ratio
    num_patches = (H // patch_size) * (W // patch_size)
    target_unmasked_patches = int(num_patches * target_unmasked_ratio)

    for b in range(B):
        unmasked_count = 0
        patch_importance_map = patch_importance(images[b, 0], patch_size, importance_type, how_many, noise_scale)

        mask = np.zeros((H // patch_size, W // patch_size), dtype=bool)
        for i in range(H // patch_size):
            for j in range(W // patch_size):
                if (i + j) % 2 == 0:
                    mask[i, j] = True

        unmasked_count = np.sum(~mask)

        if mask_ratio < 0.5:
            masked_indices = np.argwhere(mask)
            importances = patch_importance_map[mask]
            sorted_indices = masked_indices[np.argsort(importances)[::-1]]

            for idx in sorted_indices:
                if unmasked_count >= target_unmasked_patches:
                    break
                mask[tuple(idx)] = False
                unmasked_count += 1

        elif mask_ratio > 0.5:
            unmasked_indices = np.argwhere(~mask)
            importances = patch_importance_map[~mask]
            sorted_indices = unmasked_indices[np.argsort(importances)]

            for idx in sorted_indices:
                if unmasked_count <= target_unmasked_patches:
                    break
                mask[tuple(idx)] = True
                unmasked_count -= 1

        patches = []
        for i in range(H // patch_size):
            for j in range(W // patch_size):
                if mask[i, j]:
                    masked_images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = 0
                else:
                    patch = images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size]
                    #print(111)
                    #print(patch.shape)
                    patches.append(patch)
                    patch_index.append((H // patch_size)*i + j)

        unmasked_patches.append(torch.cat(patches, dim=-1))
        unmasked_counts.append(unmasked_count)
        unmasked_patches_image = torch.cat(unmasked_patches, dim=-1)
        #print(unmasked_patches_image.shape)



    return masked_images, unmasked_patches_image, patch_index



import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# CIFAR-10 데이터셋 로드
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True, num_workers=2)

# 데이터 로드
dataiter = iter(trainloader)
images, labels = next(dataiter)

# 마스킹 적용
patch_size = 8
mask_ratio = 0.75
importance_type = 'variance'
how_many = 1
noise_scale = 0

masked_images, unmasked_patches_image, patch_index = chessboard_mask(images, patch_size, mask_ratio, importance_type, how_many, noise_scale)

print(masked_images.shape)
print(unmasked_patches_image.shape)
print(len(patch_index))

def reconstruct_masked_image(unmasked_patches, patch_index, image_shape, patch_size):
    B, C, H, W = image_shape
    reconstructed_images = torch.zeros((B, C, H, W)).to(unmasked_patches.device)

    print(1)
    print(unmasked_patches.shape)

    save_here= []
    for i in range(len(patch_index)) :
        save_here.append( unmasked_patches[:, :, patch_size * i : patch_size * (i+1)] )

    for idx, linear_idx in enumerate(patch_index):
        i = linear_idx // (W // patch_size)
        j = linear_idx % (W // patch_size)
        reconstructed_images[0, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = save_here[idx]


    return reconstructed_images

image_shape = images.shape
reconstructed_image = reconstruct_masked_image(unmasked_patches_image, patch_index, image_shape, patch_size)

def visualize_masked_image(original, masked, unmasked_patches_image, reconstructed):
    fig, ax = plt.subplots(1, 4, figsize=(15, 5))
    ax[0].imshow(np.transpose(original, (1, 2, 0)))
    ax[0].set_title('Original Image')
    ax[1].imshow(np.transpose(masked, (1, 2, 0)))
    ax[1].set_title('Masked Image')
    ax[2].imshow(np.transpose(unmasked_patches_image.cpu().numpy(), (1, 2, 0)))
    ax[2].set_title('Reconstructed Image')
    ax[3].imshow(np.transpose(reconstructed[0].cpu().numpy(), (1, 2, 0)))
    ax[3].set_title('Reconstructed Image')
    plt.show()

# 첫 번째 이미지에 대한 시각화
visualize_masked_image(images[0].numpy(), masked_images[0].numpy(), unmasked_patches_image, reconstructed_image)

'Main' 카테고리의 다른 글

Good  (1) 2024.07.26
Position Estimator  (0) 2024.07.10
No Encoder Symbol Check  (1) 2024.07.05
No Masking Symbol Check  (0) 2024.07.05
Object/background focusing  (0) 2024.06.25
Comments