源码结构

➜  LF-VSN-main tree
├── README.md
├── assets
│   ├── overview.PNG
│   └── performance.PNG
└── code
    ├── data
    │   ├── Vimeo90K_dataset.py
    │   ├── __init__.py
    │   ├── data_sampler.py
    │   ├── util.py
    │   └── video_test_dataset.py
    ├── models
    │   ├── LFVSN.py
    │   ├── __init__.py
    │   ├── base_model.py
    │   ├── discrim.py
    │   ├── lr_scheduler.py
    │   ├── modules
    │   │   ├── Inv_arch.py
    │   │   ├── Quantization.py
    │   │   ├── Subnet_constructor.py
    │   │   ├── __init__.py
    │   │   ├── common.py
    │   │   ├── loss.py
    │   │   └── module_util.py
    │   └── networks.py
    ├── options
    │   ├── __init__.py
    │   ├── options.py
    │   └── train
    │       ├── train_LF-VSN_1video.yml
    │       ├── train_LF-VSN_2video.yml
    │       ├── train_LF-VSN_3video.yml
    │       ├── train_LF-VSN_4video.yml
    │       ├── train_LF-VSN_5video.yml
    │       ├── train_LF-VSN_6video.yml
    │       └── train_LF-VSN_7video.yml
    ├── test.py
    ├── train.py
    └── utils
        ├── __init__.py
        └── util.py

各文件和目录的作用

code/ 目录

code/ 下的Python脚本

  • train.py:训练脚本,用于训练模型。读取配置文件,创建数据加载器,初始化模型,执行训练循环,包括模型的保存和日志记录等。
  • test.py:测试脚本,用于评估模型。加载训练好的模型,对测试数据进行推理,计算评价指标(如 PSNR),并保存输出结果。

code/ 的子目录

code/data/

该目录包含数据处理和加载的相关代码。

  • __init__.py:使该目录成为一个 Python 包。
  • Vimeo90K_dataset.py:定义了 Vimeo90KDataset 类,用于加载 Vimeo90K 数据集的训练数据。负责读取图像、进行数据增强和预处理。
  • video_test_dataset.py:定义了 VideoTestDataset 类,用于加载测试数据集。支持加载 Vid4、REDS4、Vimeo90K-Test 等数据集。
  • data_sampler.py:定义了分布式训练时的数据采样器 DistIterSampler,用于在多 GPU 训练时正确划分数据。
  • util.py:数据处理的实用函数,例如图像读取、图像增强、通道转换、图像缩放等。
code/models/

该目录包含模型的定义和相关模块。

  • __init__.py:用于创建模型的函数,根据配置创建对应的模型实例。
  • base_model.py:定义了 BaseModel 类,所有模型的基类,包含了一些通用的方法,例如保存和加载模型、更新学习率等。
  • LFVSN.py:主要的模型文件,定义了 Model_VSN 类,是 LF-VSN 模型的具体实现,包括前向传播、损失计算、优化等。
  • networks.py:定义了模型的网络结构,包含了创建生成器(Generator)的函数。
  • modules/:模型使用的各种模块和工具。

    • __init__.py:使该目录成为一个 Python 包。
    • Inv_arch.py:定义了可逆网络的架构,包括 InvBlockInvNNPredictiveModuleMIMO 等类。
    • Quantization.py:实现了量化模块,用于模拟图像的量化过程。
    • Subnet_constructor.py:定义了子网络构造函数,用于构建模型中的子网络(如 DenseBlock)。
    • common.py:一些通用的模型组件,例如 DWT(离散小波变换)和 IWT(逆小波变换)等。
    • loss.py:定义了损失函数,包括重建损失、GAN 损失等。
    • module_util.py:模型工具函数,例如权重初始化、卷积块定义等。
code/options/

该目录包含配置相关的代码和配置文件。

  • __init__.py:使该目录成为一个 Python 包。
  • options.py:定义了解析配置文件的函数,将 YAML 文件解析为 Python 字典,并提供了一些配置检查的功能。
  • train/:存放训练的配置文件,每个配置文件对应不同的训练设置。

    • train_LF-VSN_1video.yml:用于训练隐藏 1 个秘密视频的配置文件。
    • train_LF-VSN_2video.yml:用于训练隐藏 2 个秘密视频的配置文件。
    • ...(一直到 train_LF-VSN_7video.yml):分别对应隐藏 1 到 7 个秘密视频的训练配置。
code/utils/

该目录包含一些工具函数和日志相关的代码。

  • __init__.py:使该目录成为一个 Python 包。
  • util.py:实用工具函数,例如创建文件夹、设置随机种子、计算 PSNR 等。

详细说明

1. 训练脚本(train.py)

  • 作用:负责模型的训练过程。
  • 主要功能

    • 解析训练配置文件,获取训练参数。
    • 创建数据加载器,准备训练和验证数据集。
    • 初始化模型,加载预训练权重(如果有)。
    • 设置优化器和学习率调度器。
    • 执行训练循环,包括前向传播、计算损失、反向传播和参数更新。
    • 记录训练日志,保存模型检查点。

2. 测试脚本(test.py)

  • 作用:用于模型的测试和评估。
  • 主要功能

    • 解析测试配置文件,获取测试参数。
    • 创建测试数据加载器。
    • 加载训练好的模型权重。
    • 对测试数据进行推理,生成输出结果。
    • 计算评价指标(如 PSNR),评估模型性能。
    • 保存测试输出,例如生成的图像或视频帧。

3. 数据相关代码(code/data/)

a. Vimeo90K_dataset.py

  • 作用:定义了用于训练的 Vimeo90KDataset 类。
  • 主要功能

    • 从指定的路径读取 Vimeo90K 数据集的图像序列。
    • 支持随机裁剪、数据增强(如翻转、旋转)等操作。
    • 根据 num_video 参数,加载多个秘密视频。

b. video_test_dataset.py

  • 作用:定义了用于测试的 VideoTestDataset 类。
  • 主要功能

    • 加载测试数据集的视频序列。
    • 支持不同的数据集(如 Vid4、REDS4)的加载方式。
    • 为模型提供测试所需的数据。

c. data_sampler.py

  • 作用:定义了用于分布式训练的数据采样器 DistIterSampler
  • 主要功能

    • 在分布式训练时,确保每个进程(GPU)获取不同的数据子集。
    • 支持对数据集的扩展,以适应迭代次数的需求。

d. util.py(数据部分)

  • 作用:提供数据处理的实用函数。
  • 主要功能

    • 图像的读取和保存。
    • 图像增强(如随机翻转、旋转)。
    • 通道转换(如 RGB 转灰度)。
    • 图像缩放、裁剪等操作。

4. 模型相关代码(code/models/)

a. LFVSN.py

  • 作用:定义了 LF-VSN 模型的核心类 Model_VSN
  • 主要功能

    • 继承自 BaseModel,实现模型的初始化、前向传播、损失计算等。
    • 根据配置,创建生成器(Generator)网络。
    • 定义了模型的训练过程,包括前向和反向的损失计算。
    • 提供了测试方法,用于在验证或测试时生成输出。

b. networks.py

  • 作用:定义了创建网络的函数。
  • 主要功能

    • 根据配置文件,构建生成器网络。
    • 主要使用了可逆网络(INN)的结构。

c. discrim.py

  • 作用:定义鉴别器网络和GAN损失函数
  • 主要功能

    • 定义UNetDiscriminatorSN类
    • 定义GANLoss类

c. lr_scheduler.py

  • 作用:定义了自定义的学习率调度器,用于在训练过程中调整优化器的学习率
  • 主要功能

    • MultiStepLR_Restart:多步学习率调度器
    • CosineAnnealingLR_Restart:余弦退火学习率调度器

c. modules/

  • Inv_arch.py:定义了可逆网络的架构组件。

    • InvBlock:可逆块,实现了基本的可逆操作。
    • InvNN:可逆神经网络,由多个可逆块组成。
    • PredictiveModuleMIMO:预测模块,用于生成隐藏视频的预测。
  • Quantization.py:实现了量化操作,用于模拟图像在实际存储和传输中的量化过程。
  • Subnet_constructor.py:定义了子网络的构造函数。

    • 提供了构建子网络(如 DenseBlock)的方式,支持不同的初始化方法。
  • common.py:一些通用的模型组件。

    • DWT:离散小波变换,用于提取图像的频域特征。
    • IWT:逆小波变换,将频域特征还原为图像。
  • loss.py:定义了模型训练过程中使用的损失函数。

    • ReconstructionLoss:重建损失,用于衡量输出与目标之间的差异。
    • GANLoss:对抗损失(如果使用 GAN 结构)。
  • module_util.py:提供了模型工具函数。

    • initialize_weights:权重初始化函数。
    • make_layer:用于构建网络层的辅助函数。

5. 配置相关代码(code/options/)

a. options.py

  • 作用:解析配置文件,将 YAML 文件转换为 Python 字典。
  • 主要功能

    • 提供了 parse 函数,读取配置文件。
    • 检查配置中的路径,确保必要的字段存在。

b. train/

  • 作用:存放不同训练设置的配置文件。
  • 每个配置文件的主要内容

    • general settings:一般设置,包括模型名称、使用的 GPU、隐藏的视频数量等。
    • datasets:数据集相关的配置,指定训练和验证数据集的位置和参数。
    • network structures:网络结构的配置,指定模型的架构和参数。
    • path:路径设置,包括预训练模型、模型保存路径等。
    • training settings:训练参数,包括学习率、优化器设置、损失函数权重等。
    • logger:日志和模型保存的设置。

6. 工具和日志(code/utils/)

a. logger.py

  • 作用:设置日志记录器。
  • 主要功能

    • 定义了日志的格式和级别。
    • 支持将日志输出到文件和控制台。

b. util.py(工具部分)

  • 作用:提供一些实用的工具函数。
  • 主要功能

    • mkdirmkdirs:创建目录的函数。
    • set_random_seed:设置随机种子,以确保实验的可重复性。
    • calculate_psnr:计算峰值信噪比(PSNR)的函数,用于评估图像质量。
    • tensor2img:将 PyTorch 张量转换为图像格式,便于保存和可视化。

值得学习或修改的源码

└── code
    ├── data
    │   ├── Vimeo90K_dataset.py
    │   ├── __init__.py
    │   ├── data_sampler.py
    │   ├── util.py
    │   └── video_test_dataset.py
    ├── models
    │   ├── LFVSN.py        # 定义LFVSN网络结构
    │   ├── __init__.py
    │   ├── base_model.py
    │   ├── discrim.py      # 定义鉴别器网络和GAN损失
    │   ├── lr_scheduler.py
    │   ├── networks.py     # 定义生成器网络结构
    │   ├── modules
    │   │   ├── Inv_arch.py    # 定义了可逆神经网络的架构
    │   │   ├── Quantization.py
    │   │   ├── Subnet_constructor.py    # 用于构建模型的子网络
    │   │   ├── __init__.py
    │   │   ├── common.py      # 一些通用模型组件,包括DWT和IWT
    │   │   ├── loss.py        # 定义损失函数
    │   └   └── module_util.py
    ├── options
    │   ├── __init__.py
    │   ├── options.py
    ├── test.py    # 测试脚本,可帮助了解模型推理过程
    ├── train.py   # 训练脚本,可帮助了解训练流程
    └── utils
        ├── __init__.py
        └── util.py    # 定义了一些工具函数,可帮助了解PSNR等性能指标的原理
最后修改:2024 年 09 月 21 日
如果觉得我的文章对你有用,请随意赞赏