源码结构
➜ LF-VSN-main tree
├── README.md
├── assets
│ ├── overview.PNG
│ └── performance.PNG
└── code
├── data
│ ├── Vimeo90K_dataset.py
│ ├── __init__.py
│ ├── data_sampler.py
│ ├── util.py
│ └── video_test_dataset.py
├── models
│ ├── LFVSN.py
│ ├── __init__.py
│ ├── base_model.py
│ ├── discrim.py
│ ├── lr_scheduler.py
│ ├── modules
│ │ ├── Inv_arch.py
│ │ ├── Quantization.py
│ │ ├── Subnet_constructor.py
│ │ ├── __init__.py
│ │ ├── common.py
│ │ ├── loss.py
│ │ └── module_util.py
│ └── networks.py
├── options
│ ├── __init__.py
│ ├── options.py
│ └── train
│ ├── train_LF-VSN_1video.yml
│ ├── train_LF-VSN_2video.yml
│ ├── train_LF-VSN_3video.yml
│ ├── train_LF-VSN_4video.yml
│ ├── train_LF-VSN_5video.yml
│ ├── train_LF-VSN_6video.yml
│ └── train_LF-VSN_7video.yml
├── test.py
├── train.py
└── utils
├── __init__.py
└── util.py
各文件和目录的作用
code/ 目录
code/ 下的Python脚本
- train.py:训练脚本,用于训练模型。读取配置文件,创建数据加载器,初始化模型,执行训练循环,包括模型的保存和日志记录等。
- test.py:测试脚本,用于评估模型。加载训练好的模型,对测试数据进行推理,计算评价指标(如 PSNR),并保存输出结果。
code/ 的子目录
code/data/
该目录包含数据处理和加载的相关代码。
- __init__.py:使该目录成为一个 Python 包。
- Vimeo90K_dataset.py:定义了
Vimeo90KDataset
类,用于加载 Vimeo90K 数据集的训练数据。负责读取图像、进行数据增强和预处理。 - video_test_dataset.py:定义了
VideoTestDataset
类,用于加载测试数据集。支持加载 Vid4、REDS4、Vimeo90K-Test 等数据集。 - data_sampler.py:定义了分布式训练时的数据采样器
DistIterSampler
,用于在多 GPU 训练时正确划分数据。 - util.py:数据处理的实用函数,例如图像读取、图像增强、通道转换、图像缩放等。
code/models/
该目录包含模型的定义和相关模块。
- __init__.py:用于创建模型的函数,根据配置创建对应的模型实例。
- base_model.py:定义了
BaseModel
类,所有模型的基类,包含了一些通用的方法,例如保存和加载模型、更新学习率等。 - LFVSN.py:主要的模型文件,定义了
Model_VSN
类,是 LF-VSN 模型的具体实现,包括前向传播、损失计算、优化等。 - networks.py:定义了模型的网络结构,包含了创建生成器(Generator)的函数。
modules/:模型使用的各种模块和工具。
- __init__.py:使该目录成为一个 Python 包。
- Inv_arch.py:定义了可逆网络的架构,包括
InvBlock
、InvNN
、PredictiveModuleMIMO
等类。 - Quantization.py:实现了量化模块,用于模拟图像的量化过程。
- Subnet_constructor.py:定义了子网络构造函数,用于构建模型中的子网络(如 DenseBlock)。
- common.py:一些通用的模型组件,例如 DWT(离散小波变换)和 IWT(逆小波变换)等。
- loss.py:定义了损失函数,包括重建损失、GAN 损失等。
- module_util.py:模型工具函数,例如权重初始化、卷积块定义等。
code/options/
该目录包含配置相关的代码和配置文件。
- __init__.py:使该目录成为一个 Python 包。
- options.py:定义了解析配置文件的函数,将 YAML 文件解析为 Python 字典,并提供了一些配置检查的功能。
train/:存放训练的配置文件,每个配置文件对应不同的训练设置。
- train_LF-VSN_1video.yml:用于训练隐藏 1 个秘密视频的配置文件。
- train_LF-VSN_2video.yml:用于训练隐藏 2 个秘密视频的配置文件。
- ...(一直到 train_LF-VSN_7video.yml):分别对应隐藏 1 到 7 个秘密视频的训练配置。
code/utils/
该目录包含一些工具函数和日志相关的代码。
- __init__.py:使该目录成为一个 Python 包。
- util.py:实用工具函数,例如创建文件夹、设置随机种子、计算 PSNR 等。
详细说明
1. 训练脚本(train.py)
- 作用:负责模型的训练过程。
主要功能:
- 解析训练配置文件,获取训练参数。
- 创建数据加载器,准备训练和验证数据集。
- 初始化模型,加载预训练权重(如果有)。
- 设置优化器和学习率调度器。
- 执行训练循环,包括前向传播、计算损失、反向传播和参数更新。
- 记录训练日志,保存模型检查点。
2. 测试脚本(test.py)
- 作用:用于模型的测试和评估。
主要功能:
- 解析测试配置文件,获取测试参数。
- 创建测试数据加载器。
- 加载训练好的模型权重。
- 对测试数据进行推理,生成输出结果。
- 计算评价指标(如 PSNR),评估模型性能。
- 保存测试输出,例如生成的图像或视频帧。
3. 数据相关代码(code/data/)
a. Vimeo90K_dataset.py
- 作用:定义了用于训练的
Vimeo90KDataset
类。 主要功能:
- 从指定的路径读取 Vimeo90K 数据集的图像序列。
- 支持随机裁剪、数据增强(如翻转、旋转)等操作。
- 根据
num_video
参数,加载多个秘密视频。
b. video_test_dataset.py
- 作用:定义了用于测试的
VideoTestDataset
类。 主要功能:
- 加载测试数据集的视频序列。
- 支持不同的数据集(如 Vid4、REDS4)的加载方式。
- 为模型提供测试所需的数据。
c. data_sampler.py
- 作用:定义了用于分布式训练的数据采样器
DistIterSampler
。 主要功能:
- 在分布式训练时,确保每个进程(GPU)获取不同的数据子集。
- 支持对数据集的扩展,以适应迭代次数的需求。
d. util.py(数据部分)
- 作用:提供数据处理的实用函数。
主要功能:
- 图像的读取和保存。
- 图像增强(如随机翻转、旋转)。
- 通道转换(如 RGB 转灰度)。
- 图像缩放、裁剪等操作。
4. 模型相关代码(code/models/)
a. LFVSN.py
- 作用:定义了 LF-VSN 模型的核心类
Model_VSN
。 主要功能:
- 继承自
BaseModel
,实现模型的初始化、前向传播、损失计算等。 - 根据配置,创建生成器(Generator)网络。
- 定义了模型的训练过程,包括前向和反向的损失计算。
- 提供了测试方法,用于在验证或测试时生成输出。
- 继承自
b. networks.py
- 作用:定义了创建网络的函数。
主要功能:
- 根据配置文件,构建生成器网络。
- 主要使用了可逆网络(INN)的结构。
c. discrim.py
- 作用:定义鉴别器网络和GAN损失函数
主要功能:
- 定义UNetDiscriminatorSN类
- 定义GANLoss类
c. lr_scheduler.py
- 作用:定义了自定义的学习率调度器,用于在训练过程中调整优化器的学习率
主要功能:
- MultiStepLR_Restart:多步学习率调度器
- CosineAnnealingLR_Restart:余弦退火学习率调度器
c. modules/
Inv_arch.py:定义了可逆网络的架构组件。
- InvBlock:可逆块,实现了基本的可逆操作。
- InvNN:可逆神经网络,由多个可逆块组成。
- PredictiveModuleMIMO:预测模块,用于生成隐藏视频的预测。
- Quantization.py:实现了量化操作,用于模拟图像在实际存储和传输中的量化过程。
Subnet_constructor.py:定义了子网络的构造函数。
- 提供了构建子网络(如 DenseBlock)的方式,支持不同的初始化方法。
common.py:一些通用的模型组件。
- DWT:离散小波变换,用于提取图像的频域特征。
- IWT:逆小波变换,将频域特征还原为图像。
loss.py:定义了模型训练过程中使用的损失函数。
- ReconstructionLoss:重建损失,用于衡量输出与目标之间的差异。
- GANLoss:对抗损失(如果使用 GAN 结构)。
module_util.py:提供了模型工具函数。
- initialize_weights:权重初始化函数。
- make_layer:用于构建网络层的辅助函数。
5. 配置相关代码(code/options/)
a. options.py
- 作用:解析配置文件,将 YAML 文件转换为 Python 字典。
主要功能:
- 提供了
parse
函数,读取配置文件。 - 检查配置中的路径,确保必要的字段存在。
- 提供了
b. train/
- 作用:存放不同训练设置的配置文件。
每个配置文件的主要内容:
- general settings:一般设置,包括模型名称、使用的 GPU、隐藏的视频数量等。
- datasets:数据集相关的配置,指定训练和验证数据集的位置和参数。
- network structures:网络结构的配置,指定模型的架构和参数。
- path:路径设置,包括预训练模型、模型保存路径等。
- training settings:训练参数,包括学习率、优化器设置、损失函数权重等。
- logger:日志和模型保存的设置。
6. 工具和日志(code/utils/)
a. logger.py
- 作用:设置日志记录器。
主要功能:
- 定义了日志的格式和级别。
- 支持将日志输出到文件和控制台。
b. util.py(工具部分)
- 作用:提供一些实用的工具函数。
主要功能:
- mkdir、mkdirs:创建目录的函数。
- set_random_seed:设置随机种子,以确保实验的可重复性。
- calculate_psnr:计算峰值信噪比(PSNR)的函数,用于评估图像质量。
- tensor2img:将 PyTorch 张量转换为图像格式,便于保存和可视化。
值得学习或修改的源码
└── code
├── data
│ ├── Vimeo90K_dataset.py
│ ├── __init__.py
│ ├── data_sampler.py
│ ├── util.py
│ └── video_test_dataset.py
├── models
│ ├── LFVSN.py # 定义LFVSN网络结构
│ ├── __init__.py
│ ├── base_model.py
│ ├── discrim.py # 定义鉴别器网络和GAN损失
│ ├── lr_scheduler.py
│ ├── networks.py # 定义生成器网络结构
│ ├── modules
│ │ ├── Inv_arch.py # 定义了可逆神经网络的架构
│ │ ├── Quantization.py
│ │ ├── Subnet_constructor.py # 用于构建模型的子网络
│ │ ├── __init__.py
│ │ ├── common.py # 一些通用模型组件,包括DWT和IWT
│ │ ├── loss.py # 定义损失函数
│ └ └── module_util.py
├── options
│ ├── __init__.py
│ ├── options.py
├── test.py # 测试脚本,可帮助了解模型推理过程
├── train.py # 训练脚本,可帮助了解训练流程
└── utils
├── __init__.py
└── util.py # 定义了一些工具函数,可帮助了解PSNR等性能指标的原理