diff --git a/docs/precision_checker_guide.md b/docs/precision_checker_guide.md index 0ea4ec8f..2fe9f1c2 100644 --- a/docs/precision_checker_guide.md +++ b/docs/precision_checker_guide.md @@ -1,17 +1,18 @@ # Precision Checker 使用指南 -精度检查工具,用于检测模型训练过程中的数值稳定性问题(NaN、Inf 等),支持 MD5 哈希对比和多种输出格式。 +精度检查工具,用于检测模型训练过程中的数值稳定性问题(NaN、Inf 等),支持 tensor 统计信息输出、MD5 哈希对比和 NPY 文件保存。 ## 功能特性 - **自动检测 NaN/Inf**:在前向和反向传播过程中自动检测异常值 - **多级别检查**:支持 Module 级别和 Function 级别的精度检查 - **灵活配置**:通过 key=value 字符串配置所有选项 -- **MD5 哈希**:支持输出 tensor 的 MD5 值用于对比 -- **表格格式**:支持表格化输出,便于查看和对比 -- **基准对比**:支持加载基准文件进行自动对比 +- **统计信息**:输出 tensor 的 min、max、mean 等统计值 +- **MD5 哈希**:支持输出 tensor 的 MD5 值用于快速对比 +- **NPY 保存**:支持保存 tensor 为 .npy 文件,便于离线分析 - **上下文追踪**:支持 GAS(梯度累积步)和层号追踪 -- **性能优化**:仅在需要时计算 MD5,避免冗余计算 +- **多卡支持**:每个 rank 独立输出到 rank_N 目录 +- **多 iter 覆盖**:同一次运行中,后续 iteration 的文件会覆盖前一个 ## 配置方式 @@ -19,146 +20,157 @@ ```cpp struct PrecisionCheckConfig { - int level = 0; // 0=关闭, 1=MODULE级别, 2=FUNCTION级别 - std::string output_path = ""; // 空=控制台(仅rank0), 非空=文件(所有rank) - bool output_md5 = false; // 输出 MD5 还是 tensor 值 - std::string format = "simple"; // "simple" 或 "table" - std::string baseline_path = ""; // 基准文件路径(用于对比),指定后默认开启 format=table + PrecisionCheckLevel level = PrecisionCheckLevel::OFF; // 0=关闭, 1=MODULE, 2=FUNCTION + std::string output_path = "./log_precision_check"; // 输出目录 + std::string format = "simple"; // "simple" 或 "md5" + bool save_tensors = false; // 是否保存 .npy 文件 + double md5_tolerance = 0.0; // MD5 量化容差(0=不量化) }; ``` -### 配置字符串格式 - -使用 `key=value,key=value` 格式: - -```cpp -auto config = utils::PrecisionCheckConfig::Parse("level=2,format=table,output_md5=true"); -nn::parallel::global::InitAllEnv(nthread, tp_size, sp_enabled, pp_size, vpp_size, config); -``` - ### 配置选项说明 | 选项 | 类型 | 默认值 | 说明 | |------|------|--------|------| | `level` | int | 0 | 0=关闭, 1=MODULE级别, 2=FUNCTION级别 | -| `output_path` | string | "" | 空=控制台(仅rank0), 非空=文件路径(所有rank) | -| `output_md5` | bool | false | true=输出MD5哈希, false=输出tensor值 | -| `format` | string | "simple" | "simple"=简单格式, "table"=表格格式 | -| `baseline` | string | "" | 基准文件路径,用于对比 | - -## 使用方法 +| `path` | string | `./log_precision_check` | 输出目录(自动创建时间戳子目录) | +| `format` | string | `simple` | `simple`=统计信息+前6个值, `md5`=MD5哈希 | +| `save_tensors` | bool | false | 是否保存 tensor 为 .npy 文件 | +| `md5_tolerance` | double | 0.0 | MD5 量化容差(如 1e-3),0=不量化 | -### 1. 基本用法(简单格式) +### 配置字符串格式 -```cpp -#include "infini_train/include/nn/parallel/global.h" -#include "infini_train/include/utils/precision_check_config.h" +使用 `key=value,key=value` 格式: -// 启用 Function 级别检查,输出 tensor 值 -auto config = utils::PrecisionCheckConfig::Parse("level=2"); -nn::parallel::global::InitAllEnv(1, 1, false, 1, 1, config); +```bash +--precision_check "level=1,path=./my_output,format=simple,save_tensors=true" +``` -// 创建并运行模型 -auto x = std::make_shared(std::vector{2, 3}, DataType::kFLOAT32); -x->Fill(2.0f); -x->RequiresGrad(); +## 输出格式 -auto y = x->Mul(x); -auto loss = y->Sum(0, false); -loss->Backward(); -``` +### 目录结构 -输出示例: ``` -I0113 06:44:10.575 [Rank 0][PrecisionCheck] Forward Input MulFunction tensor[0]: [2, 2, 2, 2, 2, 2] -I0113 06:44:10.575 [Rank 0][PrecisionCheck] Forward Output MulFunction tensor[0]: [4, 4, 4, 4, 4, 4] +log_precision_check/ +└── 20260122_143052/ # 时间戳子目录 (YYYYMMDD_HHMMSS) + ├── precision_check_rank_0.log # 文本日志 + ├── rank_0/ # NPY 文件目录 (save_tensors=true) + │ ├── Block_0_forward.npy + │ ├── Block_1_forward.npy + │ ├── Block_0_backward.npy + │ └── ... + └── rank_1/ # 多卡时每个 rank 独立目录 + ... ``` -### 2. MD5 哈希输出 - -```cpp -// 输出 MD5 而不是 tensor 值 -auto config = utils::PrecisionCheckConfig::Parse("level=2,output_md5=true"); -nn::parallel::global::InitAllEnv(1, 1, false, 1, 1, config); -``` +### Simple 格式 (format=simple) -输出示例: ``` -I0113 06:44:37.751 [Rank 0][PrecisionCheck] Forward Input MulFunction tensor[0]: md5=522b4223c3a2f0dd964caa87cb6eab65 -I0113 06:44:37.751 [Rank 0][PrecisionCheck] Forward Output MulFunction tensor[0]: md5=91d1e78bf226d8735a3bc0ca6968339c +[GAS-0] [L-0] Block_0_Forward Output tensor[0]: dtype=float32 shape=(2,1024,768) min=-2.34 max=3.56 mean=0.12 [1.23, 4.56, 7.89, ...] [NaN:0 Inf:0] +[GAS-0] [L-0] Block_0_Forward Output tensor[0]: dtype=float32 shape=(2,1024,768) min=-2.34 max=3.56 mean=0.12 [1.23, NaN, ...] [NaN:5 Inf:0] <- ERROR ``` -### 3. 表格格式输出 - -```cpp -// 使用表格格式,便于查看和对比 -auto config = utils::PrecisionCheckConfig::Parse("level=2,format=table,output_md5=true"); -nn::parallel::global::InitAllEnv(1, 1, false, 1, 1, config); -``` +### MD5 格式 (format=md5) -输出示例: ``` -+--------------------------------------------------+-------+------------------+---------------+----------+----------+ -| key | level | shape | dtype | same_hash| diff_order| -+--------------------------------------------------+-------+------------------+---------------+----------+----------+ -| [GAS-0] [L-0] Forward Input MulFunction | 2 | (2, 3) | float32 | True | 0 | -| [GAS-0] [L-0] Forward Output MulFunction | 2 | (2, 3) | float32 | True | 0 | +[GAS-0] [L-0] Block_0_Forward Output tensor[0]: dtype=float32 shape=(2,1024,768) md5=a1b2c3d4e5f6... ``` -### 4. 基准对比 +### NPY 文件命名规则 -```cpp -// 第一次运行:生成基准文件 -auto config1 = utils::PrecisionCheckConfig::Parse("level=2,output_md5=true,output_path=./baseline"); -nn::parallel::global::InitAllEnv(1, 1, false, 1, 1, config1); -// ... 运行模型 ... -// 生成文件: ./baseline/precision_check_rank_0.log - -// 第二次运行:与基准对比 -auto config2 = utils::PrecisionCheckConfig::Parse("level=2,format=table,baseline=./baseline/precision_check_rank_0.log"); -nn::parallel::global::InitAllEnv(1, 1, false, 1, 1, config2); -// ... 运行模型 ... -// 输出会显示 same_hash 列,标识是否与基准一致 -``` +文件名格式:`{ModuleName}_{idx}_{stage}.npy` -### 5. 文件输出(所有 rank) +- `ModuleName`: 模块名称(如 Block、LayerNorm) +- `idx`: 同名模块在当前 iteration 内的执行顺序索引 +- `stage`: `forward` 或 `backward` -```cpp -// 输出到文件,所有 rank 都会输出 -auto config = utils::PrecisionCheckConfig::Parse("level=2,output_path=./logs"); -nn::parallel::global::InitAllEnv(8, 2, false, 2, 1, config); -// 生成文件: ./logs/precision_check_rank_0.log, ./logs/precision_check_rank_1.log, ... -``` +**多 iteration 行为**:每个 iteration 开始时索引重置为 0,文件会被覆盖。最终只保留最后一个 iteration 的数据。 ## 命令行使用 ### GPT2 示例 ```bash -# 基本检查(简单格式,输出 tensor 值) -./gpt2 --precision_check "level=2" +# 基本检查(Simple 格式,输出到文件) +./build/gpt2 --device cuda \ + --input_bin /path/to/data.bin \ + --llmc_filepath /path/to/model.bin \ + --precision_check "level=1" \ + --num_iteration 1 + +# 保存 NPY 文件 +./build/gpt2 --device cuda \ + --input_bin /path/to/data.bin \ + --llmc_filepath /path/to/model.bin \ + --precision_check "level=1,save_tensors=true" \ + --num_iteration 1 + +# MD5 格式(用于快速对比) +./build/gpt2 --device cuda \ + --input_bin /path/to/data.bin \ + --llmc_filepath /path/to/model.bin \ + --precision_check "level=1,format=md5" \ + --num_iteration 1 + +# 自定义输出路径 +./build/gpt2 --device cuda \ + --input_bin /path/to/data.bin \ + --llmc_filepath /path/to/model.bin \ + --precision_check "level=1,path=./my_precision_check,save_tensors=true" \ + --num_iteration 1 +``` + +### LLaMA3 示例 -# 输出 MD5 哈希 -./gpt2 --precision_check "level=2,output_md5=true" +```bash +./build/llama3 --device cuda \ + --input_bin /path/to/data.bin \ + --llmc_filepath /path/to/model.bin \ + --precision_check "level=1,save_tensors=true" \ + --num_iteration 1 +``` -# 表格格式 -./gpt2 --precision_check "level=2,format=table,output_md5=true" +## 离线对比工具 -# 生成基准文件 -./gpt2 --precision_check "level=2,output_md5=true,output_path=./baseline" +### precision_compare.py -# 与基准对比 -./gpt2 --precision_check "level=2,format=table,baseline=./baseline/precision_check_rank_0.log" +用于对比两次运行的 NPY 文件: + +```bash +python scripts/precision_check/precision_compare.py \ + --dir1 ./precision_check/20260122_143052 \ + --dir2 ./precision_check/20260122_143105 \ + --atol 1e-5 \ + --rtol 1e-3 ``` -### LLaMA3 示例 +输出示例: +``` +Comparing Block_0_forward.npy: + Shape: (2, 1024, 768) vs (2, 1024, 768) ✓ + Dtype: float32 vs float32 ✓ + Max abs diff: 1.23e-06 ✓ + Max rel diff: 2.34e-07 ✓ + +Summary: 433/433 files passed +``` + +## 测试验证 + +使用 `test_precision_check` 二进制进行功能验证: ```bash -# 基本检查 -./llama3 --precision_check "level=2" +# 运行全部测试(Simple/MD5格式、NPY保存、多iter覆盖) +./build/test_precision_check + +# 运行特定级别测试 +./build/test_precision_check "level=1" # Module 级别 +./build/test_precision_check "level=2" # Function 级别 -# 表格格式 + MD5 -./llama3 --precision_check "level=2,format=table,output_md5=true" +# 测试不同配置选项 +./build/test_precision_check "level=1,format=simple" # Simple 格式 +./build/test_precision_check "level=1,format=md5" # MD5 格式 +./build/test_precision_check "level=1,save_tensors=true" # 保存 NPY 文件 ``` ## 上下文追踪 @@ -168,132 +180,79 @@ nn::parallel::global::InitAllEnv(8, 2, false, 2, 1, config); ```cpp #include "infini_train/include/utils/precision_check_context.h" -// 在训练循环中设置上下文 for (int gas_step = 0; gas_step < grad_accum_steps; ++gas_step) { PrecisionCheckContext::Instance().SetGAS(gas_step); for (int layer = 0; layer < num_layers; ++layer) { PrecisionCheckContext::Instance().SetLayer(layer); - PrecisionCheckContext::Instance().SetLayerName("transformer_block"); - // 运行该层的前向传播 // 输出会包含 [GAS-X] [L-Y] 前缀 } } ``` -输出示例: -``` -[GAS-0] [L-0] Forward Input MulFunction -[GAS-0] [L-1] Forward Input MulFunction -[GAS-1] [L-0] Forward Input MulFunction -``` - -## 性能优化 - -### MD5 计算优化 - -MD5 仅在以下情况计算: -- `output_md5=true` 时 -- `baseline_path` 非空时(需要对比) - -默认情况下(`output_md5=false` 且无基准对比),不会计算 MD5,避免性能开销。 - -### 使用建议 - -| 场景 | 推荐配置 | -|------|----------| -| 快速调试 | `level=2` | -| 详细调试 | `level=2,format=table` | -| 生成基准 | `level=2,output_md5=true,output_path=./baseline` | -| 对比测试 | `level=2,format=table,baseline=./baseline/...` | -| 生产环境 | `level=0`(关闭) | - -## 输出格式对比 - -### Simple 格式 - -``` -I0113 06:44:10.575 [Rank 0][PrecisionCheck] Forward Input MulFunction tensor[0]: [2, 2, 2, 2, 2, 2] -``` - -优点:紧凑,易于阅读 -缺点:不便于对比多个 tensor - -### Table 格式 - -``` -+--------------------------------------------------+-------+------------------+---------------+----------+----------+ -| key | level | shape | dtype | same_hash| diff_order| -+--------------------------------------------------+-------+------------------+---------------+----------+----------+ -| [GAS-0] [L-0] Forward Input MulFunction | 2 | (2, 3) | float32 | True | 0 | -``` - -优点:结构化,便于对比和分析 -缺点:占用更多空间 - ## 手动注册(高级用法) -除了通过 `InitAllEnv` 自动注册,也可以手动为特定模块注册: +除了通过命令行自动注册,也可以手动为特定模块注册: ```cpp #include "infini_train/include/utils/precision_checker.h" -// 配置精度检查器 utils::PrecisionChecker::Config config; config.check_nan = true; config.check_inf = true; -config.print_stats = true; config.abort_on_error = false; // 为特定模块注册 utils::PrecisionChecker::RegisterForModule(model.get(), "MyModel", config); - -// 为特定 Function 注册 -utils::PrecisionChecker::RegisterForFunction(my_function.get(), "MyFunction", config); ``` ## 实现原理 -精度检查器通过 Hook 机制实现: +### Hook 机制 -1. **Forward Pre-Hook**:检查输入 tensor -2. **Forward Post-Hook**:检查输出 tensor -3. **Backward Hooks**:自动检查梯度 +精度检查器通过 Hook 机制实现: -检查流程: ``` Forward Pass: - ├─> Pre-Hook: 检查输入 - ├─> Forward: 执行计算 - └─> Post-Hook: 检查输出 + └─> Post-Hook: 检查输出 tensor Backward Pass: - ├─> Backward Pre-Hook: 检查梯度输入 - ├─> Backward: 执行梯度计算 - └─> Backward Post-Hook: 检查梯度输出 + └─> Post-Hook: 检查梯度 tensor ``` -## 示例代码 +### Counter 机制 -参见: -- `test/hook/test_precision_check.cc` - 完整使用示例 -- `infini_train/include/utils/precision_checker.h` - API 文档 -- `infini_train/include/utils/precision_check_config.h` - 配置结构 -- `infini_train/include/utils/precision_check_context.h` - 上下文追踪 +为了支持多 iteration 文件覆盖,使用 `thread_local` 计数器: + +```cpp +// 每个 iteration 开始时重置 +PrecisionChecker::ResetCounters(); -## 测试 +// 每次 CheckTensors 时递增 +int idx = PrecisionCheckEnv::GetAndIncrementCounter(counter_key); +// 文件名: Block_{idx}_forward.npy +``` -```bash -# 运行测试(默认:简单格式) -./test_precision_check +这确保了: +- 同一 iteration 内,同名模块有不同的索引(Block_0, Block_1, ...) +- 不同 iteration 之间,索引重置,文件被覆盖 -# Function 级别 + MD5 -./test_precision_check "level=2,output_md5=true" +## 使用建议 + +| 场景 | 推荐配置 | +|------|----------| +| 快速调试 | `level=1` | +| 详细分析 | `level=1,save_tensors=true` | +| 快速对比 | `level=1,format=md5` | +| MD5 容差对比 | `level=1,format=md5,md5_tolerance=1e-3` | +| 生产环境 | `level=0`(关闭) | -# 表格格式 -./test_precision_check "level=2,format=table,output_md5=true" +## 相关文件 -# Module 级别 -./test_precision_check "level=1" -``` +- `infini_train/include/utils/precision_checker.h` - API 定义 +- `infini_train/include/utils/precision_check_config.h` - 配置结构 +- `infini_train/include/utils/precision_check_context.h` - 上下文追踪 +- `infini_train/include/utils/global_module_hook_registry.h` - 全局模块 Hook 注册 +- `scripts/precision_check/precision_compare.py` - 离线对比工具 +- `test/hook/test_precision_check.cc` - 精度检查测试 \ No newline at end of file diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index a1a58ed5..ac123901 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -26,7 +26,9 @@ #include "infini_train/include/profiler.h" #endif #include "infini_train/include/nn/parallel/utils.h" +#include "infini_train/include/utils/global_module_hook_registry.h" #include "infini_train/include/utils/precision_check_config.h" +#include "infini_train/include/utils/precision_checker.h" #include "example/common/tiny_shakespeare_dataset.h" #include "example/common/tokenizer.h" @@ -257,6 +259,9 @@ void Train(const nn::parallel::Rank &rank) { LOG(INFO) << "start training"; for (int step = 0; step < FLAGS_num_iteration + 1; ++step) { + // Reset precision check counters at start of each iteration for file overwrite + utils::PrecisionChecker::ResetCounters(); + const bool last_step = step == FLAGS_num_iteration; const auto iter_start = std::chrono::high_resolution_clock::now(); diff --git a/example/llama3/main.cc b/example/llama3/main.cc index 3a4e5053..1c885807 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -25,7 +25,9 @@ #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/process_group.h" #include "infini_train/include/nn/parallel/utils.h" +#include "infini_train/include/utils/global_module_hook_registry.h" #include "infini_train/include/utils/precision_check_config.h" +#include "infini_train/include/utils/precision_checker.h" #include "example/common/tiny_shakespeare_dataset.h" #include "example/common/tokenizer.h" @@ -232,6 +234,9 @@ void Train(const nn::parallel::Rank &rank) { LOG(INFO) << "Rank " << rank.GlobalRank() << ": start training"; for (int step = 0; step < FLAGS_num_iteration + 1; ++step) { + // Reset precision check counters at start of each iteration for file overwrite + utils::PrecisionChecker::ResetCounters(); + const bool last_step = step == FLAGS_num_iteration; const auto iter_start = std::chrono::high_resolution_clock::now(); diff --git a/infini_train/include/nn/modules/module.h b/infini_train/include/nn/modules/module.h index 43d77de6..b61236bb 100644 --- a/infini_train/include/nn/modules/module.h +++ b/infini_train/include/nn/modules/module.h @@ -95,7 +95,6 @@ class Module : public std::enable_shared_from_this { std::vector forward_post_hooks_; std::vector backward_pre_hooks_; std::vector backward_post_hooks_; - bool precision_check_registered_ = false; private: std::unordered_map> diff --git a/infini_train/include/utils/global_module_hook_registry.h b/infini_train/include/utils/global_module_hook_registry.h new file mode 100644 index 00000000..717a8f77 --- /dev/null +++ b/infini_train/include/utils/global_module_hook_registry.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include +#include + +namespace infini_train { +namespace nn { +class Module; +} + +namespace utils { + +// Global Module Hook Registry +// Manages hooks that need to be applied to all modules +class GlobalModuleHookRegistry { +public: + using ModuleHookRegistrar = std::function; + + static GlobalModuleHookRegistry &Instance(); + + // Register a hook registrar, which will be called for all modules on their first forward pass + void RegisterHook(ModuleHookRegistrar registrar); + + // Apply all registered hooks to the specified module (called by Module::operator()) + void ApplyHooks(nn::Module *module); + +private: + GlobalModuleHookRegistry() = default; + + std::vector registrars_; + std::unordered_set applied_modules_; + mutable std::mutex mutex_; +}; + +} // namespace utils +} // namespace infini_train diff --git a/infini_train/include/utils/precision_check_config.h b/infini_train/include/utils/precision_check_config.h index 25524fb7..f272775e 100644 --- a/infini_train/include/utils/precision_check_config.h +++ b/infini_train/include/utils/precision_check_config.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace infini_train { namespace utils { @@ -9,10 +10,11 @@ enum class PrecisionCheckLevel { OFF = 0, MODULE = 1, FUNCTION = 2 }; struct PrecisionCheckConfig { PrecisionCheckLevel level = PrecisionCheckLevel::OFF; - std::string output_path = ""; // empty=console(rank0), non-empty=file(all ranks) - bool output_md5 = false; // output MD5 hash or tensor values - std::string format = "simple"; // "simple" or "table" - std::string baseline_path = ""; // baseline file path for comparison + std::string output_path = "./log_precision_check"; // Output path (default) + std::string format = "simple"; // "simple" or "md5" + bool save_tensors = false; // Whether to output .npy file + double md5_tolerance = 0.0; // MD5 tolerance for quantization (e.g., 1e-3) + // 0 means no quantization (original precision) // Parse from "key=value,key=value" string static PrecisionCheckConfig Parse(const std::string &config_str); @@ -23,10 +25,16 @@ class PrecisionCheckEnv { static PrecisionCheckEnv &Instance(); void Init(const PrecisionCheckConfig &config); const PrecisionCheckConfig &GetConfig() const; + const std::string &GetOutputPath() const; + + // Tensor counter management for file overwrite across iterations (thread-local) + static int GetAndIncrementCounter(const std::string &key); + static void ResetCounters(); private: PrecisionCheckEnv() = default; PrecisionCheckConfig config_; + std::string timestamped_path_ = ""; // Actual output path (with timestamp) }; } // namespace utils diff --git a/infini_train/include/utils/precision_checker.h b/infini_train/include/utils/precision_checker.h index 060ccb98..2d835694 100644 --- a/infini_train/include/utils/precision_checker.h +++ b/infini_train/include/utils/precision_checker.h @@ -4,6 +4,8 @@ #include #include +#include "infini_train/include/utils/precision_check_config.h" + namespace infini_train { class Tensor; class HookHandle; @@ -32,6 +34,10 @@ class PrecisionChecker { return default_config; } + // Initialize global module-level precision checking + // Called automatically by PrecisionCheckEnv::Init when level >= MODULE + static void Init(const PrecisionCheckConfig &global_config, const Config &config = DefaultConfig()); + static void RegisterForFunction(autograd::Function *func, const std::string &name = "", const Config &config = DefaultConfig()); @@ -39,6 +45,9 @@ class PrecisionChecker { static void RegisterForModule(nn::Module *module, const std::string &name = "", const Config &config = DefaultConfig()); + // Reset tensor counters (call at start of each iteration for file overwrite) + static void ResetCounters(); + private: static void CheckTensors(const std::string &stage, const std::string &name, const std::vector> &tensors, const Config &config); diff --git a/infini_train/src/nn/modules/module.cc b/infini_train/src/nn/modules/module.cc index 04815c48..a4a25498 100644 --- a/infini_train/src/nn/modules/module.cc +++ b/infini_train/src/nn/modules/module.cc @@ -12,8 +12,7 @@ #include "infini_train/include/device.h" #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/tensor.h" -#include "infini_train/include/utils/precision_check_config.h" -#include "infini_train/include/utils/precision_checker.h" +#include "infini_train/include/utils/global_module_hook_registry.h" #ifndef UNLIKELY #define UNLIKELY(x) __builtin_expect(!!(x), 0) @@ -135,15 +134,8 @@ std::vector> Module::Forward(const std::vector> Module::operator()(const std::vector> &input_tensors) { - // Register precision check hooks if enabled and not already registered - // TODO(cx): move RegisterForModule to PrecisionChecker and avoid duplicate registration - if (!precision_check_registered_) { - auto precision_level = utils::PrecisionCheckEnv::Instance().GetConfig().level; - if (precision_level == utils::PrecisionCheckLevel::MODULE) { - utils::PrecisionChecker::RegisterForModule(this); - precision_check_registered_ = true; - } - } + // Apply globally registered hooks (on first call for this module) + utils::GlobalModuleHookRegistry::Instance().ApplyHooks(this); // Call forward pre-hooks for (const auto &hook : forward_pre_hooks_) { diff --git a/infini_train/src/utils/global_module_hook_registry.cc b/infini_train/src/utils/global_module_hook_registry.cc new file mode 100644 index 00000000..54972450 --- /dev/null +++ b/infini_train/src/utils/global_module_hook_registry.cc @@ -0,0 +1,24 @@ +#include "infini_train/include/utils/global_module_hook_registry.h" + +namespace infini_train::utils { + +GlobalModuleHookRegistry &GlobalModuleHookRegistry::Instance() { + static GlobalModuleHookRegistry instance; + return instance; +} + +void GlobalModuleHookRegistry::RegisterHook(ModuleHookRegistrar registrar) { + std::lock_guard lock(mutex_); + registrars_.push_back(std::move(registrar)); +} + +void GlobalModuleHookRegistry::ApplyHooks(nn::Module *module) { + std::lock_guard lock(mutex_); + if (applied_modules_.contains(module)) { + return; + } + for (const auto ®istrar : registrars_) { registrar(module); } + applied_modules_.insert(module); +} + +} // namespace infini_train::utils diff --git a/infini_train/src/utils/precision_check_config.cc b/infini_train/src/utils/precision_check_config.cc index d37cbdb0..7e1b7176 100644 --- a/infini_train/src/utils/precision_check_config.cc +++ b/infini_train/src/utils/precision_check_config.cc @@ -1,10 +1,20 @@ #include "infini_train/include/utils/precision_check_config.h" +#include +#include +#include #include #include +#include "infini_train/include/utils/precision_checker.h" + namespace infini_train::utils { +namespace { +// Thread-local tensor counter for precision check file indexing +thread_local std::unordered_map tls_g_tensor_counter; +} // namespace + PrecisionCheckConfig PrecisionCheckConfig::Parse(const std::string &config_str) { PrecisionCheckConfig config; if (config_str.empty()) { @@ -25,20 +35,17 @@ PrecisionCheckConfig PrecisionCheckConfig::Parse(const std::string &config_str) int level_int = std::stoi(kv_map["level"]); config.level = static_cast(level_int); } - if (kv_map.count("output_path")) { - config.output_path = kv_map["output_path"]; - } - if (kv_map.count("output_md5")) { - config.output_md5 = (kv_map["output_md5"] == "true" || kv_map["output_md5"] == "1"); - } - if (kv_map.count("baseline")) { - config.baseline_path = kv_map["baseline"]; + if (kv_map.count("path")) { + config.output_path = kv_map["path"]; } if (kv_map.count("format")) { config.format = kv_map["format"]; - } else if (!config.baseline_path.empty()) { - // Default to table format when baseline is specified - config.format = "table"; + } + if (kv_map.count("save_tensors")) { + config.save_tensors = (kv_map["save_tensors"] == "true" || kv_map["save_tensors"] == "1"); + } + if (kv_map.count("md5_tolerance")) { + config.md5_tolerance = std::stod(kv_map["md5_tolerance"]); } return config; } @@ -48,8 +55,34 @@ PrecisionCheckEnv &PrecisionCheckEnv::Instance() { return instance; } -void PrecisionCheckEnv::Init(const PrecisionCheckConfig &config) { config_ = config; } +void PrecisionCheckEnv::Init(const PrecisionCheckConfig &config) { + config_ = config; + if (config_.level != PrecisionCheckLevel::OFF) { + // Create timestamped subdirectory: output_path/YYYYMMDD_HHMMSS/ + auto now = std::chrono::system_clock::now(); + auto time_t = std::chrono::system_clock::to_time_t(now); + std::tm tm; + localtime_r(&time_t, &tm); + char buf[32]; + std::strftime(buf, sizeof(buf), "%Y%m%d_%H%M%S", &tm); + + timestamped_path_ = config_.output_path + "/" + buf; + std::filesystem::create_directories(timestamped_path_); + + // Initialize PrecisionChecker (registers global module hooks) + PrecisionChecker::Init(config_); + + // Output precision check output path + std::cout << "[PrecisionCheck] Output: " << timestamped_path_ << std::endl; + } +} const PrecisionCheckConfig &PrecisionCheckEnv::GetConfig() const { return config_; } +const std::string &PrecisionCheckEnv::GetOutputPath() const { return timestamped_path_; } + +int PrecisionCheckEnv::GetAndIncrementCounter(const std::string &key) { return tls_g_tensor_counter[key]++; } + +void PrecisionCheckEnv::ResetCounters() { tls_g_tensor_counter.clear(); } + } // namespace infini_train::utils diff --git a/infini_train/src/utils/precision_check_context.cc b/infini_train/src/utils/precision_check_context.cc index 8c1b6917..a9f33d65 100644 --- a/infini_train/src/utils/precision_check_context.cc +++ b/infini_train/src/utils/precision_check_context.cc @@ -3,8 +3,8 @@ namespace infini_train::utils { PrecisionCheckContext &PrecisionCheckContext::Instance() { - static thread_local PrecisionCheckContext instance; - return instance; + static thread_local PrecisionCheckContext tls_instance; + return tls_instance; } void PrecisionCheckContext::SetGAS(int gas) { gas_ = gas; } diff --git a/infini_train/src/utils/precision_checker.cc b/infini_train/src/utils/precision_checker.cc index 60301aa9..7966f4cb 100644 --- a/infini_train/src/utils/precision_checker.cc +++ b/infini_train/src/utils/precision_checker.cc @@ -5,14 +5,13 @@ #include #include #include +#include #include #include #include #include #include #include -#include -#include #include #include "infini_train/include/autograd/function.h" @@ -20,6 +19,7 @@ #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" #include "infini_train/include/tensor.h" +#include "infini_train/include/utils/global_module_hook_registry.h" #include "infini_train/include/utils/precision_check_config.h" #include "infini_train/include/utils/precision_check_context.h" @@ -153,123 +153,29 @@ std::string ComputeMD5(const void *data, size_t size) { return md5.Finalize(); } -// Baseline storage -std::unordered_map &GetBaseline() { - static std::unordered_map baseline; - static bool loaded = false; - static std::mutex load_mutex; - - if (!loaded) { - std::lock_guard lock(load_mutex); - if (!loaded) { - const auto &config = PrecisionCheckEnv::Instance().GetConfig(); - if (!config.baseline_path.empty()) { - std::ifstream file(config.baseline_path); - if (!file.is_open()) { - std::cerr << "[PrecisionCheck] Failed to open baseline file: " << config.baseline_path << std::endl; - } else { - std::string line; - while (std::getline(file, line)) { - // Try format 1: key|md5 - auto pipe_pos = line.rfind('|'); - if (pipe_pos != std::string::npos) { - std::string key = line.substr(0, pipe_pos); - std::string md5 = line.substr(pipe_pos + 1); - baseline[key] = md5; - } else { - // Try format 2: simple log format with "md5=" - auto md5_pos = line.find("md5="); - if (md5_pos != std::string::npos) { - // Extract md5 value - std::string md5 = line.substr(md5_pos + 4); - - // Extract key: find text between "][PrecisionCheck] " and ": md5=" - auto check_pos = line.find("][PrecisionCheck] "); - if (check_pos != std::string::npos) { - size_t key_start = check_pos + 18; // length of "][PrecisionCheck] " - size_t key_end = line.find(": md5=", key_start); - if (key_end != std::string::npos) { - std::string key = line.substr(key_start, key_end - key_start); - baseline[key] = md5; - } - } - } - } - } - std::cout << "[PrecisionCheck] Loaded " << baseline.size() << " baseline entries from " - << config.baseline_path << std::endl; - } - } - loaded = true; - } - } - return baseline; -} - -// Table header printed flag -bool &TableHeaderPrinted() { - thread_local bool printed = false; - return printed; -} - std::ostream &GetLogStream() { - thread_local std::ofstream log_file; - thread_local std::mutex init_mutex; - thread_local bool initialized = false; - thread_local bool use_console = false; - - if (!initialized) { - std::lock_guard lock(init_mutex); - if (!initialized) { - const auto &config = PrecisionCheckEnv::Instance().GetConfig(); - - if (config.output_path.empty()) { - use_console = true; + thread_local std::ofstream tls_log_file; + thread_local std::mutex tls_init_mutex; + thread_local bool tls_initialized = false; + + if (!tls_initialized) { + std::lock_guard lock(tls_init_mutex); + if (!tls_initialized) { + const auto &output_path = PrecisionCheckEnv::Instance().GetOutputPath(); + int global_rank = nn::parallel::global::thread_global_rank; + std::string filename = output_path + "/precision_check_rank_" + std::to_string(global_rank) + ".log"; + tls_log_file.open(filename, std::ios::out | std::ios::trunc); + if (!tls_log_file.is_open()) { + std::cerr << "[Rank " << global_rank << "] Failed to open precision check log file: " << filename + << std::endl; } else { - // Create output directory if it doesn't exist - mkdir(config.output_path.c_str(), 0755); - - int global_rank = nn::parallel::global::thread_global_rank; - std::string filename - = config.output_path + "/precision_check_rank_" + std::to_string(global_rank) + ".log"; - log_file.open(filename, std::ios::out | std::ios::trunc); - if (!log_file.is_open()) { - std::cerr << "[Rank " << global_rank << "] Failed to open precision check log file: " << filename - << std::endl; - use_console = true; - } else { - use_console = false; - std::cout << "[Rank " << global_rank << "] Precision check output: " << filename << std::endl; - } + std::cout << "[Rank " << global_rank << "] Precision check output: " << filename << std::endl; } - initialized = true; + tls_initialized = true; } } - return use_console ? std::cout : log_file; -} - -bool ShouldPrint() { - const auto &config = PrecisionCheckEnv::Instance().GetConfig(); - if (!config.output_path.empty()) { - return true; - } - return nn::parallel::global::GlobalEnv::Instance().global_proc_rank() == 0; -} - -std::string GetTimestamp() { - auto now = std::chrono::system_clock::now(); - auto time_t = std::chrono::system_clock::to_time_t(now); - auto ms = std::chrono::duration_cast(now.time_since_epoch()) % 1000; - - std::tm tm; - localtime_r(&time_t, &tm); - - std::ostringstream oss; - oss << std::setfill('0') << std::setw(2) << (tm.tm_mon + 1) << std::setw(2) << tm.tm_mday << ' ' << std::setw(2) - << tm.tm_hour << ':' << std::setw(2) << tm.tm_min << ':' << std::setw(2) << tm.tm_sec << '.' << std::setw(3) - << ms.count(); - return oss.str(); + return tls_log_file.is_open() ? tls_log_file : std::cout; } std::string FormatShape(const std::vector &shape) { @@ -277,7 +183,7 @@ std::string FormatShape(const std::vector &shape) { oss << "("; for (size_t i = 0; i < shape.size(); ++i) { if (i > 0) { - oss << ", "; + oss << ","; } oss << shape[i]; } @@ -302,63 +208,81 @@ std::string DataTypeToString(DataType dtype) { } } -void PrintTableHeader(std::ostream &os) { - if (TableHeaderPrinted()) { - return; +struct TensorStats { + float min_val = 0; + float max_val = 0; + float mean_val = 0; + int nan_count = 0; + int inf_count = 0; +}; + +TensorStats ComputeStats(const float *data, size_t num_elements) { + TensorStats stats; + if (num_elements == 0) { + return stats; } - TableHeaderPrinted() = true; - - os << "+" << std::string(50, '-') << "+" << std::string(7, '-') << "+" << std::string(18, '-') << "+" - << std::string(15, '-') << "+" << std::string(10, '-') << "+\n"; - os << "| " << std::left << std::setw(49) << "key" - << "| " << std::setw(6) << "level" - << "| " << std::setw(17) << "shape" - << "| " << std::setw(14) << "dtype" - << "| " << std::setw(9) << "same_hash" - << "|\n"; - os << "+" << std::string(50, '-') << "+" << std::string(7, '-') << "+" << std::string(18, '-') << "+" - << std::string(15, '-') << "+" << std::string(10, '-') << "+\n"; -} -void PrintTableRow(std::ostream &os, const std::string &key, int level, const std::string &shape, - const std::string &dtype, const std::string &same_hash) { - os << "| " << std::left << std::setw(49) << key.substr(0, 49) << "| " << std::setw(6) << level << "| " - << std::setw(17) << shape.substr(0, 17) << "| " << std::setw(14) << dtype << "| " << std::setw(9) << same_hash - << "|\n"; -} + stats.min_val = std::numeric_limits::max(); + stats.max_val = std::numeric_limits::lowest(); + double sum = 0; -// Calculate diff order between two tensors (returns string like "1e-3" or "0") -std::string CalculateDiffOrder(const float *data1, const float *data2, size_t size) { - if (!data1 || !data2 || size == 0) { - return "N/A"; + for (size_t i = 0; i < num_elements; ++i) { + float val = data[i]; + if (std::isnan(val)) { + stats.nan_count++; + continue; + } + if (std::isinf(val)) { + stats.inf_count++; + continue; + } + stats.min_val = std::min(stats.min_val, val); + stats.max_val = std::max(stats.max_val, val); + sum += val; } - double max_diff = 0.0; - for (size_t i = 0; i < size; ++i) { - double diff = std::abs(static_cast(data1[i]) - static_cast(data2[i])); - max_diff = std::max(max_diff, diff); - } + size_t valid_count = num_elements - stats.nan_count - stats.inf_count; + stats.mean_val = valid_count > 0 ? static_cast(sum / valid_count) : 0; + + return stats; +} - if (max_diff == 0.0) { - return "0"; +// Quantize float data to specified tolerance for MD5 calculation +// e.g., tolerance=1e-3: 4.0003 and 4.0004 both become 4.000 +std::vector QuantizeData(const float *data, size_t num_elements, double tolerance) { + std::vector quantized(num_elements); + double inv_tolerance = 1.0 / tolerance; + for (size_t i = 0; i < num_elements; ++i) { + quantized[i] = static_cast(std::round(data[i] * inv_tolerance) * tolerance); } + return quantized; +} - int order = static_cast(std::floor(std::log10(max_diff))); - return "1e" + std::to_string(order); +void SaveNpy(const std::shared_ptr &tensor, const std::string &name, int idx, const std::string &stage, + int rank) { + const auto &output_path = PrecisionCheckEnv::Instance().GetOutputPath(); + std::string dir = output_path + "/rank_" + std::to_string(rank); + std::filesystem::create_directories(dir); + std::string filename = dir + "/" + name + "_" + std::to_string(idx) + "_" + stage + ".npy"; + + if (tensor->Dtype() == DataType::kFLOAT32) { + tensor->SaveAsNpy(filename); + } else { + auto float_tensor = tensor->To(DataType::kFLOAT32); + float_tensor.SaveAsNpy(filename); + } } } // namespace void PrecisionChecker::CheckTensors(const std::string &stage, const std::string &name, const std::vector> &tensors, const Config &config) { - if (!ShouldPrint()) { + const auto &global_config = PrecisionCheckEnv::Instance().GetConfig(); + if (global_config.level == PrecisionCheckLevel::OFF) { return; } - const auto &global_config = PrecisionCheckEnv::Instance().GetConfig(); const int rank = nn::parallel::global::thread_global_rank; - const auto level = global_config.level; - auto &baseline = GetBaseline(); for (size_t i = 0; i < tensors.size(); ++i) { if (!tensors[i]) { @@ -376,110 +300,76 @@ void PrecisionChecker::CheckTensors(const std::string &stage, const std::string cpu_tensor = tensor; } - const void *data = cpu_tensor->DataPtr(); + const float *float_data = static_cast(cpu_tensor->DataPtr()); const size_t byte_size = cpu_tensor->SizeInBytes(); const size_t num_elements = cpu_tensor->NumElements(); - // Build key + // Build context key const std::string context_key = PrecisionCheckContext::Instance().GetKey(); - const std::string full_key = context_key.empty() ? (stage + " " + name + " tensor[" + std::to_string(i) + "]") - : (context_key + " " + stage + " " + name); - - // Only compute MD5 if needed (for output or baseline comparison) - const bool need_md5 = global_config.output_md5 || !baseline.empty(); - std::string md5; - if (need_md5) { - md5 = ComputeMD5(data, byte_size); - } + const std::string stage_short = (stage.find("Forward") != std::string::npos) ? "forward" : "backward"; - // Check baseline - const bool has_baseline = !baseline.empty(); - bool same_hash = true; - if (has_baseline) { - auto it = baseline.find(full_key); - if (it == baseline.end() && !context_key.empty()) { - // Try without context: "stage name tensor[i]" - std::string key_without_context = stage + " " + name + " tensor[" + std::to_string(i) + "]"; - it = baseline.find(key_without_context); - } - if (it != baseline.end()) { - same_hash = (it->second == md5); - } + // Get tensor index for this (name, stage) combination + std::string counter_key = name + "_" + stage_short; + int idx = PrecisionCheckEnv::GetAndIncrementCounter(counter_key); + + // Save NPY if enabled + if (global_config.save_tensors) { + SaveNpy(cpu_tensor, name, idx, stage_short, rank); } + // Output to log auto &log_stream = GetLogStream(); - if (global_config.format == "table") { - thread_local bool header_printed = false; - if (!header_printed) { - PrintTableHeader(log_stream); - header_printed = true; - } - std::string same_hash_str = has_baseline ? (same_hash ? "True" : "False") : "--"; - PrintTableRow(log_stream, full_key, static_cast(level), FormatShape(cpu_tensor->Dims()), - DataTypeToString(cpu_tensor->Dtype()), same_hash_str); - - // Save to baseline file if output_path is set and output_md5 is true - if (!global_config.output_path.empty() && global_config.output_md5) { - log_stream << full_key << "|" << md5 << std::endl; + if (global_config.format == "md5") { + // MD5 format + std::string md5; + if (global_config.md5_tolerance > 0.0) { + // Quantize data before computing MD5 + // Convert to float32 if needed for quantization + std::shared_ptr float32_tensor = cpu_tensor; + if (cpu_tensor->Dtype() != DataType::kFLOAT32) { + float32_tensor = std::make_shared(cpu_tensor->To(DataType::kFLOAT32)); + } + const float *data_ptr = static_cast(float32_tensor->DataPtr()); + size_t num_elems = float32_tensor->NumElements(); + auto quantized = QuantizeData(data_ptr, num_elems, global_config.md5_tolerance); + md5 = ComputeMD5(quantized.data(), quantized.size() * sizeof(float)); + } else { + // Original precision MD5 + md5 = ComputeMD5(cpu_tensor->DataPtr(), byte_size); } + log_stream << context_key << " " << name << "_" << idx << "_" << stage << " tensor[" << i << "]: " + << "dtype=" << DataTypeToString(cpu_tensor->Dtype()) << " " + << "shape=" << FormatShape(cpu_tensor->Dims()) << " " + << "md5=" << md5 << std::endl; } else { - // Simple format - const float *float_data = static_cast(data); - - bool has_nan = false; - bool has_inf = false; - for (size_t j = 0; j < num_elements; ++j) { - float val = float_data[j]; - if (std::isnan(val)) { - has_nan = true; - } - if (std::isinf(val)) { - has_inf = true; + // Simple format (default) + TensorStats stats = ComputeStats(float_data, num_elements); + + const bool has_error + = (config.check_nan && stats.nan_count > 0) || (config.check_inf && stats.inf_count > 0); + const std::string error_marker = has_error ? " <- ERROR" : ""; + + log_stream << context_key << " " << name << "_" << idx << "_" << stage << " tensor[" << i << "]: " + << "dtype=" << DataTypeToString(cpu_tensor->Dtype()) << " " + << "shape=" << FormatShape(cpu_tensor->Dims()) << " " + << "min=" << stats.min_val << " " + << "max=" << stats.max_val << " " + << "mean=" << stats.mean_val << " ["; + + // Print first 6 values + constexpr size_t max_print = 6; + for (size_t j = 0; j < std::min(num_elements, max_print); ++j) { + if (j > 0) { + log_stream << ", "; } + log_stream << float_data[j]; } - - const bool has_error = (config.check_nan && has_nan) || (config.check_inf && has_inf); - - // When output_path is set, always write to file; otherwise only write on error or if print_stats is enabled - const bool should_output = !global_config.output_path.empty() || has_error || config.print_stats; - - if (should_output) { - const std::string log_level = has_error ? "E" : "I"; - - log_stream << log_level << GetTimestamp() << " [Rank " << rank << "][PrecisionCheck] " << stage << " " - << name << " tensor[" << i << "]: "; - - if (global_config.output_md5) { - log_stream << "md5=" << md5; - if (!same_hash) { - log_stream << " (MISMATCH)"; - } - } else { - log_stream << "["; - if (has_nan) { - log_stream << " NaN detected!"; - } - if (has_inf) { - log_stream << " Inf detected!"; - } - - if (config.print_stats) { - constexpr size_t max_print = 6; - for (size_t j = 0; j < std::min(num_elements, max_print); ++j) { - if (j > 0) { - log_stream << ", "; - } - log_stream << float_data[j]; - } - if (num_elements > max_print) { - log_stream << ", ..."; - } - } - log_stream << "]"; - } - log_stream << std::endl; + if (num_elements > max_print) { + log_stream << ", ..."; } + log_stream << "] [NaN:" << stats.nan_count << " Inf:" << stats.inf_count << "]" << error_marker + << std::endl; if (has_error && config.abort_on_error) { std::cerr << "Precision check failed, aborting!" << std::endl; @@ -489,6 +379,16 @@ void PrecisionChecker::CheckTensors(const std::string &stage, const std::string } } +void PrecisionChecker::Init(const PrecisionCheckConfig &global_config, const Config &config) { + if (global_config.level == PrecisionCheckLevel::OFF) { + return; + } + + // Register auto-hook for all modules via GlobalModuleHookRegistry + GlobalModuleHookRegistry::Instance().RegisterHook( + [config](nn::Module *module) { RegisterForModule(module, module->type(), config); }); +} + void PrecisionChecker::RegisterForFunction(autograd::Function *func, const std::string &name, const Config &config) { const std::string func_name = name.empty() ? "Function" : name; @@ -520,14 +420,16 @@ void PrecisionChecker::RegisterForModule(nn::Module *module, const std::string & module->RegisterForwardPostHook([module_name, config](nn::Module *, const std::vector> &, const std::vector> &outputs) { - CheckTensors("Module Forward Output", module_name, outputs, config); + CheckTensors("Forward Output", module_name, outputs, config); }); module->RegisterBackwardPostHook([module_name, config](nn::Module *, const std::vector> &grad_inputs, const std::vector> &) { - CheckTensors("Module Backward Output", module_name, grad_inputs, config); + CheckTensors("Backward Output", module_name, grad_inputs, config); }); } +void PrecisionChecker::ResetCounters() { PrecisionCheckEnv::ResetCounters(); } + } // namespace infini_train::utils diff --git a/tools/compare_loss.py b/scripts/compare_loss.py similarity index 86% rename from tools/compare_loss.py rename to scripts/compare_loss.py index 7a38c0d8..d8ed8718 100755 --- a/tools/compare_loss.py +++ b/scripts/compare_loss.py @@ -54,6 +54,7 @@ def main(): parser.add_argument('--threshold', type=float, help='Loss difference threshold (deprecated, use --threshold-fp32 and --threshold-bf16)') parser.add_argument('--threshold-fp32', type=float, default=1e-5, help='Loss difference threshold for fp32 (default: 1e-5)') parser.add_argument('--threshold-bf16', type=float, default=1e-2, help='Loss difference threshold for bfloat16 (default: 1e-2)') + parser.add_argument('--verbose', action='store_true', help='Print detailed output for all files, including passed ones') args = parser.parse_args() # Support legacy --threshold argument @@ -69,9 +70,9 @@ def main(): common = set(files1.keys()) & set(files2.keys()) if only_in_1: - print(f"Files only in {args.dir1}: {', '.join(sorted(only_in_1))}") + print(f"Files only in {args.dir1.resolve()}: {', '.join(sorted(only_in_1))}") if only_in_2: - print(f"Files only in {args.dir2}: {', '.join(sorted(only_in_2))}") + print(f"Files only in {args.dir2.resolve()}: {', '.join(sorted(only_in_2))}") if only_in_1 or only_in_2: print() @@ -90,10 +91,10 @@ def main(): else: fp32_total += 1 - print(f"Comparing {name} ({dtype}, threshold: {threshold:.0e}):") total_steps, num_mismatches, mismatches = compare_files(files1[name], files2[name], threshold) if mismatches: + print(f"Comparing {name} ({dtype}, threshold: {threshold:.0e}):") for msg in mismatches: print(msg) total_mismatches += num_mismatches @@ -103,9 +104,12 @@ def main(): else: fp32_passed += 1 - matched = total_steps - num_mismatches - print(f" Summary: {matched}/{total_steps} steps matched") - print() + # Only print details when there are mismatches or verbose mode + if mismatches or args.verbose: + if mismatches: + matched = total_steps - num_mismatches + print(f" Summary: {matched}/{total_steps} steps matched") + print() print("=" * 50) print(f"Overall Summary:") diff --git a/tools/compare_tps.py b/scripts/compare_tps.py similarity index 72% rename from tools/compare_tps.py rename to scripts/compare_tps.py index b99f97d5..8a3cb804 100755 --- a/tools/compare_tps.py +++ b/scripts/compare_tps.py @@ -31,7 +31,7 @@ def compare_files(file1, file2, threshold): tps2 = {k: v for k, v in tps2.items() if k > 1} if not tps1 or not tps2: - return 0, 1, [" No valid steps found (after excluding step 1)"] + return 0, 1, [" No valid steps found (after excluding step 1)"], 0, 0, 0 # Calculate averages avg1 = sum(tps1.values()) / len(tps1) @@ -44,17 +44,15 @@ def compare_files(file1, file2, threshold): if rel_error > threshold: mismatches.append(f" Average tok/s: {avg1:.2f} vs {avg2:.2f} ✗ (error: {rel_error*100:.2f}%, threshold: {threshold*100:.2f}%)") mismatches.append(f" Steps compared: {len(tps1)} vs {len(tps2)} (excluding step 1)") - else: - print(f" Average tok/s: {avg1:.2f} vs {avg2:.2f} ✓ (error: {rel_error*100:.2f}%, threshold: {threshold*100:.2f}%)") - print(f" Steps compared: {len(tps1)} vs {len(tps2)} (excluding step 1)") - return 1, len(mismatches), mismatches + return 1, len(mismatches), mismatches, avg1, avg2, rel_error def main(): parser = ArgumentParser(description='Compare tok/s between two log directories') parser.add_argument('dir1', type=Path, help='First log directory') parser.add_argument('dir2', type=Path, help='Second log directory') parser.add_argument('--threshold', type=float, default=0.20, help='Relative error threshold (default: 0.20 = 20%%)') + parser.add_argument('--verbose', action='store_true', help='Print detailed output for all files, including passed ones') args = parser.parse_args() files1 = {f.name: f for f in args.dir1.glob('*.log')} @@ -65,9 +63,9 @@ def main(): common = set(files1.keys()) & set(files2.keys()) if only_in_1: - print(f"Files only in {args.dir1}: {', '.join(sorted(only_in_1))}") + print(f"Files only in {args.dir1.resolve()}: {', '.join(sorted(only_in_1))}") if only_in_2: - print(f"Files only in {args.dir2}: {', '.join(sorted(only_in_2))}") + print(f"Files only in {args.dir2.resolve()}: {', '.join(sorted(only_in_2))}") if only_in_1 or only_in_2: print() @@ -77,17 +75,24 @@ def main(): for name in sorted(common): total_files += 1 - print(f"Comparing {name}:") - total_comparisons, num_mismatches, mismatches = compare_files(files1[name], files2[name], args.threshold) + total_comparisons, num_mismatches, mismatches, avg1, avg2, rel_error = compare_files(files1[name], files2[name], args.threshold) if mismatches: + print(f"Comparing {name}:") for msg in mismatches: print(msg) total_mismatches += num_mismatches else: passed_files += 1 - - print() + # Only print details when verbose mode is enabled + if args.verbose: + print(f"Comparing {name}:") + print(f" Average tok/s: {avg1:.2f} vs {avg2:.2f} ✓ (error: {rel_error*100:.2f}%, threshold: {args.threshold*100:.2f}%)") + print(f" Steps compared: {len([k for k in parse_log(files1[name]) if k > 1])} (excluding step 1)") + + # Print separator when there are mismatches or verbose mode + if mismatches or args.verbose: + print() print("=" * 50) print(f"Overall Summary:") diff --git a/scripts/precision_check/precision_compare.py b/scripts/precision_check/precision_compare.py new file mode 100755 index 00000000..40c91308 --- /dev/null +++ b/scripts/precision_check/precision_compare.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +""" +Precision comparison tool for InfiniTrain tensor outputs. + +Usage: + python precision_compare.py --dir1 ./run1 --dir2 ./run2 [--atol 1e-5] [--rtol 1e-3] + +Compares .npy files between two directories and reports differences. +""" + +import argparse +import os +import sys +from pathlib import Path + +import numpy as np + + +def find_npy_files(directory: str) -> dict[str, Path]: + """Find all .npy files in directory (recursively).""" + files = {} + for path in Path(directory).rglob("*.npy"): + rel_path = path.relative_to(directory) + files[str(rel_path)] = path + return files + + +def compare_tensors(file1: Path, file2: Path, atol: float, rtol: float) -> dict: + """Compare two tensor files and return comparison results.""" + arr1 = np.load(file1) + arr2 = np.load(file2) + + result = { + "file": str(file1.name), + "shape1": arr1.shape, + "shape2": arr2.shape, + "dtype1": str(arr1.dtype), + "dtype2": str(arr2.dtype), + "match": False, + "error": None, + } + + if arr1.shape != arr2.shape: + result["error"] = f"Shape mismatch: {arr1.shape} vs {arr2.shape}" + return result + + if arr1.dtype != arr2.dtype: + result["error"] = f"Dtype mismatch: {arr1.dtype} vs {arr2.dtype}" + return result + + arr1_flat = arr1.astype(np.float64).flatten() + arr2_flat = arr2.astype(np.float64).flatten() + + abs_diff = np.abs(arr1_flat - arr2_flat) + max_abs_diff = np.max(abs_diff) + mean_abs_diff = np.mean(abs_diff) + + with np.errstate(divide="ignore", invalid="ignore"): + rel_diff = abs_diff / (np.abs(arr2_flat) + 1e-12) + rel_diff = np.where(np.isfinite(rel_diff), rel_diff, 0) + max_rel_diff = np.max(rel_diff) + mean_rel_diff = np.mean(rel_diff) + + result["max_abs_diff"] = float(max_abs_diff) + result["mean_abs_diff"] = float(mean_abs_diff) + result["max_rel_diff"] = float(max_rel_diff) + result["mean_rel_diff"] = float(mean_rel_diff) + result["match"] = np.allclose(arr1, arr2, atol=atol, rtol=rtol) + + return result + + +def main(): + parser = argparse.ArgumentParser(description="Compare precision check outputs") + parser.add_argument("--dir1", required=True, help="First directory") + parser.add_argument("--dir2", required=True, help="Second directory") + parser.add_argument("--atol", type=float, default=1e-5, help="Absolute tolerance") + parser.add_argument("--rtol", type=float, default=1e-3, help="Relative tolerance") + parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") + args = parser.parse_args() + + if not os.path.isdir(args.dir1): + print(f"Error: {args.dir1} is not a directory") + sys.exit(1) + if not os.path.isdir(args.dir2): + print(f"Error: {args.dir2} is not a directory") + sys.exit(1) + + files1 = find_npy_files(args.dir1) + files2 = find_npy_files(args.dir2) + + print(f"Directory 1: {args.dir1} ({len(files1)} files)") + print(f"Directory 2: {args.dir2} ({len(files2)} files)") + print(f"Tolerance: atol={args.atol}, rtol={args.rtol}") + print() + + only_in_1 = set(files1.keys()) - set(files2.keys()) + only_in_2 = set(files2.keys()) - set(files1.keys()) + common = set(files1.keys()) & set(files2.keys()) + + if only_in_1: + print(f"Files only in dir1 ({len(only_in_1)}):") + for f in sorted(only_in_1): + print(f" {f}") + print() + + if only_in_2: + print(f"Files only in dir2 ({len(only_in_2)}):") + for f in sorted(only_in_2): + print(f" {f}") + print() + + if not common: + print("No common files to compare") + sys.exit(1) + + print(f"Comparing {len(common)} common files...") + print() + + passed = 0 + failed = 0 + errors = 0 + + for rel_path in sorted(common): + result = compare_tensors(files1[rel_path], files2[rel_path], args.atol, args.rtol) + + if result["error"]: + errors += 1 + print(f"ERROR: {rel_path}") + print(f" {result['error']}") + elif result["match"]: + passed += 1 + if args.verbose: + print(f"PASS: {rel_path}") + print(f" max_abs={result['max_abs_diff']:.2e} max_rel={result['max_rel_diff']:.2e}") + else: + failed += 1 + print(f"FAIL: {rel_path}") + print(f" shape={result['shape1']} dtype={result['dtype1']}") + print(f" max_abs={result['max_abs_diff']:.2e} mean_abs={result['mean_abs_diff']:.2e}") + print(f" max_rel={result['max_rel_diff']:.2e} mean_rel={result['mean_rel_diff']:.2e}") + + print() + print("=" * 50) + print(f"Summary: {passed} passed, {failed} failed, {errors} errors") + print(f"Missing: {len(only_in_1)} in dir1 only, {len(only_in_2)} in dir2 only") + + if failed > 0 or errors > 0: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/scripts/run_models_and_profile.bash b/scripts/run_models_and_profile.bash index 785c4a0e..7682a340 100755 --- a/scripts/run_models_and_profile.bash +++ b/scripts/run_models_and_profile.bash @@ -20,6 +20,7 @@ read_var() { BUILD_DIR="$(read_var BUILD_DIR)"; : "${BUILD_DIR:=../build}" LOG_DIR="$(read_var LOG_DIR)"; : "${LOG_DIR:=logs}" PROFILE_LOG_DIR="$(read_var PROFILE_LOG_DIR)"; : "${PROFILE_LOG_DIR:=./profile_logs}" +COMPARE_LOG_DIR="$(read_var COMPARE_LOG_DIR)"; : "${COMPARE_LOG_DIR:=}" mkdir -p "$BUILD_DIR" "$LOG_DIR" "$PROFILE_LOG_DIR" @@ -83,7 +84,16 @@ run_and_log() { echo "[COMMAND] $cmd" >> "$log_path" # Run the command and append both stdout and stderr to the log file - eval "$cmd" >> "$log_path" 2>&1 + if ! eval "$cmd" >> "$log_path" 2>&1; then + echo -e "\033[1;31m============================================================\033[0m" + echo -e "\033[1;31m[ERROR] Command failed: ${cmd}\033[0m" + echo -e "\033[1;31m[ERROR] See log file for details: ${log_path}\033[0m" + echo -e "\033[1;31m============================================================\033[0m" + echo "" + echo "[ERROR] Last 20 lines of log:" + tail -20 "$log_path" + exit 1 + fi popd > /dev/null @@ -174,3 +184,28 @@ for ((id=0; id #include #include @@ -6,7 +7,9 @@ #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/tensor.h" +#include "infini_train/include/utils/global_module_hook_registry.h" #include "infini_train/include/utils/precision_check_config.h" +#include "infini_train/include/utils/precision_checker.h" using namespace infini_train; @@ -22,6 +25,30 @@ class MyModel : public nn::Module { } }; +// Simple model for multi-iteration test +class SimpleModel : public nn::Module { +public: + SimpleModel() : Module("SimpleModel") {} + + std::vector> Forward(const std::vector> &inputs) override { + auto x = inputs[0]; + x->RequiresGrad(); + auto y = x->Mul(x)->Mul(x); // x^3 + return {y}; + } +}; + +void RunModelForwardBackward(const std::shared_ptr &model) { + auto x = std::make_shared(std::vector{2, 3}, DataType::kFLOAT32); + x->Fill(2.0f); + x->RequiresGrad(); + + std::vector> inputs = {x}; + auto outputs = (*model)(inputs); + auto loss = outputs[0]->Sum(0, false)->Sum(0, false); + loss->Backward(); +} + void TestFunctionLevel(const std::string &config_str) { std::cout << "\n========================================" << std::endl; std::cout << " Function-Level Test: " << config_str << std::endl; @@ -42,40 +69,164 @@ void TestFunctionLevel(const std::string &config_str) { std::cout << "Test completed." << std::endl; } -void TestModuleLevel() { +void TestModuleLevel(const std::string &config_str) { std::cout << "\n========================================" << std::endl; - std::cout << " Module-Level Test" << std::endl; + std::cout << " Module-Level Test: " << config_str << std::endl; std::cout << "========================================" << std::endl; auto model = std::make_shared(); + RunModelForwardBackward(model); + + std::cout << "Test completed." << std::endl; +} + +// Test: Simple format output (level=2, format=simple) +void TestSimpleFormat() { + std::cout << "\n========================================" << std::endl; + std::cout << " Test: Simple Format (level=2, format=simple)" << std::endl; + std::cout << "========================================" << std::endl; + auto x = std::make_shared(std::vector{2, 3}, DataType::kFLOAT32); x->Fill(2.0f); x->RequiresGrad(); - std::vector> inputs = {x}; - auto outputs = (*model)(inputs); - auto loss = outputs[0]->Sum(0, false)->Sum(0, false); + auto y = x->Mul(x); + auto loss = y->Sum(0, false)->Sum(0, false); // Two Sum ops to produce scalar loss->Backward(); - std::cout << "Test completed." << std::endl; + std::cout << "Simple format test completed - check output for min/max/mean values." << std::endl; +} + +// Test: MD5 format output (level=2, format=md5) +void TestMd5Format() { + std::cout << "\n========================================" << std::endl; + std::cout << " Test: MD5 Format (level=2, format=md5)" << std::endl; + std::cout << "========================================" << std::endl; + + auto x = std::make_shared(std::vector{2, 3}, DataType::kFLOAT32); + x->Fill(2.0f); + x->RequiresGrad(); + + auto y = x->Mul(x); + auto loss = y->Sum(0, false)->Sum(0, false); // Two Sum ops to produce scalar + loss->Backward(); + + std::cout << "MD5 format test completed - check output for md5 hashes." << std::endl; +} + +// Test: Save tensors to NPY files (level=1, save_tensors=true) +void TestSaveTensors() { + std::cout << "\n========================================" << std::endl; + std::cout << " Test: Save Tensors (level=1, save_tensors=true)" << std::endl; + std::cout << "========================================" << std::endl; + + std::string output_path = "/tmp/precision_check_npy"; + + auto model = std::make_shared(); + RunModelForwardBackward(model); + + // Verify NPY files were created + namespace fs = std::filesystem; + bool found_npy = false; + if (fs::exists(output_path)) { + for (const auto &entry : fs::recursive_directory_iterator(output_path)) { + if (entry.path().extension() == ".npy") { + found_npy = true; + std::cout << "Found NPY file: " << entry.path() << std::endl; + } + } + } + + if (found_npy) { + std::cout << "Save tensors test PASSED - NPY files created successfully." << std::endl; + } else { + std::cout << "Save tensors test completed - check output directory for NPY files." << std::endl; + } +} + +// Test: Multi-iteration file overwrite (level=1, save_tensors=true, iter=3) +void TestMultiIterOverwrite() { + std::cout << "\n========================================" << std::endl; + std::cout << " Test: Multi-Iteration File Overwrite" << std::endl; + std::cout << "========================================" << std::endl; + + std::string output_path = "/tmp/precision_check_overwrite"; + + auto model = std::make_shared(); + int num_iters = 3; + + // Run multiple iterations - files should be overwritten + for (int i = 0; i < num_iters; ++i) { + std::cout << "Iteration " << (i + 1) << "/" << num_iters << std::endl; + utils::PrecisionCheckEnv::ResetCounters(); // Reset counters each iteration + RunModelForwardBackward(model); + } + + namespace fs = std::filesystem; + int npy_count = 0; + if (fs::exists(output_path)) { + for (const auto &entry : fs::recursive_directory_iterator(output_path)) { + if (entry.path().extension() == ".npy") { + ++npy_count; + } + } + } + + std::cout << "Multi-iteration test completed - found " << npy_count << " NPY files after " << num_iters + << " iterations." << std::endl; + std::cout << "(Files should be overwritten each iteration, count should be consistent with 1 iter)" << std::endl; } int main(int argc, char *argv[]) { google::InitGoogleLogging(argv[0]); - std::string config_str = argc > 1 ? argv[1] : "level=2"; + std::string config_str = argc > 1 ? argv[1] : ""; std::cout << "========================================" << std::endl; std::cout << " Precision Check Test Suite" << std::endl; std::cout << "========================================" << std::endl; - std::cout << "Config: " << config_str << std::endl; - auto config = utils::PrecisionCheckConfig::Parse(config_str); nn::parallel::global::InitAllEnv(1, 1, false, 1, 1); + + // If no config argument, run all format tests + if (config_str.empty()) { + auto config = utils::PrecisionCheckConfig::Parse("level=2,format=simple"); + utils::PrecisionCheckEnv::Instance().Init(config); + + std::cout << "\nRunning all precision check format tests..." << std::endl; + + // Test 1: Simple format + TestSimpleFormat(); + + // Test 2: MD5 format + auto md5_config = utils::PrecisionCheckConfig::Parse("level=2,format=md5"); + utils::PrecisionCheckEnv::Instance().Init(md5_config); + TestMd5Format(); + + // Test 3: Save tensors + auto npy_config = utils::PrecisionCheckConfig::Parse("level=1,save_tensors=true"); + utils::PrecisionCheckEnv::Instance().Init(npy_config); + TestSaveTensors(); + + // Test 4: Multi-iteration overwrite + auto iter_config = utils::PrecisionCheckConfig::Parse("level=1,save_tensors=true"); + utils::PrecisionCheckEnv::Instance().Init(iter_config); + TestMultiIterOverwrite(); + + std::cout << "\n========================================" << std::endl; + std::cout << " All Tests Completed Successfully" << std::endl; + std::cout << "========================================" << std::endl; + return 0; + } + + // If config provided, run single test (original behavior) + auto config = utils::PrecisionCheckConfig::Parse(config_str); utils::PrecisionCheckEnv::Instance().Init(config); + std::cout << "Config: " << config_str << std::endl; + if (config.level == utils::PrecisionCheckLevel::MODULE) { - TestModuleLevel(); + TestModuleLevel(config_str); } else if (config.level == utils::PrecisionCheckLevel::FUNCTION) { TestFunctionLevel(config_str); } else { @@ -83,7 +234,7 @@ int main(int argc, char *argv[]) { } std::cout << "\n========================================" << std::endl; - std::cout << " All Tests Completed Successfully" << std::endl; + std::cout << " Test Completed" << std::endl; std::cout << "========================================" << std::endl; return 0;