UOMOP

Adaptive decoder 본문

DE/Code

Adaptive decoder

Happy PinGu 2024. 8. 6. 14:31
class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.latent_dim = latent_dim

        self.linear = nn.Linear(self.latent_dim, 2048)
        self.prelu  = nn.PReLU()
        self.unflatten = nn.Unflatten(1, (32, 8, 8))

        self.essen = nn.ConvTranspose2d(32, 16, kernel_size=5, stride=2, padding=1, output_padding=1)


        self.in6 = nn.ConvTranspose2d(16, 16, kernel_size=5, stride=1, padding=1)
        self.in5 = nn.ConvTranspose2d(16, 16, kernel_size=5, stride=1, padding=1)
        self.in4 = nn.ConvTranspose2d(16, 16, kernel_size=5, stride=1, padding=1)
        self.in3 = nn.ConvTranspose2d(16, 16, kernel_size=5, stride=1, padding=1)
        self.in2 = nn.ConvTranspose2d(16, 16, kernel_size=5, stride=1, padding=1)
        self.in1 = nn.ConvTranspose2d(16, 16, kernel_size=5, stride=1, padding=1)

        self.out6 = nn.ConvTranspose2d(16, 3, kernel_size=5, stride=1, padding=1)
        self.out5 = nn.ConvTranspose2d(16, 3, kernel_size=5, stride=1, padding=1)
        self.out4 = nn.ConvTranspose2d(16, 3, kernel_size=5, stride=1, padding=1)
        self.out3 = nn.ConvTranspose2d(16, 3, kernel_size=5, stride=1, padding=1)
        self.out2 = nn.ConvTranspose2d(16, 3, kernel_size=5, stride=1, padding=1)
        self.out1 = nn.ConvTranspose2d(16, 3, kernel_size=5, stride=1, padding=1)

        self.sigmoid = nn.Sigmoid()
    

    def forward(self, x, encoder_level):

        x = self.essen(self.unflatten(self.prelu(self.linear(x))))
        
        if encoder_level == 0 :
            x = self.in6(x)
            x = self.in5(x)
            x = self.in4(x)
            x = self.in3(x)
            x = self.in2(x)
            x = self.in1(x)
            x = self.out1(x)

        elif encoder_level == 1 :
            x = self.in6(x)
            x = self.in5(x)
            x = self.in4(x)
            x = self.in3(x)
            x = self.in2(x)
            x = self.out2(x)

        elif encoder_level == 2 :
            x = self.in6(x)
            x = self.in5(x)
            x = self.in4(x)
            x = self.in3(x)
            x = self.out3(x)
        elif encoder_level == 3 :
            x = self.in6(x)
            x = self.in5(x)
            x = self.in4(x)
            x = self.out4(x)
        elif encoder_level == 4 :
            x = self.in6(x)
            x = self.in5(x)
            x = self.out5(x)
        elif encoder_level == 5 :
            x = self.in6(x)
            x = self.out6(x)

        decoded = self.sigmoid(x)
        
        return decoded

Naive

import torch
import torch.nn as nn

class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.latent_dim = latent_dim

        self.linear = nn.Linear(self.latent_dim, 2048)
        self.prelu = nn.PReLU()
        self.unflatten = nn.Unflatten(1, (32, 8, 8))

        self.essen = nn.ConvTranspose2d(32, 16, kernel_size=5, stride=2, padding=1, output_padding=1)

        # ConvTranspose2d 레이어 리스트 생성
        self.inner_layers = nn.ModuleList(
            [nn.ConvTranspose2d(16, 16, kernel_size=5, stride=1, padding=1) for _ in range(6)]
        )
        
        # 마지막 출력 레이어 리스트 생성
        self.output_layers = nn.ModuleList(
            [nn.ConvTranspose2d(16, 3, kernel_size=5, stride=1, padding=1) for _ in range(6)]
        )

        self.sigmoid = nn.Sigmoid()

    def forward(self, x, encoder_level):
        x = self.essen(self.unflatten(self.prelu(self.linear(x))))

        for i in range(5, encoder_level - 1, -1):
            x = self.inner_layers[i](x)
        
        x = self.output_layers[encoder_level](x)

        decoded = self.sigmoid(x)
        
        return decoded

'DE > Code' 카테고리의 다른 글

DeepJSCC BenchMark  (0) 2024.08.10
Proposed Network Architecture  (0) 2024.08.10
Adaptive encoder  (0) 2024.08.06
Proposed net  (0) 2024.08.06
Filter counting of zero padding  (0) 2024.08.06
Comments