UOMOP

DE : Selection (33%, 60%, 75%) 본문

DE/Code

DE : Selection (33%, 60%, 75%)

Happy PinGu 2024. 8. 15. 20:28
import torchvision.transforms as transforms
import math
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import time

import os
from tqdm import tqdm
import numpy as np
import math
import torch
import torchvision
from fractions import Fraction
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as f
import matplotlib.pyplot as plt
import torchvision.transforms as tr
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
import time
import os
from skimage.metrics import structural_similarity as ssim
import torch

from torch.optim.lr_scheduler import ReduceLROnPlateau
import random
def patch_importance(image, patch_size=2, type='variance', how_many=2):
    if isinstance(image, torch.Tensor):
        image = image.numpy()

    H, W = image.shape[-2:]

    value_map = np.zeros((H // patch_size, W // patch_size))

    for i in range(0, H, patch_size):
        for j in range(0, W, patch_size):
            start_i = max(i - how_many, 0)
            end_i = min(i + patch_size + how_many, H)
            start_j = max(j - how_many, 0)
            end_j = min(j + patch_size + how_many, W)

            extended_patch = image[start_i:end_i, start_j:end_j]

            if type == 'variance':
                value = np.std(extended_patch)
            elif type == 'mean_brightness':
                value = np.mean(extended_patch)
            elif type == 'contrast':
                value = extended_patch.max() - extended_patch.min()
            elif type == 'edge_density':
                dy, dx = np.gradient(extended_patch)
                value = np.sum(np.sqrt(dx ** 2 + dy ** 2))
            elif type == 'color_diversity':
                value = np.std(extended_patch)

            value_map[i // patch_size, j // patch_size] = value

    return value_map
def chessboard_mask(images, patch_size=2, mask_ratio=0.5, importance_type='variance', how_many=1):
    for mr_i in range(len(mask_ratio)):
        MR = mask_ratio[mr_i]
        B, C, H, W = images.shape
        masked_images = images.clone()
        unmasked_counts = []
        unmasked_patches = []
        patch_index = []

        target_unmasked_ratio = 1 - MR
        num_patches = (H // patch_size) * (W // patch_size)
        target_unmasked_patches = int(num_patches * target_unmasked_ratio)

        for b in range(B):

            patch_importance_map = patch_importance(images[b, 0], patch_size, importance_type, how_many)

            mask = np.zeros((H // patch_size, W // patch_size), dtype=bool)
            for i in range(H // patch_size):
                for j in range(W // patch_size):
                    if (i + j) % 2 == 0:
                        mask[i, j] = True

            unmasked_count = np.sum(~mask)

            if MR < 0.5:
                masked_indices = np.argwhere(mask)
                importances = patch_importance_map[mask]
                sorted_indices = masked_indices[np.argsort(importances)[::-1]]

                for idx in sorted_indices:
                    if unmasked_count >= target_unmasked_patches:
                        break
                    mask[tuple(idx)] = False
                    unmasked_count += 1

            elif MR > 0.5:
                unmasked_indices = np.argwhere(~mask)
                importances = patch_importance_map[~mask]
                sorted_indices = unmasked_indices[np.argsort(importances)]

                for idx in sorted_indices:
                    if unmasked_count <= target_unmasked_patches:
                        break
                    mask[tuple(idx)] = True
                    unmasked_count -= 1

            patches = []
            for i in range(H // patch_size):
                for j in range(W // patch_size):
                    if mask[i, j]:
                        masked_images[b, :, i * patch_size:(i + 1) * patch_size,
                        j * patch_size:(j + 1) * patch_size] = 0
                    else:
                        patch = images[b, :, i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size]

                        patches.append(patch)
                        patch_index.append((H // patch_size) * i + j)

            unmasked_patches.append(torch.cat(patches, dim=-1))
            unmasked_counts.append(unmasked_count)
            unmasked_patches_image = torch.cat(unmasked_patches, dim=-1)

            if MR == 0.33984:
                split_len = 26
                split_tensor = torch.split(unmasked_patches_image, split_len, dim=2)
                masked_33 = masked_images
                reshaped_33 = torch.cat(split_tensor, dim=1)
                index_33 = torch.tensor(patch_index)

            elif MR == 0.609375:
                split_len = 20
                split_tensor = torch.split(unmasked_patches_image, split_len, dim=2)
                masked_60 = masked_images
                reshaped_60 = torch.cat(split_tensor, dim=1)
                index_60 = torch.tensor(patch_index)

            elif MR == 0.75:
                split_len = 16
                split_tensor = torch.split(unmasked_patches_image, split_len, dim=2)
                masked_75 = masked_images
                reshaped_75 = torch.cat(split_tensor, dim=1)
                index_75 = torch.tensor(patch_index)

    return masked_33, masked_60, masked_75, reshaped_33, reshaped_60, reshaped_75, index_33, index_60, index_75
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt

# CIFAR-10 데이터셋 로드
transform = transforms.Compose([
    transforms.ToTensor()
])

cifar10 = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# 이미지 하나를 가져오기
image, _ = cifar10[124]
image = image.unsqueeze(0)  # 배치 차원을 추가

# chessboard_mask 함수 실행
masked_33, masked_60, masked_75, reshaped_33, reshaped_60, reshaped_75, index_33, index_60, index_75 = chessboard_mask(image, patch_size=2, mask_ratio=[0.33984, 0.609375, 0.75])

# 이미지 시각화 함수
def plot_image(tensor, title):
    img = tensor.squeeze(0).permute(1, 2, 0).numpy()
    plt.imshow(img)
    plt.title(title)
    plt.axis('off')

# 결과 시각화
plt.figure(figsize=(12, 8))

plt.subplot(2, 3, 1)
plot_image(masked_33, 'Masked 33%')

plt.subplot(2, 3, 2)
plot_image(masked_60, 'Masked 60%')

plt.subplot(2, 3, 3)
plot_image(masked_75, 'Masked 75%')

plt.subplot(2, 3, 4)
plot_image(reshaped_33, 'Reshaped 33%')

plt.subplot(2, 3, 5)
plot_image(reshaped_60, 'Reshaped 60%')

plt.subplot(2, 3, 6)
plot_image(reshaped_75, 'Reshaped 75%')

plt.show()

print(masked_33.shape)
print(masked_60.shape)
print(masked_75.shape)

print(reshaped_33.shape)
print(reshaped_60.shape)
print(reshaped_75.shape)

print(len(index_33))
print(len(index_60))
print(len(index_75))

torch.Size([1, 3, 32, 32]) torch.Size([1, 3, 32, 32]) torch.Size([1, 3, 32, 32]) torch.Size([3, 26, 26]) torch.Size([3, 20, 20]) torch.Size([3, 16, 16]) 169 100 64

 

'DE > Code' 카테고리의 다른 글

Level decision model's MSE training (x+x_f, x, x_f)  (0) 2024.08.27
Cifar10 Fourier  (0) 2024.08.27
Masked, Reshaped, index Gen  (0) 2024.08.10
DeepJSCC BenchMark  (0) 2024.08.10
Proposed Network Architecture  (0) 2024.08.10
Comments