UOMOP
DDPM(Free Guidance)_only 10epoch 본문
import os
from typing import Dict
import numpy as np
import torch
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.utils import save_image
from DiffusionFreeGuidence.DiffusionCondition import GaussianDiffusionSampler, GaussianDiffusionTrainer
from DiffusionFreeGuidence.ModelCondition import UNet
from Scheduler import GradualWarmupScheduler
#import openpyxl
def train(modelConfig: Dict):
device = torch.device(modelConfig["device"])
# dataset
dataset = CIFAR10(
root='./CIFAR10', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
dataloader = DataLoader(dataset, batch_size=modelConfig["batch_size"], shuffle=True, num_workers=4, drop_last=True, pin_memory=True)
# model setup
net_model = UNet(T=modelConfig["T"], num_labels=10, ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"],
num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)
if modelConfig["training_load_weight"] is not None:
net_model.load_state_dict(torch.load(os.path.join(
modelConfig["save_dir"], modelConfig["training_load_weight"]), map_location=device), strict=False)
print("Model weight load down.")
optimizer = torch.optim.AdamW(net_model.parameters(), lr=modelConfig["lr"], weight_decay=1e-4)
cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=modelConfig["epoch"], eta_min=0, last_epoch=-1)
warmUpScheduler = GradualWarmupScheduler(optimizer=optimizer, multiplier=modelConfig["multiplier"],
warm_epoch=modelConfig["epoch"] // 10, after_scheduler=cosineScheduler)
trainer = GaussianDiffusionTrainer(net_model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)
file = open("mse_result.txt", "w")
# start training
for e in range(modelConfig["epoch"]):
total_loss_in_one_epoch = 0
with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
for images, labels in tqdmDataLoader:
# train
b = images.shape[0]
optimizer.zero_grad()
x_0 = images.to(device)
labels = labels.to(device) + 1
if np.random.rand() < 0.1:
labels = torch.zeros_like(labels).to(device)
loss = trainer(x_0, labels).sum() / b ** 2.
total_loss_in_one_epoch += loss.item()
loss.backward()
torch.nn.utils.clip_grad_norm_(net_model.parameters(), modelConfig["grad_clip"])
optimizer.step()
tqdmDataLoader.set_postfix(ordered_dict={
"epoch": e+1,
"loss: ": loss.item(),
"img shape: ": x_0.shape,
"LR": optimizer.state_dict()['param_groups'][0]["lr"]
})
aver_loss_in_one_epoch = total_loss_in_one_epoch / len(tqdmDataLoader)
print("Epoch[{:04d}] : Loss = {:.5f} ".format(e + 1, aver_loss_in_one_epoch))
file.write(str(aver_loss_in_one_epoch) + str(" "))
warmUpScheduler.step()
torch.save(net_model.state_dict(), os.path.join(modelConfig["save_dir"], 'ckpt_' + str(e) + "_.pt"))
file.close()
def eval(modelConfig: Dict):
device = torch.device(modelConfig["device"])
# load model and evaluate
with torch.no_grad():
step = int(modelConfig["batch_size"] // 10)
labelList = []
k = 0
for i in range(1, modelConfig["batch_size"] + 1):
labelList.append(torch.ones(size=[1]).long() * k)
if i % step == 0:
if k < 10 - 1:
k += 1
labels = torch.cat(labelList, dim=0).long().to(device) + 1
print("labels: ", labels)
model = UNet(T=modelConfig["T"], num_labels=10, ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"],
num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)
ckpt = torch.load(os.path.join(modelConfig["save_dir"], modelConfig["test_load_weight"]), map_location=device)
model.load_state_dict(ckpt)
print("model load weight done.")
model.eval()
sampler = GaussianDiffusionSampler(model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"], w=modelConfig["w"]).to(device)
# Sampled from standard normal distribution
noisyImage = torch.randn(
size=[modelConfig["batch_size"], 3, modelConfig["img_size"], modelConfig["img_size"]], device=device)
saveNoisy = torch.clamp(noisyImage * 0.5 + 0.5, 0, 1) # normal distribution noise => [0, 1] noise
save_image(saveNoisy, os.path.join(modelConfig["sampled_dir"], modelConfig["sampledNoisyImgName"]), nrow=modelConfig["nrow"])
sampledImgs = sampler(noisyImage, labels)
sampledImgs = sampledImgs * 0.5 + 0.5 # [0 ~ 1]
print(sampledImgs)
save_image(sampledImgs, os.path.join(
modelConfig["sampled_dir"], modelConfig["sampledImgName"]), nrow=modelConfig["nrow"])
'Research > Semantic Communication' 카테고리의 다른 글
DDPM(Free Guidance)_1000epoch (0) | 2023.12.02 |
---|---|
DDPM(Free Guidance)_320epoch (1) | 2023.11.29 |
Classifier Guidance vs Classifier Free Guidance (0) | 2023.11.28 |
DDPM_cifar10(BS : 128 | LR : 0.0005 | L2 loss) (0) | 2023.11.22 |
DDPM_cifar10(BS : 128 | LR : 0.001 | L2 loss) (0) | 2023.11.21 |
Comments