UOMOP
Adaptive decoder 본문
class Decoder(nn.Module):
def __init__(self, latent_dim):
super(Decoder, self).__init__()
self.latent_dim = latent_dim
self.linear = nn.Linear(self.latent_dim, 2048)
self.prelu = nn.PReLU()
self.unflatten = nn.Unflatten(1, (32, 8, 8))
self.essen = nn.ConvTranspose2d(32, 16, kernel_size=5, stride=2, padding=1, output_padding=1)
self.in6 = nn.ConvTranspose2d(16, 16, kernel_size=5, stride=1, padding=1)
self.in5 = nn.ConvTranspose2d(16, 16, kernel_size=5, stride=1, padding=1)
self.in4 = nn.ConvTranspose2d(16, 16, kernel_size=5, stride=1, padding=1)
self.in3 = nn.ConvTranspose2d(16, 16, kernel_size=5, stride=1, padding=1)
self.in2 = nn.ConvTranspose2d(16, 16, kernel_size=5, stride=1, padding=1)
self.in1 = nn.ConvTranspose2d(16, 16, kernel_size=5, stride=1, padding=1)
self.out6 = nn.ConvTranspose2d(16, 3, kernel_size=5, stride=1, padding=1)
self.out5 = nn.ConvTranspose2d(16, 3, kernel_size=5, stride=1, padding=1)
self.out4 = nn.ConvTranspose2d(16, 3, kernel_size=5, stride=1, padding=1)
self.out3 = nn.ConvTranspose2d(16, 3, kernel_size=5, stride=1, padding=1)
self.out2 = nn.ConvTranspose2d(16, 3, kernel_size=5, stride=1, padding=1)
self.out1 = nn.ConvTranspose2d(16, 3, kernel_size=5, stride=1, padding=1)
self.sigmoid = nn.Sigmoid()
def forward(self, x, encoder_level):
x = self.essen(self.unflatten(self.prelu(self.linear(x))))
if encoder_level == 0 :
x = self.in6(x)
x = self.in5(x)
x = self.in4(x)
x = self.in3(x)
x = self.in2(x)
x = self.in1(x)
x = self.out1(x)
elif encoder_level == 1 :
x = self.in6(x)
x = self.in5(x)
x = self.in4(x)
x = self.in3(x)
x = self.in2(x)
x = self.out2(x)
elif encoder_level == 2 :
x = self.in6(x)
x = self.in5(x)
x = self.in4(x)
x = self.in3(x)
x = self.out3(x)
elif encoder_level == 3 :
x = self.in6(x)
x = self.in5(x)
x = self.in4(x)
x = self.out4(x)
elif encoder_level == 4 :
x = self.in6(x)
x = self.in5(x)
x = self.out5(x)
elif encoder_level == 5 :
x = self.in6(x)
x = self.out6(x)
decoded = self.sigmoid(x)
return decoded
Naive
import torch
import torch.nn as nn
class Decoder(nn.Module):
def __init__(self, latent_dim):
super(Decoder, self).__init__()
self.latent_dim = latent_dim
self.linear = nn.Linear(self.latent_dim, 2048)
self.prelu = nn.PReLU()
self.unflatten = nn.Unflatten(1, (32, 8, 8))
self.essen = nn.ConvTranspose2d(32, 16, kernel_size=5, stride=2, padding=1, output_padding=1)
# ConvTranspose2d 레이어 리스트 생성
self.inner_layers = nn.ModuleList(
[nn.ConvTranspose2d(16, 16, kernel_size=5, stride=1, padding=1) for _ in range(6)]
)
# 마지막 출력 레이어 리스트 생성
self.output_layers = nn.ModuleList(
[nn.ConvTranspose2d(16, 3, kernel_size=5, stride=1, padding=1) for _ in range(6)]
)
self.sigmoid = nn.Sigmoid()
def forward(self, x, encoder_level):
x = self.essen(self.unflatten(self.prelu(self.linear(x))))
for i in range(5, encoder_level - 1, -1):
x = self.inner_layers[i](x)
x = self.output_layers[encoder_level](x)
decoded = self.sigmoid(x)
return decoded
'DE > Code' 카테고리의 다른 글
DeepJSCC BenchMark (0) | 2024.08.10 |
---|---|
Proposed Network Architecture (0) | 2024.08.10 |
Adaptive encoder (0) | 2024.08.06 |
Proposed net (0) | 2024.08.06 |
Filter counting of zero padding (0) | 2024.08.06 |
Comments