Skip to content

Conversation

@Chamberlain0w0
Copy link
Contributor

@Chamberlain0w0 Chamberlain0w0 commented Dec 24, 2025

通过 --use_distributed_optimizer 来开启基于分布式优化器的 ZeRO-1 优化。Stage 1 的优化对 SGD 没有效果,但是依然在 GPT-2 训练脚本中加了相关 option,为的是之后 stage 2/3 的扩展。

实现细节:

  1. 添加 ParamAndGradBucket/Group/Buffer 的基建,采用了与 Megatron-LM 类似的实现思路:所有 param/grad 都连续放在一个一维 buffer 中,并分 bucket 按组进行通信;
  2. 修改 DistributedDataParallel 的逻辑,需要在构造函数中创建上述 buffer、划分 bucket group,并完成 hook 注册;同时保留了原 reducer 逻辑的实现分支,不开启 --use_distributed_optimizer 时则走原先的逻辑;
  3. 添加 DistributedOptimizer 类,继承自 Optimizer;同时为每种优化器包了一个仅接受 models->Parameters() 的构造 function 方法,DistributedOptimizer 在构造时会接受这个 function,并在其构造函数内部创建 base_optimizer;
  4. 修改了 PP 的构造函数参数,把 optimizer 的传入延后到了 Module::TrainStep,这样 PP 对象可以完全不持有 optimizer 对象;
  5. 在训练循环添加了本轮峰值占用/预留显存信息的输出。

详细实现流程可以参考叙述:https://gxtctab8no8.feishu.cn/wiki/XQbGwXSsZi3MutkZuhXcKWsinnY#share-QU0fdM1cYoT06vxmwMbcnLJUn1b

两个比较重要的细节:

  1. DDP 对象在构造时,会一同构造模型对应的 buffer/bucket_group 等;同时 DistOpt 在构造时会接受 DDP 对象的 buffer/group;
  2. 由于现在 PP + DDP 的实现,会单独将每个 chunk 构造为 DDP 对象,所以传给 DistOpt 的 buffer/bucket_group 等,需要额外在 chunk 之间做一道汇总,得到一个总 list 再传进去。

@Chamberlain0w0 Chamberlain0w0 force-pushed the feat/distributed_optimizer branch from be321ca to 4afb235 Compare January 12, 2026 06:27
@Chamberlain0w0 Chamberlain0w0 changed the title [WIP] feat: Add DistributedOptimizer and support ZeRO-1 feat: Add DistributedOptimizer and support ZeRO-1 Jan 12, 2026
#include <unordered_map>
#include <vector>

#include "infini_train/include/nn/parallel/param_and_grad_buffer.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

前置声明

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已改

BuildShardParamsAndBindGrads();

// Build base optimizer
base_optimizer_ = creator_(shard_params_);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

creator 只在构造时被调用了一次吧,没必要存下来?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删

: Optimizer(full_params), param_grad_buffers_(buffers), bucket_groups_(bucket_groups), dp_pg_(dp_pg),
dp_world_size_(dp_world_size), dp_rank_(dp_rank), creator_(std::move(creator)) {

CHECK(dp_pg_);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dp_pg_ 在 DistributedOptimizer 里似乎没被用到过

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前删了,暂时确实用不到,只用得到 size 和 rank,后续用上了的话再看情况加上吧

namespace infini_train::nn::parallel {

namespace {
std::shared_ptr<Tensor> GetShardView(const std::shared_ptr<Tensor> &buffer, size_t world_size, size_t rank) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数没被调用过

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删

const std::vector<std::shared_ptr<Tensor>> &full_params,
const std::vector<std::shared_ptr<ParamAndGradBuffer>> &buffers,
const std::vector<std::shared_ptr<ParamAndGradBucketGroup>> &bucket_groups,
const ProcessGroup *dp_pg, size_t dp_world_size, size_t dp_rank)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"ddp"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已改

*/
explicit Reducer(std::vector<std::shared_ptr<Tensor>> parameters, std::vector<std::vector<size_t>> bucket_indices,
const ReducerOptions &opts);
const DistributedDataParallelConfig ddp_config = DistributedDataParallelConfig());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

尽量不要用默认参数

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删

}
}

void Tensor::SetData(const Tensor &tensor, size_t offset, bool overwrite) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 overwrite 的语义是,是否在“重绑定 buffer 之前”,把当前 Tensor 的数据拷贝到目标 buffer 的对应位置中,overwrite 这个名字有点歧义,改成 preserve_data 吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已改

@@ -1,5 +1,6 @@
#include "infini_train/include/optimizer.h"

#include <utility>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个头文件有必要添加吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删

const Device *PipelineStage::device() const { return device_; }
const std::vector<std::vector<int64_t>> &PipelineStage::recv_shape() const { return recv_shape_; }
std::shared_ptr<Optimizer> PipelineStage::optimizer() { return optimizer_; }
// std::shared_ptr<Optimizer> PipelineStage::optimizer() { return optimizer_; }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接删了吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删

"nthread_per_process": 8,
"num_iteration": 10,
"batch_size": 40,
"batch_size": 20,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么要改成 20?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已改

@kilinchange
Copy link
Collaborator

把 ddp 相关的文件在 parallel 里单独放一个 ddp 文件夹吧。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants