UOMOP

Patch importance 본문

Main

Patch importance

Happy PinGu 2024. 6. 3. 14:09
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt

def tensor_to_np(tensor):
    return tensor.numpy().transpose((1, 2, 0))


def patch_importance(image, patch_size=2, type='variance', how_many=2):
    H, W = image.shape
    extended_patch_size = patch_size + 2 * how_many
    value_map = np.zeros((H // patch_size, W // patch_size))
    noise_scale = 1e-6

    for i in range(0, H, patch_size):
        for j in range(0, W, patch_size):
            # Calculate start and end indices with consideration for boundary conditions
            start_i = max(i - how_many, 0)
            end_i = min(i + patch_size + how_many, H)
            start_j = max(j - how_many, 0)
            end_j = min(j + patch_size + how_many, W)
            
            # Extract the extended patch
            extended_patch = image[start_i:end_i, start_j:end_j]

            if type == 'variance':
                value = np.std(extended_patch)
            elif type == 'mean_brightness':
                value = np.mean(extended_patch)
            elif type == 'contrast':
                value = extended_patch.max() - extended_patch.min()
            elif type == 'edge_density':
                dy, dx = np.gradient(extended_patch)
                value = np.sum(np.sqrt(dx**2 + dy**2))
            elif type == 'color_diversity':
                value = np.std(extended_patch)

            noise = np.random.randn() * noise_scale
            value_map[i // patch_size, j // patch_size] = value + noise

    return value_map

#images_gray = np.random.rand(8, 8)
transform = transforms.Compose([
    transforms.ToTensor()
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False, num_workers=2)

# Fetch one image and convert to grayscale
dataiter = iter(trainloader)
images, labels = next(dataiter)
image = tensor_to_np(images[0])  # Convert to numpy array
images_gray = np.dot(image[..., :3], [0.2989, 0.5870, 0.1140])  # Convert to grayscale


std_map = patch_importance(images_gray, type='variance')

brightness_map = patch_importance(images_gray, type='mean_brightness')
contrast_map = patch_importance(images_gray, type='contrast')
edge_map = patch_importance(images_gray, type='edge_density')
color_diversity_map = patch_importance(images_gray, type='color_diversity')

# Plot the input image and its std_map
fig, axes = plt.subplots(1, 6, figsize=(12, 6))

axes[0].imshow(images_gray, cmap='gray')
axes[0].set_title('Input Random Noise Image')
axes[0].axis('off')

axes[1].imshow(std_map, cmap='Blues')
axes[1].set_title('Standard Deviation Map')
axes[1].axis('off')



axes[2].imshow(brightness_map, cmap='Blues')
axes[2].set_title('brightness_map')
axes[2].axis('off')

axes[3].imshow(contrast_map, cmap='Blues')
axes[3].set_title('contrast Map')
axes[3].axis('off')

axes[4].imshow(edge_map, cmap='Blues')
axes[4].set_title('edge')
axes[4].axis('off')


axes[5].imshow(color_diversity_map, cmap='Blues')
axes[5].set_title('color_diversity')
axes[5].axis('off')


plt.show()

'Main' 카테고리의 다른 글

No Masking Symbol Check  (0) 2024.07.05
Object/background focusing  (0) 2024.06.25
Matlab code for PSNR performance comparison  (0) 2024.05.24
DeepJSCC performance ( DIM = 768, 1536, 2304 )  (0) 2024.05.17
Image reconstruction with CBM  (0) 2024.05.03
Comments