UOMOP

수학적 알고리즘의 FLOPs 계산방법 본문

카테고리 없음

수학적 알고리즘의 FLOPs 계산방법

Happy PinGu 2024. 10. 12. 18:39
import torch
import numpy as np

def patch_importance(image, patch_size=2, type='variance', how_many=1):
    flops = 0  # FLOPs 계산을 위한 변수
    if isinstance(image, torch.Tensor):
        image = image.numpy()

    H, W = image.shape[-2:]
    value_map = np.zeros((H // patch_size, W // patch_size))

    for i in range(0, H, patch_size):
        for j in range(0, W, patch_size):
            start_i = max(i - how_many, 0)
            end_i = min(i + patch_size + how_many, H)
            start_j = max(j - how_many, 0)
            end_j = min(j + patch_size + how_many, W)

            extended_patch = image[start_i:end_i, start_j:end_j]

            if type == 'variance':
                mean_value = np.mean(extended_patch)
                flops += extended_patch.size + 1  # 평균 계산 (덧셈과 나눗셈)
                value = np.std(extended_patch)
                flops += extended_patch.size * 2  # 표준편차 계산 (제곱 및 평균 차 계산)
            elif type == 'mean_brightness':
                value = np.mean(extended_patch)
                flops += extended_patch.size + 1  # 평균 계산 (덧셈과 나눗셈)
            elif type == 'contrast':
                value = extended_patch.max() - extended_patch.min()
                flops += 2 * extended_patch.size  # 최대, 최소 계산
            elif type == 'edge_density':
                dy, dx = np.gradient(extended_patch)
                flops += 2 * extended_patch.size  # 그래디언트 계산
                value = np.sum(np.sqrt(dx ** 2 + dy ** 2))
                flops += 3 * extended_patch.size  # 제곱, 합, 제곱근 계산
            elif type == 'color_diversity':
                value = np.std(extended_patch)
                flops += extended_patch.size * 2  # 표준편차 계산

            value_map[i // patch_size, j // patch_size] = value

    return value_map, flops

def chessboard_mask(images, patch_size=2, mask_ratio=[0.5], importance_type='variance', how_many=1):
    total_flops = 0  # 전체 FLOPs 계산을 위한 변수
    
    for mr_i in range(len(mask_ratio)):
        MR = mask_ratio[mr_i]
        B, C, H, W = images.shape
        masked_images = images.clone()
        
        target_unmasked_ratio = 1 - MR
        num_patches = (H // patch_size) * (W // patch_size)
        target_unmasked_patches = int(num_patches * target_unmasked_ratio)

        for b in range(B):
            patch_importance_map, flops = patch_importance(images[b, 0], patch_size, importance_type, how_many)
            total_flops += flops  # 패치 중요도 계산에서 발생한 FLOPs 추가

            mask = np.zeros((H // patch_size, W // patch_size), dtype=bool)
            for i in range(H // patch_size):
                for j in range(W // patch_size):
                    if (i + j) % 2 == 0:
                        mask[i, j] = True

            unmasked_count = np.sum(~mask)
            total_flops += mask.size  # 마스크 초기화에서 발생한 FLOPs 추가

            if MR < 0.5:
                masked_indices = np.argwhere(mask)
                importances = patch_importance_map[mask]
                sorted_indices = masked_indices[np.argsort(importances)[::-1]]
                total_flops += len(importances) * np.log(len(importances))  # 정렬 연산 FLOPs 추가

                for idx in sorted_indices:
                    if unmasked_count >= target_unmasked_patches:
                        break
                    mask[tuple(idx)] = False
                    unmasked_count += 1

            elif MR > 0.5:
                unmasked_indices = np.argwhere(~mask)
                importances = patch_importance_map[~mask]
                sorted_indices = unmasked_indices[np.argsort(importances)]
                total_flops += len(importances) * np.log(len(importances))  # 정렬 연산 FLOPs 추가

                for idx in sorted_indices:
                    if unmasked_count <= target_unmasked_patches:
                        break
                    mask[tuple(idx)] = True
                    unmasked_count -= 1

            for i in range(H // patch_size):
                for j in range(W // patch_size):
                    if mask[i, j]:
                        masked_images[b, :, i * patch_size:(i + 1) * patch_size,
                                      j * patch_size:(j + 1) * patch_size] = 0

        # 마스크된 이미지에 대한 FLOPs 추가 (여기서는 단순히 값을 0으로 설정하는 연산)
        total_flops += B * C * (H // patch_size) * (W // patch_size)

    return masked_images, total_flops

# 예시 입력으로 FLOPs 계산
example_images = torch.randn(1, 3, 24, 24)  # 예시 입력 이미지
masked_images, flops = chessboard_mask(example_images, patch_size=1, mask_ratio=[0.5], importance_type='variance', how_many=1)
print(f"Total FLOPs: {flops/1000000}M")

 

Comments