UOMOP

DDPM(Free Guidance)_320epoch 본문

Research/Semantic Communication

DDPM(Free Guidance)_320epoch

Happy PinGu 2023. 11. 29. 11:16
from DiffusionFreeGuidence.TrainCondition import train, eval

def main(model_config=None):
    modelConfig = {
        "state": "train", # train or eval
        "epoch": 320,
        "batch_size": 80,
        "T": 500,
        "channel": 128,
        "channel_mult": [1, 2, 2, 2],
        "num_res_blocks": 2,
        "dropout": 0.15,
        "lr": 1e-4,
        "multiplier": 2.5,
        "beta_1": 1e-4,
        "beta_T": 0.028,
        "img_size": 32,
        "grad_clip": 1.,
        "device": "cuda:0",
        "w": 1.8,
        "save_dir": "./CheckpointsCondition/",
        "training_load_weight": None,
        "test_load_weight": "ckpt_9_.pt",
        "sampled_dir": "./SampledImgs/",
        "sampledNoisyImgName": "NoisyGuidenceImgs.png",
        "sampledImgName": "SampledGuidenceImgs.png",
        "nrow": 8
    }
    if model_config is not None:
        modelConfig = model_config
    if modelConfig["state"] == "train":
        train(modelConfig)
    else:
        eval(modelConfig)


if __name__ == '__main__':
    main()

import matplotlib.pyplot as plt

file_path = 'k.txt'

data = []

with open(file_path, 'r') as file:
    for line in file:
        values = line.split()
        data.extend([float(value) for value in values])

plt.plot(data)
plt.title('Predicted Noise vs Real Noise')
plt.xlabel('EPOCH')
plt.ylabel('MSE')
plt.show()

Comments