UOMOP

Transmission code 20240911 본문

DE/Code

Transmission code 20240911

Happy PinGu 2024. 9. 11. 13:24
import torch
import torch.nn as nn
from thop import profile

class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.latent_dim = latent_dim

        # stride=2를 초반 레이어에 적용하고, kernel_size를 3으로 줄임
        self.in1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=0)
        self.in2 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=2)
        self.in3 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=0)
        self.in4 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=0)

        self.out1 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=2)
        self.out2 = nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=2)
        self.out3 = nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=2)
        self.out4 = nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=2)

        self.prelu = nn.PReLU()

        self.pool = nn.AdaptiveAvgPool2d((8, 8))

        self.flatten = nn.Flatten()
        self.linear = nn.Linear(2048, self.latent_dim)

    def forward(self, x):
        height = x.shape[-1]

        if height == 32:
            encoder_level = 1
            x = self.prelu(self.in1(x))
            x = self.prelu(self.out1(x))
            x = self.prelu(self.out2(x))
            x = self.prelu(self.out3(x))
            x = self.prelu(self.out4(x))
            #x = self.prelu(self.essen(x))

        if height == 20:
            encoder_level = 2
            x = self.prelu(self.in2(x))
            x = self.prelu(self.out2(x))
            x = self.prelu(self.out3(x))
            x = self.prelu(self.out4(x))
            #x = self.prelu(self.essen(x))

        if height == 16:
            encoder_level = 3
            x = self.prelu(self.in3(x))
            x = self.prelu(self.out3(x))
            x = self.prelu(self.out4(x))
            #x = self.prelu(self.essen(x))

        if height == 12:
            encoder_level = 4
            x = self.prelu(self.in4(x))
            x = self.prelu(self.out4(x))
            #x = self.prelu(self.essen(x))

        print(x.shape)
        x = self.pool(x)
        #print(x.shape)
        x = self.flatten(x)
        encoded = self.linear(x)

        return encoded, encoder_level

# 입력 이미지 크기 정의
input_sizes = [(1, 3, 32, 32), (1, 3, 20, 20), (1, 3, 16, 16), (1, 3, 12, 12)]

# Encoder 모델 생성
latent_dim = 512  # 예시로 설정한 latent dimension
encoder = Encoder(latent_dim)

# 각 입력 이미지 크기에 대한 FLOPs 계산
for input_size in input_sizes:
    print(f"Input size: {input_size}")
    dummy_input = torch.randn(input_size)

    # thop을 사용하여 FLOPs과 파라미터 수 계산
    flops, params = profile(encoder, inputs=(dummy_input,))

    print(f"FLOPs: {flops/1000000}M")
    print(f"Params: {params}\n")
    print("="*80 + "\n")
class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.latent_dim = latent_dim

        self.linear = nn.Linear(self.latent_dim, 2048)
        self.prelu = nn.PReLU()
        self.unflatten = nn.Unflatten(1, (32, 8, 8))

        #self.essen = nn.ConvTranspose2d(32, 32, kernel_size=5, stride=2, padding=1, output_padding=1)

        self.in4 = nn.ConvTranspose2d(32, 32, kernel_size=5, stride=1, padding=0)
        self.in3 = nn.ConvTranspose2d(32, 32, kernel_size=5, stride=1, padding=1)
        self.in2 = nn.ConvTranspose2d(32, 32, kernel_size=5, stride=1, padding=1)
        self.in1 = nn.ConvTranspose2d(32, 32, kernel_size=5, stride=1, padding=1)

        self.out4 = nn.ConvTranspose2d(32, 3, kernel_size=5, stride=1, padding=2, output_padding=0)
        self.out3 = nn.ConvTranspose2d(32, 3, kernel_size=3, stride=1, padding=0, output_padding=0)
        self.out2 = nn.ConvTranspose2d(32, 3, kernel_size=5, stride=1, padding=0, output_padding=0)
        self.out1 = nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2, padding=3, output_padding=1)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x, encoder_level):

        x = self.unflatten(self.prelu(self.linear(x)))

        if encoder_level == 1:
            x = self.prelu(self.in4(x))
            x = self.prelu(self.in3(x))
            x = self.prelu(self.in2(x))
            x = self.prelu(self.in1(x))
            x = self.out1(x)

        elif encoder_level == 2:
            x = self.prelu(self.in4(x))
            x = self.prelu(self.in3(x))
            x = self.prelu(self.in2(x))
            x = self.out2(x)


        elif encoder_level == 3:
            x = self.prelu(self.in4(x))
            x = self.prelu(self.in3(x))
            x = self.out3(x)


        elif encoder_level == 4:
            x = self.prelu(self.in4(x))
            x = self.out4(x)

        decoded = self.sigmoid(x)

        return decoded
import torch
import torch.nn as nn

def generate_random_image(width, height=None):

    if height is None:
        height = width
    return torch.randn(1, 3, height, width)
    
    if __name__ == '__main__':
    latent_dim = 512
    img_size = 32

    input = generate_random_image(img_size)
    print(f"input img size : {input.shape}")

    encoder = Encoder(latent_dim)

    decoder = Decoder(latent_dim)

    encoded, encoder_level = encoder(input)
    print(f"encoded.shape = {encoded.shape}")
    print(f"Encoder's level = {encoder_level}")

    decoded = decoder(encoded, encoder_level)
    print(f"decded.shape = {decoded.shape}")

'DE > Code' 카테고리의 다른 글

FLOPs (Proposed encoder)  (0) 2024.08.28
Level decision model's MSE training (x+x_f, x, x_f)  (0) 2024.08.27
Cifar10 Fourier  (0) 2024.08.27
DE : Selection (33%, 60%, 75%)  (0) 2024.08.15
Masked, Reshaped, index Gen  (0) 2024.08.10
Comments