UOMOP
Adaptive encoder 본문
import torch
import torch.nn as nn
class Encoder(nn.Module):
def __init__(self, latent_dim):
super(Encoder, self).__init__()
self.latent_dim = latent_dim
self.in1 = nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=1)
self.in2 = nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=1)
self.in3 = nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=1)
self.in4 = nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=1)
self.in5 = nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=1)
self.in6 = nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=1)
self.out1 = nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=1)
self.out2 = nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=1)
self.out3 = nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=1)
self.out4 = nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=1)
self.out5 = nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=1)
self.out6 = nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=1)
self.prelu = nn.PReLU()
self.flatten = nn.Flatten()
self.linear = nn.Linear(2048, self.latent_dim)
self.essen = nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=1)
def forward(self, x):
height = x.shape[-1]
if height == 32 :
x = self.prelu(self.in1(x))
x = self.prelu(self.out1(x))
x = self.prelu(self.out2(x))
x = self.prelu(self.out3(x))
x = self.prelu(self.out4(x))
x = self.prelu(self.out5(x))
x = self.prelu(self.out6(x))
x = self.prelu(self.essen(x))
print(x.shape)
if height == 30 :
x = self.prelu(self.in2(x))
x = self.prelu(self.out2(x))
x = self.prelu(self.out3(x))
x = self.prelu(self.out4(x))
x = self.prelu(self.out5(x))
x = self.prelu(self.out6(x))
x = self.prelu(self.essen(x))
print(x.shape)
if height == 28 :
x = self.prelu(self.in3(x))
x = self.prelu(self.out3(x))
x = self.prelu(self.out4(x))
x = self.prelu(self.out5(x))
x = self.prelu(self.out6(x))
x = self.prelu(self.essen(x))
print(x.shape)
if height == 26 :
x = self.prelu(self.in4(x))
x = self.prelu(self.out4(x))
x = self.prelu(self.out5(x))
x = self.prelu(self.out6(x))
x = self.prelu(self.essen(x))
print(x.shape)
if height == 24 :
x = self.prelu(self.in5(x))
x = self.prelu(self.out5(x))
x = self.prelu(self.out6(x))
x = self.prelu(self.essen(x))
print(x.shape)
if height == 22 :
x = self.prelu(self.in6(x))
x = self.prelu(self.out6(x))
x = self.prelu(self.essen(x))
print(x.shape)
x = self.flatten(x)
encoded = self.linear(x)
return encoded
Naive
import torch
import torch.nn as nn
class Encoder(nn.Module):
def __init__(self, latent_dim):
super(Encoder, self).__init__()
self.latent_dim = latent_dim
self.prelu = nn.PReLU()
# 입력 레이어와 출력 레이어 생성
self.input_layers = nn.ModuleList(
[nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=1) for _ in range(6)]
)
self.output_layers = nn.ModuleList(
[nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=1) for _ in range(6)]
)
self.essen = nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=1)
self.flatten = nn.Flatten()
self.linear = nn.Linear(2048, self.latent_dim)
def forward(self, x):
height = x.shape[-1]
# 입력 높이에 따라 다른 레이어를 선택
encoder_level = (32 - height) // 2
x = self.prelu(self.input_layers[encoder_level](x))
for i in range(encoder_level, 6):
x = self.prelu(self.output_layers[i](x))
x = self.prelu(self.essen(x))
print(x.shape)
x = self.flatten(x)
encoded = self.linear(x)
return encoded, encoder_level
'DE > Code' 카테고리의 다른 글
Proposed Network Architecture (0) | 2024.08.10 |
---|---|
Adaptive decoder (0) | 2024.08.06 |
Proposed net (0) | 2024.08.06 |
Filter counting of zero padding (0) | 2024.08.06 |
Patch selection code (CBS) (0) | 2024.08.05 |
Comments