UOMOP

Patch complexity calculated region extending 본문

Main

Patch complexity calculated region extending

Happy PinGu 2024. 5. 3. 14:58
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import cv2

# 이미 위에서 정의된 mask_patches_chessboard 및 patch_std 함수를 사용합니다.

# CIFAR-10 데이터셋 로드
transform = transforms.Compose([
    transforms.ToTensor()  # 이미지를 텐서로 변환
])
# 테스트용으로 하나의 이미지만 로드
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=True)

# 데이터 로더에서 이미지 하나 가져오기
dataiter = iter(testloader)
images, labels = next(dataiter)

# 이미지 마스킹 함수 적용
# 예시로 mask_ratio를 0.5와 complexity_based를 False로 설정
masked_images = mask_patches_chessboard(images, patch_size=2, mask_ratio=0.25, complexity_based=True)
new_masked_images = new_mask_patches_chessboard(images, patch_size=2, mask_ratio=0.25, complexity_based=True)

# 원본 이미지와 마스킹된 이미지 시각화
def imshow(img):
    #img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# 원본 이미지 출력
print("Original Image:")
imshow(torchvision.utils.make_grid(images))

# 마스킹된 이미지 출력
print("Masked Image:")
imshow(torchvision.utils.make_grid(masked_images))

print("New Masked Image:")
imshow(torchvision.utils.make_grid(new_masked_images))​
def patch_std(image, patch_size=2):

    H, W = image.shape
    std_map = np.zeros((H // patch_size, W // patch_size))

    for i in range(0, H, patch_size):
        for j in range(0, W, patch_size):
            patch = image[i:i + patch_size, j:j + patch_size]
            std_map[i // patch_size, j // patch_size] = np.std(patch)

    return std_map


def mask_patches_chessboard(images, patch_size=2, mask_ratio=0.5, complexity_based=False):
    if mask_ratio != 0.5:
        B, C, H, W = images.shape
        masked_images = images.clone()

        for b in range(B):
            image = images[b].permute(1, 2, 0).cpu().numpy() * 255
            image = image.astype(np.uint8)
            gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)

            # Calculate complexity for each patch
            complexity_map = patch_std(gray, patch_size)

            # Initialize mask with chessboard pattern for complexity map dimensions
            complexity_height, complexity_width = complexity_map.shape
            mask = np.zeros((complexity_height, complexity_width), dtype=bool)
            mask[::2, 1::2] = 1
            mask[1::2, ::2] = 1

            if complexity_based:
                if mask_ratio > 0.5:
                    additional_masking_ratio = (mask_ratio - 0.5) / 0.5
                    complexity_threshold = np.quantile(complexity_map[~mask], 1 - additional_masking_ratio)
                    additional_mask = complexity_map <= complexity_threshold
                    mask[~mask] = additional_mask[~mask]


                else:
                    unmasking_ratio = (0.5 - mask_ratio) / 0.5
                    complexity_threshold = np.quantile(complexity_map[mask], unmasking_ratio)
                    unmask = complexity_map >= complexity_threshold
                    mask[mask] = ~unmask[mask]

            # Apply mask to the original image based on complexity map
            for i in range(complexity_height):
                for j in range(complexity_width):
                    if mask[i, j]:
                        image[i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = 0

            # Convert image back to PyTorch format
            image = image.astype(np.float32) / 255.0
            masked_images[b] = torch.from_numpy(image).permute(2, 0, 1)

    elif mask_ratio == 0.5:
        B, C, H, W = images.shape
        masked_images = images.clone()

        # Create the chessboard pattern
        pattern = np.tile(np.array([[1, 0] * (W // (2 * patch_size)), [0, 1] * (W // (2 * patch_size))]),
                          (H // (2 * patch_size), 1))

        for b in range(B):
            image = images[b].permute(1, 2, 0).cpu().numpy() * 255
            image = image.astype(np.uint8)

            # Apply masking
            mask = np.repeat(np.repeat(pattern, patch_size, axis=0), patch_size, axis=1)
            image[mask == 0] = 0  # Apply chessboard pattern masking

            # Convert back to PyTorch format
            image = image.astype(np.float32) / 255.0
            masked_images[b] = torch.from_numpy(image).permute(2, 0, 1)

    return masked_images
def new_patch_std(image, patch_size=2):
    H, W = image.shape
    expanded_patch_size = patch_size * 2
    std_map = np.zeros((H // patch_size, W // patch_size))

    # Adjust the start points to include neighbor pixels
    start_i = expanded_patch_size // 2 - patch_size // 2
    start_j = expanded_patch_size // 2 - patch_size // 2

    for i in range(start_i, H - start_i + 1, patch_size):
        for j in range(start_j, W - start_j + 1, patch_size):
            # Define the expanded patch region considering boundary conditions
            i_start = max(0, i - patch_size // 2)
            i_end = min(H, i + patch_size // 2 + patch_size)
            j_start = max(0, j - patch_size // 2)
            j_end = min(W, j + patch_size // 2 + patch_size)

            # Extract the patch including neighboring pixels
            patch = image[i_start:i_end, j_start:j_end]

            # Compute the standard deviation of the expanded patch
            std_map[(i - start_i) // patch_size, (j - start_j) // patch_size] = np.std(patch)

    return std_map


def new_mask_patches_chessboard(images, patch_size=2, mask_ratio=0.5, complexity_based=False):
    if mask_ratio != 0.5:
        B, C, H, W = images.shape
        masked_images = images.clone()

        for b in range(B):
            image = images[b].permute(1, 2, 0).cpu().numpy() * 255
            image = image.astype(np.uint8)
            gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)

            # Calculate complexity for each patch
            complexity_map = new_patch_std(gray, patch_size)

            # Initialize mask with chessboard pattern for complexity map dimensions
            complexity_height, complexity_width = complexity_map.shape
            mask = np.zeros((complexity_height, complexity_width), dtype=bool)
            mask[::2, 1::2] = 1
            mask[1::2, ::2] = 1

            if complexity_based:
                if mask_ratio > 0.5:
                    additional_masking_ratio = (mask_ratio - 0.5) / 0.5
                    complexity_threshold = np.quantile(complexity_map[~mask], 1 - additional_masking_ratio)
                    additional_mask = complexity_map <= complexity_threshold
                    mask[~mask] = additional_mask[~mask]


                else:
                    unmasking_ratio = (0.5 - mask_ratio) / 0.5
                    complexity_threshold = np.quantile(complexity_map[mask], unmasking_ratio)
                    unmask = complexity_map >= complexity_threshold
                    mask[mask] = ~unmask[mask]

            # Apply mask to the original image based on complexity map
            for i in range(complexity_height):
                for j in range(complexity_width):
                    if mask[i, j]:
                        image[i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = 0

            # Convert image back to PyTorch format
            image = image.astype(np.float32) / 255.0
            masked_images[b] = torch.from_numpy(image).permute(2, 0, 1)

    elif mask_ratio == 0.5:
        B, C, H, W = images.shape
        masked_images = images.clone()

        # Create the chessboard pattern
        pattern = np.tile(np.array([[1, 0] * (W // (2 * patch_size)), [0, 1] * (W // (2 * patch_size))]),
                          (H // (2 * patch_size), 1))

        for b in range(B):
            image = images[b].permute(1, 2, 0).cpu().numpy() * 255
            image = image.astype(np.uint8)

            # Apply masking
            mask = np.repeat(np.repeat(pattern, patch_size, axis=0), patch_size, axis=1)
            image[mask == 0] = 0  # Apply chessboard pattern masking

            # Convert back to PyTorch format
            image = image.astype(np.float32) / 255.0
            masked_images[b] = torch.from_numpy(image).permute(2, 0, 1)

    return masked_images

 

성능 확인중

'Main' 카테고리의 다른 글

DeepJSCC performance ( DIM = 768, 1536, 2304 )  (0) 2024.05.17
Image reconstruction with CBM  (0) 2024.05.03
ChessBoard Masking with Colored Random Noise  (0) 2024.05.02
CBS (odd, odd) masking  (0) 2024.04.18
Masking strategy comparison  (0) 2024.04.11
Comments