UOMOP

Proposed Network Architecture 본문

DE/Code

Proposed Network Architecture

Happy PinGu 2024. 8. 10. 20:25
import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.latent_dim = latent_dim

        self.in1 = nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=0)
        self.in2 = nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=0)
        self.in3 = nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=0)
        self.in4 = nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=0)

        self.out1 = nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=1)
        self.out2 = nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=1)
        self.out3 = nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=1)
        self.out4 = nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=1)

        self.prelu = nn.PReLU()

        self.pool = nn.AdaptiveAvgPool2d((8, 8))

        self.flatten = nn.Flatten()
        self.linear = nn.Linear(2048, self.latent_dim)


        self.essen = nn.Conv2d(32, 32, kernel_size=5, stride=2, padding=1)


    def forward(self, x):

        height = x.shape[-1]

        if height == 32 :
            x = self.prelu(self.in1(x))
            x = self.prelu(self.out1(x))
            x = self.prelu(self.out2(x))
            x = self.prelu(self.out3(x))
            x = self.prelu(self.out4(x))

            x = self.prelu(self.essen(x))
            #print(x.shape)

        if height == 28 : # Selection Ratio : 0.765625
            x = self.prelu(self.in2(x))
            x = self.prelu(self.out2(x))
            x = self.prelu(self.out3(x))
            x = self.prelu(self.out4(x))

            x = self.prelu(self.essen(x))
            #print(x.shape)

        if height == 24 : # Selectio Ratio : 0.5625
            x = self.prelu(self.in3(x))
            x = self.prelu(self.out3(x))
            x = self.prelu(self.out4(x))
            x = self.prelu(self.essen(x))
            #print(x.shape)

        if height == 20 : # Selection Ratio : 0.390625
            x = self.prelu(self.in4(x))
            x = self.prelu(self.out4(x))
            x = self.prelu(self.essen(x))
            #print(x.shape)

        x = self.pool(x)
        #print(x.shape)
        x = self.flatten(x)
        encoded = self.linear(x)
        encoder_level = int(((32-height)/4) + 1)

        return encoded, encoder_level
class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.latent_dim = latent_dim

        self.linear = nn.Linear(self.latent_dim, 2048)
        self.prelu  = nn.PReLU()
        self.unflatten = nn.Unflatten(1, (32, 8, 8))

        self.essen = nn.ConvTranspose2d(32, 32, kernel_size=5, stride=2, padding=1, output_padding=1)
        
        self.in4 = nn.ConvTranspose2d(32, 32, kernel_size=5, stride=1, padding=1)
        self.in3 = nn.ConvTranspose2d(32, 32, kernel_size=5, stride=1, padding=1)
        self.in2 = nn.ConvTranspose2d(32, 32, kernel_size=5, stride=1, padding=1)
        self.in1 = nn.ConvTranspose2d(32, 32, kernel_size=5, stride=1, padding=0)

        self.out4 = nn.ConvTranspose2d(32, 3, kernel_size=5, stride=1, padding=2)
        self.out3 = nn.ConvTranspose2d(32, 3, kernel_size=5, stride=1, padding=1)
        self.out2 = nn.ConvTranspose2d(32, 3, kernel_size=5, stride=1, padding=0)
        self.out1 = nn.ConvTranspose2d(32, 3, kernel_size=5, stride=1, padding=0)

        self.sigmoid = nn.Sigmoid()
    

    def forward(self, x, encoder_level):

        x = self.essen(self.unflatten(self.prelu(self.linear(x))))
        
        if encoder_level == 1 :
            x = self.in4(x)
            x = self.in3(x)
            x = self.in2(x)
            x = self.in1(x)
            x = self.out1(x)
            

        elif encoder_level == 2 :
            x = self.in4(x)
            x = self.in3(x)
            x = self.in2(x)
            x = self.out2(x)
            

        elif encoder_level == 3 :
            x = self.in4(x)
            x = self.in3(x)
            x = self.out3(x)
            

        elif encoder_level == 4 :
            x = self.in4(x)
            x = self.out4(x)
            


        decoded = self.sigmoid(x)
        
        return decoded
img_size = 32

input_image = torch.randn(1, 3, img_size, img_size)

encoder = Encoder(latent_dim=128) 
decoder = Decoder(latent_dim=128)

encoded, encoder_level = encoder(input_image)
print("Encoded output shape:", encoded.shape)
print("Encoder Level : {}".format(encoder_level))

decoded = decoder(encoded, encoder_level)
print("Decoded output shape:", decoded.shape)

Encoded output shape: torch.Size([1, 128]) Encoder Level : 1 Decoded output shape: torch.Size([1, 3, 32, 32])

'DE > Code' 카테고리의 다른 글

Masked, Reshaped, index Gen  (0) 2024.08.10
DeepJSCC BenchMark  (0) 2024.08.10
Adaptive decoder  (0) 2024.08.06
Adaptive encoder  (0) 2024.08.06
Proposed net  (0) 2024.08.06
Comments