UOMOP

Adaptive encoder 본문

DE/Code

Adaptive encoder

Happy PinGu 2024. 8. 6. 13:55
import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.latent_dim = latent_dim

        self.in1 = nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=1)
        self.in2 = nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=1)
        self.in3 = nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=1)
        self.in4 = nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=1)
        self.in5 = nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=1)
        self.in6 = nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=1)

        self.out1 = nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=1)
        self.out2 = nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=1)
        self.out3 = nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=1)
        self.out4 = nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=1)
        self.out5 = nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=1)
        self.out6 = nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=1)
        
        self.prelu = nn.PReLU()

        self.flatten = nn.Flatten()
        self.linear = nn.Linear(2048, self.latent_dim)


        self.essen = nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=1)


    def forward(self, x):

        height = x.shape[-1]

        if height == 32 :
            x = self.prelu(self.in1(x))
            x = self.prelu(self.out1(x))
            x = self.prelu(self.out2(x))
            x = self.prelu(self.out3(x))
            x = self.prelu(self.out4(x))
            x = self.prelu(self.out5(x))
            x = self.prelu(self.out6(x))
            x = self.prelu(self.essen(x))
            print(x.shape)

        if height == 30 :
            x = self.prelu(self.in2(x))
            x = self.prelu(self.out2(x))
            x = self.prelu(self.out3(x))
            x = self.prelu(self.out4(x))
            x = self.prelu(self.out5(x))
            x = self.prelu(self.out6(x))
            x = self.prelu(self.essen(x))
            print(x.shape)

        if height == 28 :
            x = self.prelu(self.in3(x))
            x = self.prelu(self.out3(x))
            x = self.prelu(self.out4(x))
            x = self.prelu(self.out5(x))
            x = self.prelu(self.out6(x))
            x = self.prelu(self.essen(x))
            print(x.shape)

        if height == 26 :
            x = self.prelu(self.in4(x))
            x = self.prelu(self.out4(x))
            x = self.prelu(self.out5(x))
            x = self.prelu(self.out6(x))
            x = self.prelu(self.essen(x))
            print(x.shape)

        if height == 24 :
            x = self.prelu(self.in5(x))
            x = self.prelu(self.out5(x))
            x = self.prelu(self.out6(x))
            x = self.prelu(self.essen(x))
            print(x.shape)

        if height == 22 :
            x = self.prelu(self.in6(x))
            x = self.prelu(self.out6(x))
            x = self.prelu(self.essen(x))
            print(x.shape)

        x = self.flatten(x)
        encoded = self.linear(x)


        return encoded

Naive

 

 

import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.latent_dim = latent_dim
        self.prelu = nn.PReLU()

        # 입력 레이어와 출력 레이어 생성
        self.input_layers = nn.ModuleList(
            [nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=1) for _ in range(6)]
        )
        self.output_layers = nn.ModuleList(
            [nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=1) for _ in range(6)]
        )
        self.essen = nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=1)

        self.flatten = nn.Flatten()
        self.linear = nn.Linear(2048, self.latent_dim)

    def forward(self, x):
        height = x.shape[-1]

        # 입력 높이에 따라 다른 레이어를 선택
        encoder_level = (32 - height) // 2

        x = self.prelu(self.input_layers[encoder_level](x))
        for i in range(encoder_level, 6):
            x = self.prelu(self.output_layers[i](x))
        x = self.prelu(self.essen(x))
        print(x.shape)

        x = self.flatten(x)
        encoded = self.linear(x)

        return encoded, encoder_level

'DE > Code' 카테고리의 다른 글

Proposed Network Architecture  (0) 2024.08.10
Adaptive decoder  (0) 2024.08.06
Proposed net  (0) 2024.08.06
Filter counting of zero padding  (0) 2024.08.06
Patch selection code (CBS)  (0) 2024.08.05
Comments