diff --git a/docs/CN/source/cookbook/glm4_deployment.rst b/docs/CN/source/cookbook/glm4_deployment.rst new file mode 100644 index 000000000..ec63490d4 --- /dev/null +++ b/docs/CN/source/cookbook/glm4_deployment.rst @@ -0,0 +1,213 @@ +.. _glm4_deployment: + +GLM-4.7-Flash 模型部署指南 +=========================== + +LightLLM 支持 GLM-4.7-Flash (glm4_moe_lite) 模型系列的部署,该模型采用 MoE 架构。本文档提供详细的部署配置、函数调用和 MTP(多令牌预测)支持信息。 + +模型概述 +-------- + +**主要特性:** + +- 分组 MoE,支持 top-k 专家选择 +- 支持 ``vanilla_with_att`` 和 ``eagle_with_att`` MTP 模式 +- 兼容 FlashAttention3 后端 +- 支持 XML 风格参数格式的函数调用 + +模型参考:https://huggingface.co/zai-org/GLM-4.7-Flash + +推荐启动脚本 (H200) +------------------- + +**基础启动命令:** + +.. code-block:: bash + + LIGHTLLM_TRITON_AUTOTUNE_LEVEL=1 LOADWORKER=18 \ + python -m lightllm.server.api_server \ + --model_dir /path/to/GLM-4.7-Flash/ \ + --tp 1 \ + --max_req_total_len 202752 \ + --chunked_prefill_size 8192 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend flashinfer \ + --graph_max_batch_size 512 \ + --tool_call_parser glm47 \ + --reasoning_parser glm45 \ + --host 0.0.0.0 \ + --port 8000 + +**参数说明:** + +- ``LIGHTLLM_TRITON_AUTOTUNE_LEVEL=1``: 启用 Triton 自动调优以获得最佳内核性能 +- ``LOADWORKER=18``: 模型加载线程数,加快权重加载速度 +- ``--tp 1``: 张量并行度(单 GPU) +- ``--max_req_total_len 202752``: 最大请求总长度 +- ``--chunked_prefill_size 8192``: 预填充处理的分块大小 +- ``--llm_prefill_att_backend fa3``: 预填充阶段使用 FlashAttention3 +- ``--llm_decode_att_backend flashinfer``: 解码阶段使用 FlashInfer +- ``--graph_max_batch_size 512``: CUDA graph 最大批处理大小 +- ``--tool_call_parser glm47``: 使用 GLM-4.7 函数调用解析器 +- ``--reasoning_parser glm45``: 使用 GLM-4.5 推理解析器 + +MTP(多令牌预测)模式 +--------------------- + +要启用 MTP 进行推测解码,请添加以下参数: + +.. code-block:: bash + + LIGHTLLM_TRITON_AUTOTUNE_LEVEL=1 LOADWORKER=18 \ + python -m lightllm.server.api_server \ + --model_dir /path/to/GLM-4.7-Flash/ \ + --tp 1 \ + --max_req_total_len 202752 \ + --chunked_prefill_size 8192 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend flashinfer \ + --graph_max_batch_size 512 \ + --tool_call_parser glm47 \ + --reasoning_parser glm45 \ + --mtp_step 4 \ + --mtp_mode eagle_with_att \ + --mtp_draft_model_dir /path/to/GLM-4.7-Flash/ \ + --host 0.0.0.0 \ + --port 8000 + +**MTP 参数说明:** + +- ``--mtp_step 4``: 每个 MTP 步骤预测的令牌数 +- ``--mtp_mode eagle_with_att``: MTP 模式(支持 ``vanilla_with_att`` 和 ``eagle_with_att``) +- ``--mtp_draft_model_dir``: MTP 草稿模型路径 + +函数调用支持 +------------ + +GLM-4.7-Flash 使用新的 ``Glm47Detector`` 类来解析 XML 风格的工具调用。 + +**函数调用格式:** + +.. code-block:: xml + + func_name + keyvalue + + +**特性:** + +- 完整的流式支持,支持增量解析 +- 兼容 OpenAI 风格的函数调用 API + +测试与验证 +---------- + +基础功能测试 +~~~~~~~~~~~~ + +.. code-block:: bash + + curl http://localhost:8000/generate \ + -H "Content-Type: application/json" \ + -d '{ + "inputs": "什么是人工智能?", + "parameters":{ + "max_new_tokens": 100, + "frequency_penalty": 1 + } + }' + +OpenAI 兼容聊天接口 +~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "GLM-4.7-Flash", + "messages": [{"role": "user", "content": "你好"}], + "max_tokens": 100 + }' + +性能基准测试 +------------ + +函数调用测试结果 (BFCL v3) +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. list-table:: + :header-rows: 1 + :widths: 40 20 + + * - 类别 + - LightLLM + * - simple + - 62.50% + * - multiple + - 54.50% + * - parallel + - 69.50% + * - parallel_multiple + - 61.50% + * - java + - 66.00% + * - javascript + - 48.00% + * - irrelevance + - 83.33% + * - live_simple + - 45.74% + * - live_multiple + - 34.00% + * - live_parallel + - 25.00% + * - live_parallel_multiple + - 37.50% + * - rest + - 2.86% + * - sql + - 28.00% + * - **总体** + - **49.12%** + +速度测试结果 (ShareGPT 2000 条提示,4×H200) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. list-table:: + :header-rows: 1 + :widths: 30 20 20 30 + + * - 工作负载 + - 输出 (tok/s) + - TTFT (ms) + - 端到端延迟 (ms) + * - burst + - 6442 + - 11476 + - 27719 + * - high-conc (512) + - **6728** + - 1099 + - 11240 + * - moderate (10 req/s) + - 1798 + - 196 + - 5746 + * - steady (5 req/s) + - 917 + - 154 + - 2797 + +硬件要求 +-------- + +**测试配置:** + +- 4× NVIDIA H200 (每卡 80GB HBM3) +- NVLink 4.0 互联 + +**最低要求:** + +- 基础部署需要单张 NVIDIA H100/H200 GPU(80GB 显存) +- 生产环境建议使用多 GPU 配置 diff --git a/docs/CN/source/index.rst b/docs/CN/source/index.rst index b97b2c759..58b747359 100755 --- a/docs/CN/source/index.rst +++ b/docs/CN/source/index.rst @@ -57,7 +57,13 @@ Lightllm 整合了众多的开源方案的优点,包括但不限于 FasterTran 思考解析(Reasoning Parser) APIServer 参数详解 lightllm api介绍 - + +.. toctree:: + :maxdepth: 1 + :caption: Cookbook + + GLM-4.7-Flash 部署 + .. toctree:: :maxdepth: 1 :caption: 模型支持 diff --git a/docs/EN/source/cookbook/glm4_deployment.rst b/docs/EN/source/cookbook/glm4_deployment.rst new file mode 100644 index 000000000..c3807f812 --- /dev/null +++ b/docs/EN/source/cookbook/glm4_deployment.rst @@ -0,0 +1,213 @@ +.. _glm4_deployment: + +GLM-4.7-Flash Model Deployment Guide +===================================== + +LightLLM supports deployment of GLM-4.7-Flash (glm4_moe_lite) model family with MoE architecture. This document provides detailed information on deployment configuration, function calling, and MTP (Multi-Token Prediction) support. + +Model Overview +-------------- + +**Key Features:** + +- Grouped MoE with top-k expert selection +- Support for ``vanilla_with_att`` and ``eagle_with_att`` MTP modes +- Compatible with FlashAttention3 backend +- Function calling support with XML-style argument format + +Model Reference: https://huggingface.co/zai-org/GLM-4.7-Flash + +Recommended Launch Script (H200) +-------------------------------- + +**Basic Launch Command:** + +.. code-block:: bash + + LIGHTLLM_TRITON_AUTOTUNE_LEVEL=1 LOADWORKER=18 \ + python -m lightllm.server.api_server \ + --model_dir /path/to/GLM-4.7-Flash/ \ + --tp 1 \ + --max_req_total_len 202752 \ + --chunked_prefill_size 8192 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend flashinfer \ + --graph_max_batch_size 512 \ + --tool_call_parser glm47 \ + --reasoning_parser glm45 \ + --host 0.0.0.0 \ + --port 8000 + +**Parameter Description:** + +- ``LIGHTLLM_TRITON_AUTOTUNE_LEVEL=1``: Enable Triton autotuning for optimal kernel performance +- ``LOADWORKER=18``: Number of model loading threads for faster weight loading +- ``--tp 1``: Tensor parallelism degree (single GPU) +- ``--max_req_total_len 202752``: Maximum total request length +- ``--chunked_prefill_size 8192``: Chunk size for prefill processing +- ``--llm_prefill_att_backend fa3``: Use FlashAttention3 for prefill +- ``--llm_decode_att_backend flashinfer``: Use FlashInfer for decode +- ``--graph_max_batch_size 512``: Maximum batch size for CUDA graph +- ``--tool_call_parser glm47``: Use GLM-4.7 function calling parser +- ``--reasoning_parser glm45``: Use GLM-4.5 reasoning parser + +MTP (Multi-Token Prediction) Mode +--------------------------------- + +To enable MTP for speculative decoding, add the following parameters: + +.. code-block:: bash + + LIGHTLLM_TRITON_AUTOTUNE_LEVEL=1 LOADWORKER=18 \ + python -m lightllm.server.api_server \ + --model_dir /path/to/GLM-4.7-Flash/ \ + --tp 1 \ + --max_req_total_len 202752 \ + --chunked_prefill_size 8192 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend flashinfer \ + --graph_max_batch_size 512 \ + --tool_call_parser glm47 \ + --reasoning_parser glm45 \ + --mtp_step 4 \ + --mtp_mode eagle_with_att \ + --mtp_draft_model_dir /path/to/GLM-4.7-Flash/ \ + --host 0.0.0.0 \ + --port 8000 + +**MTP Parameters:** + +- ``--mtp_step 4``: Number of tokens to predict in each MTP step +- ``--mtp_mode eagle_with_att``: MTP mode (supports ``vanilla_with_att`` and ``eagle_with_att``) +- ``--mtp_draft_model_dir``: Path to the draft model for MTP + +Function Calling Support +------------------------ + +GLM-4.7-Flash uses a new ``Glm47Detector`` class for parsing XML-style tool calls. + +**Function Call Format:** + +.. code-block:: xml + + func_name + keyvalue + + +**Features:** + +- Full streaming support for incremental parsing +- Compatible with OpenAI-style function calling API + +Testing and Validation +---------------------- + +Basic Functionality Testing +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + curl http://localhost:8000/generate \ + -H "Content-Type: application/json" \ + -d '{ + "inputs": "What is AI?", + "parameters":{ + "max_new_tokens": 100, + "frequency_penalty": 1 + } + }' + +OpenAI-Compatible Chat Completions +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "GLM-4.7-Flash", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 100 + }' + +Performance Benchmarks +---------------------- + +Function Calling Test Results (BFCL v3) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. list-table:: + :header-rows: 1 + :widths: 40 20 + + * - Category + - LightLLM + * - simple + - 62.50% + * - multiple + - 54.50% + * - parallel + - 69.50% + * - parallel_multiple + - 61.50% + * - java + - 66.00% + * - javascript + - 48.00% + * - irrelevance + - 83.33% + * - live_simple + - 45.74% + * - live_multiple + - 34.00% + * - live_parallel + - 25.00% + * - live_parallel_multiple + - 37.50% + * - rest + - 2.86% + * - sql + - 28.00% + * - **OVERALL** + - **49.12%** + +Speed Test Results (ShareGPT 2000 prompts, 4×H200) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. list-table:: + :header-rows: 1 + :widths: 30 20 20 30 + + * - Workload + - Output (tok/s) + - TTFT (ms) + - E2E Latency (ms) + * - burst + - 6442 + - 11476 + - 27719 + * - high-conc (512) + - **6728** + - 1099 + - 11240 + * - moderate (10 req/s) + - 1798 + - 196 + - 5746 + * - steady (5 req/s) + - 917 + - 154 + - 2797 + +Hardware Requirements +--------------------- + +**Tested Configuration:** + +- 4× NVIDIA H200 (80GB HBM3 each) +- NVLink 4.0 interconnect + +**Minimum Requirements:** + +- Single NVIDIA H100/H200 GPU with 80GB memory for basic deployment +- Multiple GPUs recommended for production workloads diff --git a/docs/EN/source/index.rst b/docs/EN/source/index.rst index 07eaaa42e..5ad3c63c1 100755 --- a/docs/EN/source/index.rst +++ b/docs/EN/source/index.rst @@ -56,7 +56,13 @@ Documentation List Reasoning Parser APIServer Parameters Lightllm API Introduction - + +.. toctree:: + :maxdepth: 1 + :caption: Cookbook + + GLM-4.7-Flash Deployment + .. toctree:: :maxdepth: 1 :caption: Model Support diff --git a/lightllm/common/basemodel/attention/flashinfer/mla.py b/lightllm/common/basemodel/attention/flashinfer/mla.py index 6c74b22e1..c9e831829 100644 --- a/lightllm/common/basemodel/attention/flashinfer/mla.py +++ b/lightllm/common/basemodel/attention/flashinfer/mla.py @@ -16,6 +16,8 @@ def __init__(self, model): self.qk_nope_head_dim = model.qk_nope_head_dim self.qk_rope_head_dim = model.qk_rope_head_dim self.kv_lora_rank = model.kv_lora_rank + # v_head_dim may differ from qk_nope_head_dim (e.g., GLM-4.7-Flash: v_head_dim=256, qk_nope_head_dim=192) + self.v_head_dim = getattr(model, "v_head_dim", self.qk_nope_head_dim) self.q_data_type = model.data_type self.kv_data_type = model.data_type self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) @@ -69,7 +71,7 @@ def init_state(self): num_qo_heads=self.backend.tp_q_head_num, num_kv_heads=self.backend.tp_q_head_num, head_dim_qk=self.backend.qk_nope_head_dim + self.backend.qk_rope_head_dim, - head_dim_vo=self.backend.qk_nope_head_dim, + head_dim_vo=self.backend.v_head_dim, q_data_type=self.backend.q_data_type, causal=True, sm_scale=self.backend.softmax_scale, @@ -101,7 +103,7 @@ def _mla_prefill_att( ) -> torch.Tensor: self.backend: MlaFlashInferAttBackend = self.backend # for typing k_nope, k_rope = k - o_tensor = alloc_func((q.shape[0], q.shape[1], k_nope.shape[2]), q.dtype, device="cuda") + o_tensor = alloc_func((q.shape[0], q.shape[1], v.shape[-1]), q.dtype, device="cuda") q_head_num = q.shape[1] k = torch.cat([k_nope, torch.repeat_interleave(k_rope, q_head_num, dim=-2)], dim=-1) self.prefill_wrapper.run(q, k, v, out=o_tensor) diff --git a/lightllm/common/basemodel/attention/triton/mla.py b/lightllm/common/basemodel/attention/triton/mla.py index 8288193ad..c7edecd10 100644 --- a/lightllm/common/basemodel/attention/triton/mla.py +++ b/lightllm/common/basemodel/attention/triton/mla.py @@ -44,7 +44,7 @@ def _mla_prefill_att( qk_rope_head_dim = 64 q_nope, q_rope = q[:, :, :-qk_rope_head_dim], q[:, :, -qk_rope_head_dim:] - o_tensor = alloc_func(q_nope.shape, dtype=q_nope.dtype, device=q.device) + o_tensor = alloc_func((q_nope.shape[0], q_nope.shape[1], v.shape[-1]), dtype=q_nope.dtype, device=q.device) k_nope, k_rope = k assert att_control.mla_prefill softmax_scale = att_control.mla_prefill_dict["softmax_scale"] diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index e6405e4d7..5c1d2b871 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -1022,6 +1022,7 @@ def _gen_special_model_input(self, token_num: int): "Deepseek3MTPModel" in str(self.__class__) or "Qwen3MOEMTPModel" in str(self.__class__) or "MistralMTPModel" in str(self.__class__) + or "Glm4MoeLiteMTPModel" in str(self.__class__) ) if is_mtp_draft_model: special_model_input["mtp_draft_input_hiddens"] = torch.randn( diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py index 6a9bb79c7..141587ff3 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py @@ -81,8 +81,11 @@ def _fwd_kernel_calcu_index_and_block_seq( vsm_count, batch_size, BLOCK_N: tl.constexpr, + MAX_BATCH_SIZE: tl.constexpr, ): - b_seq_len = tl.load(b_seq_len + tl.arange(0, 2048), mask=tl.arange(0, 2048) < batch_size, other=0) + b_seq_len = tl.load( + b_seq_len + tl.arange(0, MAX_BATCH_SIZE), mask=tl.arange(0, MAX_BATCH_SIZE) < batch_size, other=0 + ) total_token_num = tl.sum(b_seq_len) block_seq = tl.cdiv(total_token_num, vsm_count * 4) @@ -93,9 +96,9 @@ def _fwd_kernel_calcu_index_and_block_seq( cumsum_seq_len = tl.cumsum(block_seq_len) batch_start_index = cumsum_seq_len - block_seq_len tl.store( - mid_o_batch_start_index + tl.arange(0, 2048), + mid_o_batch_start_index + tl.arange(0, MAX_BATCH_SIZE), batch_start_index, - mask=tl.arange(0, 2048) < batch_size, + mask=tl.arange(0, MAX_BATCH_SIZE) < batch_size, ) tl.store(mid_o_decode_att_block_seq, block_seq) @@ -455,7 +458,6 @@ def gqa_token_decode_attention_flash_decoding_vsm( ) if not hasattr(infer_state, "decode_att_block_seq"): - assert batch_size <= 2048 decode_att_block_seq = torch.empty( [ 1, @@ -477,6 +479,7 @@ def gqa_token_decode_attention_flash_decoding_vsm( num_vsm, batch_size, BLOCK_N=run_config["BLOCK_N"], + MAX_BATCH_SIZE=triton.next_power_of_2(batch_size), num_warps=4, ) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py index fb0323cd4..2687adf14 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py @@ -227,7 +227,7 @@ def triton_grouped_topk( scores_buffer = torch.empty((token_num, total_expert_num), dtype=dtype, device="cuda") out_topk_weights = torch.empty((token_num, topk), dtype=torch.float32, device="cuda") - out_topk_ids = torch.empty((token_num, topk), dtype=torch.long, device="cuda") + out_topk_ids = torch.empty((token_num, topk), dtype=torch.int32, device="cuda") assert total_expert_num % num_expert_group == 0 diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py b/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py index 72c3a381e..7ac5a03b5 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py @@ -196,10 +196,12 @@ def select_experts( scoring_func=scoring_func, ) else: - group_score_topk_num = 1 - # for deepseek v3 - if topk_group == 4 and num_expert_group == 8 and top_k == 8: + if correction_bias is not None: group_score_topk_num = 2 + elif topk_group == 4 and num_expert_group == 8 and top_k == 8: + group_score_topk_num = 2 + else: + group_score_topk_num = 1 topk_weights, topk_ids = triton_grouped_topk( hidden_states=hidden_states, diff --git a/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py index 28839b5f5..063181d99 100644 --- a/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py @@ -67,7 +67,6 @@ def gqa_token_decode_attention_flash_decoding( ) if not hasattr(infer_state, "decode_att_block_seq"): - assert batch_size <= 2048 decode_att_block_seq = torch.empty( [ 1, @@ -89,6 +88,7 @@ def gqa_token_decode_attention_flash_decoding( vsm_count, batch_size, BLOCK_N=BLOCK_N, + MAX_BATCH_SIZE=triton.next_power_of_2(batch_size), num_warps=4, ) @@ -134,8 +134,11 @@ def _fwd_kernel_calcu_index_and_block_seq( num_sm, batch_size, BLOCK_N: tl.constexpr, + MAX_BATCH_SIZE: tl.constexpr, ): - b_seq_len = tl.load(b_seq_len_ptr + tl.arange(0, 2048), mask=tl.arange(0, 2048) < batch_size, other=0) + b_seq_len = tl.load( + b_seq_len_ptr + tl.arange(0, MAX_BATCH_SIZE), mask=tl.arange(0, MAX_BATCH_SIZE) < batch_size, other=0 + ) total_token_num = tl.sum(b_seq_len) block_seq = tl.cast(total_token_num / (num_sm * 4), dtype=tl.int32) + 1 @@ -144,6 +147,10 @@ def _fwd_kernel_calcu_index_and_block_seq( block_seq_len = tl.cdiv(b_seq_len, block_seq) cumsum_seq_len = tl.cumsum(block_seq_len) batch_start_index = cumsum_seq_len - block_seq_len - tl.store(mid_o_batch_start_index_ptr + tl.arange(0, 2048), batch_start_index, mask=tl.arange(0, 2048) < batch_size) + tl.store( + mid_o_batch_start_index_ptr + tl.arange(0, MAX_BATCH_SIZE), + batch_start_index, + mask=tl.arange(0, MAX_BATCH_SIZE) < batch_size, + ) tl.store(mid_o_decode_att_block_seq_ptr, block_seq) return diff --git a/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/context_flashattention_nopad_with_v.py b/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/context_flashattention_nopad_with_v.py index be0635182..d79020844 100644 --- a/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/context_flashattention_nopad_with_v.py +++ b/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/context_flashattention_nopad_with_v.py @@ -36,6 +36,9 @@ def _fwd_kernel_with_v( BLOCK_DMODEL: tl.constexpr, BLOCK_ROPE_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + BLOCK_V_DMODEL: tl.constexpr, + ACTUAL_DMODEL: tl.constexpr, + ACTUAL_V_DMODEL: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -53,8 +56,13 @@ def _fwd_kernel_with_v( # initialize offsets offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) + offs_v_d = tl.arange(0, BLOCK_V_DMODEL) offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + + d_mask = offs_d < ACTUAL_DMODEL + v_d_mask = offs_v_d < ACTUAL_V_DMODEL + off_q = (cur_batch_in_q_start_index + offs_m[:, None]) * stride_q_bs + cur_head * stride_q_h + offs_d[None, :] off_q_rope = ( (cur_batch_in_q_start_index + offs_m[:, None]) * stride_q_rope_bs @@ -63,9 +71,10 @@ def _fwd_kernel_with_v( ) off_k = offs_n[None, :] * stride_k_bs + cur_k_head * stride_k_h + offs_d[:, None] off_k_rope = offs_n[None, :] * stride_k_rope_bs + offs_rope_d[:, None] - off_v = offs_n[:, None] * stride_vbs + cur_k_head * stride_vh + offs_d[None, :] + off_v = offs_n[:, None] * stride_vbs + cur_k_head * stride_vh + offs_v_d[None, :] - q = tl.load(Q_nope + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + q_mask = (offs_m[:, None] < cur_batch_seq_len) & d_mask[None, :] + q = tl.load(Q_nope + off_q, mask=q_mask, other=0.0) q_rope = tl.load(Q_rope + off_q_rope, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) k_ptrs = K_nope + off_k @@ -75,7 +84,7 @@ def _fwd_kernel_with_v( # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_V_DMODEL], dtype=tl.float32) block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) block_end_loc = tl.minimum((start_m + 1) * BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len) @@ -83,14 +92,16 @@ def _fwd_kernel_with_v( for start_n in range(0, block_mask * block_end_loc, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- + k_seq_mask = (start_n + offs_n[None, :]) < block_end_loc + k_mask = k_seq_mask & d_mask[:, None] k = tl.load( k_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_k_bs, - mask=(start_n + offs_n[None, :]) < block_end_loc, + mask=k_mask, other=0.0, ) k_rope = tl.load( k_rope_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_k_rope_bs, - mask=(start_n + offs_n[None, :]) < block_end_loc, + mask=k_seq_mask, other=0.0, ) @@ -112,9 +123,11 @@ def _fwd_kernel_with_v( # -- update output accumulator -- acc = acc * alpha[:, None] # update acc + v_seq_mask = (start_n + offs_n[:, None]) < block_end_loc + v_mask = v_seq_mask & v_d_mask[None, :] v = tl.load( v_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < block_end_loc, + mask=v_mask, other=0.0, ) p = p.to(v.dtype) @@ -124,9 +137,10 @@ def _fwd_kernel_with_v( acc = acc / l_i[:, None] # initialize pointers to output - off_o = (cur_batch_in_q_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] + off_o = (cur_batch_in_q_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_v_d[None, :] out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + o_mask = (offs_m[:, None] < cur_batch_seq_len) & v_d_mask[None, :] + tl.store(out_ptrs, acc, mask=o_mask) return @@ -149,13 +163,14 @@ def context_attention_fwd_with_v( BLOCK = 128 if not is_tesla() else 64 q_nope_dim = q_nope.shape[-1] q_rope_dim = q_rope.shape[-1] + v_dim = v.shape[-1] assert q_nope_dim == k_nope.shape[-1] assert q_rope_dim == k_rope.shape[-1] - assert q_nope_dim in {16, 32, 64, 128, 256, 512} - assert q_rope_dim in {16, 32, 64, 128, 256} - assert q_nope_dim == v.shape[-1] - if q_nope_dim >= 512: + q_nope_dim_padded = triton.next_power_of_2(q_nope_dim) + v_dim_padded = triton.next_power_of_2(v_dim) + + if q_nope_dim_padded >= 512 or v_dim_padded >= 512: BLOCK = 64 if not is_tesla() else 32 else: BLOCK = 128 if not is_tesla() else 64 @@ -167,7 +182,7 @@ def context_attention_fwd_with_v( batch, head = b_seq_len.shape[0], q_nope.shape[1] grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - num_warps = 4 if q_nope_dim <= 64 else 8 + num_warps = 4 if q_nope_dim_padded <= 64 else 8 _fwd_kernel_with_v[grid]( q_nope, @@ -194,9 +209,12 @@ def context_attention_fwd_with_v( o.stride(1), b_prompt_cache_len=b_prompt_cache_len, BLOCK_M=BLOCK, - BLOCK_DMODEL=q_nope_dim, + BLOCK_DMODEL=q_nope_dim_padded, BLOCK_ROPE_DMODEL=q_rope_dim, BLOCK_N=BLOCK, + BLOCK_V_DMODEL=v_dim_padded, + ACTUAL_DMODEL=q_nope_dim, + ACTUAL_V_DMODEL=v_dim, num_warps=num_warps, num_stages=1, ) diff --git a/lightllm/common/quantization/no_quant.py b/lightllm/common/quantization/no_quant.py index 3bf023f8a..fa926ad6f 100644 --- a/lightllm/common/quantization/no_quant.py +++ b/lightllm/common/quantization/no_quant.py @@ -28,8 +28,8 @@ def apply( device = input_tensor.device if use_custom_tensor_mananger: out = g_cache_manager.alloc_tensor(shape, dtype, device=device) - else: - out = torch.empty(shape, dtype=dtype, device=device) + else: + out = torch.empty(shape, dtype=dtype, device=device) if bias is None: return torch.mm(input_tensor, weight, out=out) return torch.addmm(bias, input_tensor, weight, out=out) diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..ea5ca845d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,119 @@ +{ + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "400": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "65536": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "66560": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=65,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=65,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..588bc5a2a --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=65,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "10240": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "1280": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "160": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "20480": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "320": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "40": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "5": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "500": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "5120": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "640": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "80": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "83200": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=1536,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=1536,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..f7df66542 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=1536,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,119 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16640": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..27b601a56 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,119 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16640": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=65,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=5,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=65,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=5,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..5f05eb2a6 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=65,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=5,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16640": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 5, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..72dc716ac --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,119 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16640": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..ceba290e6 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,119 @@ +{ + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "400": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "65536": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "66560": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=768,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=768,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..063afa697 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=768,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,119 @@ +{ + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "400": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "65536": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "66560": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=4}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=4}_NVIDIA_H200.json new file mode 100644 index 000000000..1764ad455 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=4}_NVIDIA_H200.json @@ -0,0 +1,54 @@ +{ + "1": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "100": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE": 128, + "num_warps": 2 + }, + "16384": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, + "16640": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, + "2048": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE": 128, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=5}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=5}_NVIDIA_H200.json new file mode 100644 index 000000000..f1d903fd2 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=5}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "16640": { + "BLOCK_SIZE": 128, + "num_warps": 1 + }, + "2048": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "4096": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, + "64": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE": 128, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=4}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=4}_NVIDIA_H200.json new file mode 100644 index 000000000..8da1f07f1 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=4}_NVIDIA_H200.json @@ -0,0 +1,80 @@ +{ + "1": { + "BLOCK_DIM": 64, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "100": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "128": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 1 + }, + "16": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16384": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "16640": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 1 + }, + "256": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "4096": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "64": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 16 + }, + "8": { + "BLOCK_DIM": 64, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=5}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=5}_NVIDIA_H200.json new file mode 100644 index 000000000..610acfaa5 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=5}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 128, + "BLOCK_M": 16, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "100": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 2 + }, + "128": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "16": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16640": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "256": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "32": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "4096": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 1 + }, + "64": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "8": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=10,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=10,dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..1b75882e7 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=10,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,80 @@ +{ + "1": { + "BLOCK_SEQ": 32, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 3, + "num_warps": 8 + }, + "100": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 5, + "num_warps": 1 + }, + "1024": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 2, + "num_stages": 5, + "num_warps": 2 + }, + "128": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 8, + "num_stages": 4, + "num_warps": 1 + }, + "16": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 2, + "num_warps": 2 + }, + "16384": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 1, + "num_stages": 2, + "num_warps": 1 + }, + "16640": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 1, + "num_stages": 2, + "num_warps": 1 + }, + "2048": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 2, + "num_stages": 2, + "num_warps": 1 + }, + "256": { + "BLOCK_SEQ": 2, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 1, + "num_warps": 2 + }, + "32": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 2, + "num_warps": 8 + }, + "4096": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 1, + "num_stages": 5, + "num_warps": 1 + }, + "64": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 5, + "num_warps": 1 + }, + "8": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 2, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=20,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=20,dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..056b6a747 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=20,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,80 @@ +{ + "1": { + "BLOCK_SEQ": 32, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 5, + "num_warps": 8 + }, + "100": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 3, + "num_warps": 1 + }, + "1024": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 2, + "num_stages": 5, + "num_warps": 1 + }, + "128": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 8, + "num_stages": 5, + "num_warps": 1 + }, + "16": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 3, + "num_warps": 2 + }, + "16384": { + "BLOCK_SEQ": 8, + "HEAD_PARALLEL_NUM": 2, + "num_stages": 5, + "num_warps": 1 + }, + "16640": { + "BLOCK_SEQ": 4, + "HEAD_PARALLEL_NUM": 1, + "num_stages": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 2, + "num_stages": 5, + "num_warps": 1 + }, + "256": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 8, + "num_stages": 1, + "num_warps": 1 + }, + "32": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 4, + "num_warps": 4 + }, + "4096": { + "BLOCK_SEQ": 2, + "HEAD_PARALLEL_NUM": 2, + "num_stages": 3, + "num_warps": 1 + }, + "64": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 1, + "num_warps": 1 + }, + "8": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=5,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=5,dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..0c480f44e --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=5,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,80 @@ +{ + "1": { + "BLOCK_SEQ": 2, + "HEAD_PARALLEL_NUM": 8, + "num_stages": 3, + "num_warps": 1 + }, + "100": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 8, + "num_stages": 5, + "num_warps": 4 + }, + "1024": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 2, + "num_stages": 4, + "num_warps": 2 + }, + "128": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 8, + "num_stages": 1, + "num_warps": 1 + }, + "16": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 2, + "num_warps": 1 + }, + "16384": { + "BLOCK_SEQ": 4, + "HEAD_PARALLEL_NUM": 1, + "num_stages": 2, + "num_warps": 1 + }, + "16640": { + "BLOCK_SEQ": 4, + "HEAD_PARALLEL_NUM": 1, + "num_stages": 2, + "num_warps": 1 + }, + "2048": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 1, + "num_stages": 3, + "num_warps": 1 + }, + "256": { + "BLOCK_SEQ": 2, + "HEAD_PARALLEL_NUM": 8, + "num_stages": 3, + "num_warps": 2 + }, + "32": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 8, + "num_stages": 1, + "num_warps": 1 + }, + "4096": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 1, + "num_stages": 3, + "num_warps": 1 + }, + "64": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 8, + "num_stages": 5, + "num_warps": 8 + }, + "8": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 8, + "num_stages": 3, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=10240,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=10240,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..70f6af6d6 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=10240,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,80 @@ +{ + "1": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "100": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "1024": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16640": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 32, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "32": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "4096": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..71bc9d341 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,188 @@ +{ + "1": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "100": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "1024": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "10240": { + "BLOCK_M": 32, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "1280": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "160": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "16384": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16640": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "20480": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 8 + }, + "320": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "4": { + "BLOCK_M": 1, + "BLOCK_N": 32, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "40": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "400": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "4096": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "5": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 8 + }, + "500": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "5120": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "640": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "65536": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "66560": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "80": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "83200": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=2560,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=2560,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..1bb9f7370 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=2560,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,80 @@ +{ + "1": { + "BLOCK_M": 32, + "BLOCK_N": 64, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "100": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "1024": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "16": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 8 + }, + "16384": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16640": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "32": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 8 + }, + "4096": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 8 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..f133810ae --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,116 @@ +{ + "1": { + "BLOCK_M": 128, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 8 + }, + "100": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "1024": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "16": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 8 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16640": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 8 + }, + "32": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 8 + }, + "4": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 4, + "num_warps": 8 + }, + "400": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "4096": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 8 + }, + "65536": { + "BLOCK_M": 64, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "66560": { + "BLOCK_M": 64, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=5120,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=5120,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..796d8f01b --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=5120,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,80 @@ +{ + "1": { + "BLOCK_M": 8, + "BLOCK_N": 32, + "NUM_STAGES": 2, + "num_warps": 8 + }, + "100": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "1024": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "16": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 8 + }, + "16384": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16640": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 8 + }, + "32": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 8 + }, + "4096": { + "BLOCK_M": 32, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=768,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=768,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..dfd9dd729 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=768,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,116 @@ +{ + "1": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 8 + }, + "100": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "1024": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "16": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "16384": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16640": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "32": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "4": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "400": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "4096": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "65536": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "66560": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotuner.py b/lightllm/common/triton_utils/autotuner.py index a919f7b28..c62a2572f 100644 --- a/lightllm/common/triton_utils/autotuner.py +++ b/lightllm/common/triton_utils/autotuner.py @@ -62,7 +62,7 @@ def autotune( as needed before invocation. """ - def decorator(fn): + def decorator(fn: Callable) -> Callable: return Autotuner( fn=fn, kernel_name=kernel_name, diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 539b32dec..095f73679 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -18,6 +18,7 @@ from lightllm.models.gemma_2b.model import Gemma_2bTpPartModel from lightllm.models.phi3.model import Phi3TpPartModel from lightllm.models.deepseek2.model import Deepseek2TpPartModel +from lightllm.models.glm4_moe_lite.model import Glm4MoeLiteTpPartModel from lightllm.models.internvl.model import ( InternVLLlamaTpPartModel, InternVLPhi3TpPartModel, diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index e1e435cce..98cc7c229 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -33,7 +33,7 @@ def __init__(self, layer_num, network_config): self.is_moe = ( network_config["n_routed_experts"] is not None and layer_num >= network_config["first_k_dense_replace"] - and layer_num % network_config["moe_layer_freq"] == 0 + and layer_num % network_config.get("moe_layer_freq", 1) == 0 ) self.n_shared_experts = network_config["n_shared_experts"] @@ -65,10 +65,10 @@ def _bind_ffn(self): if self.is_moe: enable_ep_moe = get_env_start_args().enable_ep_moe if enable_ep_moe: - self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn_edp, self) + self._ffn = self._moe_ffn_edp self._tpsp_ffn = self._tpsp_ffn_ep else: - self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn, self) + self._ffn = self._moe_ffn self._tpsp_ffn = self._tpsp_ffn_tp else: self._ffn = partial(LlamaTransformerLayerInfer._ffn, self) @@ -257,7 +257,7 @@ def _get_o( ) -> torch.Tensor: if input.shape[2] == self.kv_lora_rank: input = layer_weight.v_b_proj_.bmm(input.transpose(0, 1)).transpose(0, 1) - o_tensor = layer_weight.o_weight_.mm(input.reshape(-1, self.tp_q_head_num_ * self.qk_nope_head_dim)) + o_tensor = layer_weight.o_weight_.mm(input.reshape(-1, self.tp_q_head_num_ * self.v_head_dim)) return o_tensor def _tpsp_get_o( @@ -269,7 +269,7 @@ def _tpsp_get_o( if input.shape[2] == self.kv_lora_rank: input = layer_weight.v_b_proj_.bmm(input.transpose(0, 1)).transpose(0, 1) - input = input.reshape(-1, self.tp_q_head_num_ * self.qk_nope_head_dim) + input = input.reshape(-1, self.tp_q_head_num_ * self.v_head_dim) dest_size = triton.cdiv(input.shape[0], self.tp_world_size_) * self.tp_world_size_ o_tensor = self.alloc_tensor((dest_size, self.embed_dim_), dtype=input.dtype, device=input.device) layer_weight.o_weight_.mm(input, out=o_tensor[0 : len(infer_state.input_ids), :]) @@ -302,7 +302,8 @@ def _moe_ffn( if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0: shared_output = LlamaTransformerLayerInfer._ffn(self, hidden_states, infer_state, layer_weight) - router_logits = layer_weight.moe_gate.mm(hidden_states) + moe_gate_dtype = layer_weight.moe_gate.data_type_ + router_logits = layer_weight.moe_gate.mm(hidden_states.to(moe_gate_dtype)) layer_weight.experts.experts( hidden_states, router_logits=router_logits, @@ -327,7 +328,8 @@ def _moe_ffn_edp( if self.n_shared_experts is not None: shared_output = LlamaTransformerLayerInfer._ffn(self, hidden_states, infer_state, layer_weight) - router_logits = layer_weight.moe_gate.mm(hidden_states) + moe_gate_dtype = layer_weight.moe_gate.data_type_ + router_logits = layer_weight.moe_gate.mm(hidden_states.to(moe_gate_dtype)) ep_output = layer_weight.experts.experts( hidden_states, router_logits=router_logits, @@ -405,7 +407,8 @@ def overlap_tpsp_token_forward( input_embdings.add_(_0_o.view(-1, self.embed_dim_)) _0_o = None _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) - _0_router_logits = layer_weight.moe_gate.mm(_0_input1) + moe_gate_dtype = layer_weight.moe_gate.data_type_ + _0_router_logits = layer_weight.moe_gate.mm(_0_input1.to(moe_gate_dtype)) # 1 hook if getattr(infer_state1, "hook", None) is not None: infer_state1.hook() @@ -439,7 +442,8 @@ def overlap_tpsp_token_forward( _1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight) # to do gate and disptatch - _1_router_logits = layer_weight.moe_gate.mm(_1_input1) + moe_gate_dtype = layer_weight.moe_gate.data_type_ + _1_router_logits = layer_weight.moe_gate.mm(_1_input1.to(moe_gate_dtype)) # 0 hook if getattr(infer_state, "hook", None) is not None: infer_state.hook() @@ -529,7 +533,8 @@ def overlap_tpsp_context_forward( input_embdings.add_(_0_o.view(-1, self.embed_dim_)) _0_o = None _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) - _0_router_logits = layer_weight.moe_gate.mm(_0_input1) + moe_gate_dtype = layer_weight.moe_gate.data_type_ + _0_router_logits = layer_weight.moe_gate.mm(_0_input1.to(moe_gate_dtype)) # wait last 1 combine if getattr(infer_state1, "hook", None) is not None: @@ -556,7 +561,8 @@ def overlap_tpsp_context_forward( _1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight) # to do gate and disptatch - _1_router_logits = layer_weight.moe_gate.mm(_1_input1) + moe_gate_dtype = layer_weight.moe_gate.data_type_ + _1_router_logits = layer_weight.moe_gate.mm(_1_input1.to(moe_gate_dtype)) # 0 dispatch execute ( diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 783e70e64..3eb09f917 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -25,7 +25,7 @@ def _parse_config(self): self.is_moe = ( self.network_config_["n_routed_experts"] is not None and self.layer_num_ >= self.network_config_["first_k_dense_replace"] - and self.layer_num_ % self.network_config_["moe_layer_freq"] == 0 + and self.layer_num_ % self.network_config_.get("moe_layer_freq", 1) == 0 ) self.tp_q_head_num_ = self.network_config_["num_attention_heads"] self.tp_q_head_num_ = self.tp_q_head_num_ // self.tp_world_size_ @@ -65,7 +65,9 @@ def _init_weight(self): self._init_norm() def _split_kv_b_proj(self, kv_b_proj_): - kv_b_proj_ = kv_b_proj_.view(self.num_attention_heads, self.qk_nope_head_dim * 2, self.kv_lora_rank) + kv_b_proj_ = kv_b_proj_.view( + self.num_attention_heads, self.qk_nope_head_dim + self.v_head_dim, self.kv_lora_rank + ) k_b_proj_, v_b_proj_ = torch.split(kv_b_proj_, [self.qk_nope_head_dim, self.v_head_dim], dim=-2) # num_attention_heads x qk_nope_head_dim x kv_lora_rank k_b_proj_ = k_b_proj_.contiguous().to(kv_b_proj_.dtype) diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index f0739a8a8..e596eed97 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -43,6 +43,7 @@ def _init_some_value(self): self.qk_rope_head_dim = self.config["qk_rope_head_dim"] self.q_lora_rank = self.config["q_lora_rank"] self.kv_lora_rank = self.config["kv_lora_rank"] + self.v_head_dim = self.config.get("v_head_dim", self.qk_nope_head_dim) self.head_dim_ = self.kv_lora_rank + self.qk_rope_head_dim def _init_custom(self): diff --git a/lightllm/models/glm4_moe_lite/__init__.py b/lightllm/models/glm4_moe_lite/__init__.py new file mode 100644 index 000000000..b00657090 --- /dev/null +++ b/lightllm/models/glm4_moe_lite/__init__.py @@ -0,0 +1,4 @@ +from lightllm.models.glm4_moe_lite.model import Glm4MoeLiteTpPartModel +from lightllm.models.glm4_moe_lite.infer_struct import Glm4MoeLiteInferStateInfo + +__all__ = ["Glm4MoeLiteTpPartModel", "Glm4MoeLiteInferStateInfo"] diff --git a/lightllm/models/glm4_moe_lite/infer_struct.py b/lightllm/models/glm4_moe_lite/infer_struct.py new file mode 100644 index 000000000..38d879d84 --- /dev/null +++ b/lightllm/models/glm4_moe_lite/infer_struct.py @@ -0,0 +1,6 @@ +from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo + + +class Glm4MoeLiteInferStateInfo(Deepseek2InferStateInfo): + def __init__(self): + super().__init__() diff --git a/lightllm/models/glm4_moe_lite/layer_infer/__init__.py b/lightllm/models/glm4_moe_lite/layer_infer/__init__.py new file mode 100644 index 000000000..a95580535 --- /dev/null +++ b/lightllm/models/glm4_moe_lite/layer_infer/__init__.py @@ -0,0 +1,3 @@ +from lightllm.models.glm4_moe_lite.layer_infer.transformer_layer_infer import Glm4MoeLiteTransformerLayerInfer + +__all__ = ["Glm4MoeLiteTransformerLayerInfer"] diff --git a/lightllm/models/glm4_moe_lite/layer_infer/transformer_layer_infer.py b/lightllm/models/glm4_moe_lite/layer_infer/transformer_layer_infer.py new file mode 100644 index 000000000..6d0041014 --- /dev/null +++ b/lightllm/models/glm4_moe_lite/layer_infer/transformer_layer_infer.py @@ -0,0 +1,17 @@ +from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer + + +class Glm4MoeLiteTransformerLayerInfer(Deepseek2TransformerLayerInfer): + def __init__(self, layer_num, network_config): + self._glm4_layer_num = layer_num + self._glm4_first_k_dense = network_config.get("first_k_dense_replace", 0) + self._glm4_has_routed_experts = network_config.get("n_routed_experts") is not None + super().__init__(layer_num, network_config) + + @property + def is_moe(self): + return self._glm4_has_routed_experts and self._glm4_layer_num >= self._glm4_first_k_dense + + @is_moe.setter + def is_moe(self, value): + pass diff --git a/lightllm/models/glm4_moe_lite/layer_weights/__init__.py b/lightllm/models/glm4_moe_lite/layer_weights/__init__.py new file mode 100644 index 000000000..1fd5e36f8 --- /dev/null +++ b/lightllm/models/glm4_moe_lite/layer_weights/__init__.py @@ -0,0 +1,3 @@ +from lightllm.models.glm4_moe_lite.layer_weights.transformer_layer_weight import Glm4MoeLiteTransformerLayerWeight + +__all__ = ["Glm4MoeLiteTransformerLayerWeight"] diff --git a/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py b/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py new file mode 100644 index 000000000..92a5eabe0 --- /dev/null +++ b/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py @@ -0,0 +1,55 @@ +import torch +from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, FusedMoeWeight + + +class Glm4MoeLiteTransformerLayerWeight(Deepseek2TransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) + + def _parse_config(self): + super()._parse_config() + + self.is_moe = self.network_config_.get( + "n_routed_experts" + ) is not None and self.layer_num_ >= self.network_config_.get("first_k_dense_replace", 0) + + from lightllm.utils.envs_utils import get_env_start_args + + self.num_fused_shared_experts = 0 + if get_env_start_args().enable_fused_shared_experts and self.is_moe: + assert not get_env_start_args().enable_ep_moe, "enable_fused_shared_experts can only work with tp mode." + self.num_fused_shared_experts = self.network_config_.get("n_shared_experts", 0) + + def _init_moe(self): + moe_intermediate_size = self.network_config_["moe_intermediate_size"] + hidden_size = self.network_config_["hidden_size"] + + self.moe_gate = ROWMMWeight( + in_dim=hidden_size, + out_dims=[self.n_routed_experts], + weight_names=f"model.layers.{self.layer_num_}.mlp.gate.weight", + data_type=torch.float32, # Router gate needs float32 for numerical stability + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + + if self.num_fused_shared_experts == 0: + self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts", is_shared_experts=True) + + self.experts = FusedMoeWeight( + gate_proj_name="gate_proj", + down_proj_name="down_proj", + up_proj_name="up_proj", + e_score_correction_bias_name=self.e_score_correction_bias_name, + weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts", + n_routed_experts=self.n_routed_experts, + hidden_size=hidden_size, + moe_intermediate_size=moe_intermediate_size, + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), + num_fused_shared_experts=self.num_fused_shared_experts, + layer_num=self.layer_num_, + network_config=self.network_config_, + ) diff --git a/lightllm/models/glm4_moe_lite/model.py b/lightllm/models/glm4_moe_lite/model.py new file mode 100644 index 000000000..367928f2e --- /dev/null +++ b/lightllm/models/glm4_moe_lite/model.py @@ -0,0 +1,54 @@ +import torch +from lightllm.models.registry import ModelRegistry +from lightllm.models.deepseek2.model import Deepseek2TpPartModel +from lightllm.models.glm4_moe_lite.layer_infer.transformer_layer_infer import Glm4MoeLiteTransformerLayerInfer +from lightllm.models.glm4_moe_lite.layer_weights.transformer_layer_weight import Glm4MoeLiteTransformerLayerWeight +from lightllm.models.glm4_moe_lite.infer_struct import Glm4MoeLiteInferStateInfo +from lightllm.distributed.communication_op import dist_group_manager +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +@ModelRegistry("glm4_moe_lite") +class Glm4MoeLiteTpPartModel(Deepseek2TpPartModel): + + transformer_weight_class = Glm4MoeLiteTransformerLayerWeight + transformer_layer_infer_class = Glm4MoeLiteTransformerLayerInfer + infer_state_class = Glm4MoeLiteInferStateInfo + + def __init__(self, kvargs): + super().__init__(kvargs) + + def _init_config(self): + super()._init_config() + if "scoring_func" not in self.config: + self.config["scoring_func"] = "sigmoid" + + def _init_custom(self): + self._init_to_get_yarn_rotary() + dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"]) + + def _init_to_get_yarn_rotary(self): + rope_scaling = self.config.get("rope_scaling") + + if rope_scaling is None: + self._init_glm4_standard_rotary() + else: + super()._init_to_get_yarn_rotary() + + def _init_glm4_standard_rotary(self): + rope_theta = self.config.get("rope_theta", 1000000.0) + qk_rope_head_dim = self.config.get("qk_rope_head_dim", 64) + max_position_embeddings = self.config.get("max_position_embeddings", 202752) + + dim = qk_rope_head_dim + + inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2, device="cpu", dtype=torch.float32) / dim)) + + max_seq_len = max(max_position_embeddings, self.max_seq_length) + t = torch.arange(max_seq_len, device="cpu", dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + + self._cos_cached = torch.cos(freqs).to(self.data_type).cuda() + self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() diff --git a/lightllm/models/glm4_moe_lite_mtp/__init__.py b/lightllm/models/glm4_moe_lite_mtp/__init__.py new file mode 100644 index 000000000..96b6659c8 --- /dev/null +++ b/lightllm/models/glm4_moe_lite_mtp/__init__.py @@ -0,0 +1,3 @@ +from lightllm.models.glm4_moe_lite_mtp.model import Glm4MoeLiteMTPModel + +__all__ = ["Glm4MoeLiteMTPModel"] diff --git a/lightllm/models/glm4_moe_lite_mtp/layer_infer/__init__.py b/lightllm/models/glm4_moe_lite_mtp/layer_infer/__init__.py new file mode 100644 index 000000000..e357bfa19 --- /dev/null +++ b/lightllm/models/glm4_moe_lite_mtp/layer_infer/__init__.py @@ -0,0 +1,3 @@ +from lightllm.models.glm4_moe_lite_mtp.layer_infer.pre_layer_infer import Glm4MoeLiteMTPPreLayerInfer + +__all__ = ["Glm4MoeLiteMTPPreLayerInfer"] diff --git a/lightllm/models/glm4_moe_lite_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/glm4_moe_lite_mtp/layer_infer/pre_layer_infer.py new file mode 100644 index 000000000..1f2dc71d5 --- /dev/null +++ b/lightllm/models/glm4_moe_lite_mtp/layer_infer/pre_layer_infer.py @@ -0,0 +1,82 @@ +import torch + +from lightllm.models.glm4_moe_lite_mtp.layer_weights.pre_and_post_layer_weight import ( + Glm4MoeLiteMTPPreAndPostLayerWeight, +) +from lightllm.models.glm4_moe_lite.infer_struct import Glm4MoeLiteInferStateInfo +from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer + + +class Glm4MoeLiteMTPPreLayerInfer(LlamaPreLayerInfer): + def __init__(self, network_config): + super().__init__(network_config) + self.eps_ = network_config["rms_norm_eps"] + self.hidden_size = network_config["hidden_size"] + + def _mtp_context_forward( + self, + input_embdings, + infer_state: Glm4MoeLiteInferStateInfo, + layer_weight: Glm4MoeLiteMTPPreAndPostLayerWeight, + ): + tgt_embdings = infer_state.mtp_draft_input_hiddens + assert ( + input_embdings.shape[0] == tgt_embdings.shape[0] + ), f"shape {input_embdings.shape} != shape {tgt_embdings.shape}" + + layer_weight.enorm_weight_( + input=input_embdings, + eps=self.eps_, + out=input_embdings, + ) + layer_weight.hnorm_weight_( + input=tgt_embdings, + eps=self.eps_, + out=tgt_embdings, + ) + cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) + + ans_logics = layer_weight.eh_proj_weight_.mm(cat_embdings) + return ans_logics + + def _mtp_token_forward( + self, + input_embdings, + infer_state: Glm4MoeLiteInferStateInfo, + layer_weight: Glm4MoeLiteMTPPreAndPostLayerWeight, + ): + tgt_embdings = infer_state.mtp_draft_input_hiddens + assert input_embdings.shape[0] == tgt_embdings.shape[0] + + layer_weight.enorm_weight_( + input=input_embdings, + eps=self.eps_, + out=input_embdings, + ) + layer_weight.hnorm_weight_( + input=tgt_embdings, + eps=self.eps_, + out=tgt_embdings, + ) + cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) + + ans_logics = layer_weight.eh_proj_weight_.mm(cat_embdings) + return ans_logics + + def context_forward( + self, + input_ids, + infer_state: Glm4MoeLiteInferStateInfo, + layer_weight: Glm4MoeLiteMTPPreAndPostLayerWeight, + ): + input_embdings = super().context_forward(input_ids, infer_state, layer_weight) + return self._mtp_context_forward(input_embdings, infer_state, layer_weight) + + def token_forward( + self, + input_ids, + infer_state: Glm4MoeLiteInferStateInfo, + layer_weight: Glm4MoeLiteMTPPreAndPostLayerWeight, + ): + input_embdings = super().token_forward(input_ids, infer_state, layer_weight) + return self._mtp_token_forward(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/glm4_moe_lite_mtp/layer_weights/__init__.py b/lightllm/models/glm4_moe_lite_mtp/layer_weights/__init__.py new file mode 100644 index 000000000..57fe578cf --- /dev/null +++ b/lightllm/models/glm4_moe_lite_mtp/layer_weights/__init__.py @@ -0,0 +1,5 @@ +from lightllm.models.glm4_moe_lite_mtp.layer_weights.pre_and_post_layer_weight import ( + Glm4MoeLiteMTPPreAndPostLayerWeight, +) + +__all__ = ["Glm4MoeLiteMTPPreAndPostLayerWeight"] diff --git a/lightllm/models/glm4_moe_lite_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/glm4_moe_lite_mtp/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..78b4d411c --- /dev/null +++ b/lightllm/models/glm4_moe_lite_mtp/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,48 @@ +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + EmbeddingWeight, + LMHeadWeight, + RMSNormWeight, + ROWMMWeight, +) +from lightllm.common.quantization import Quantcfg + + +class Glm4MoeLiteMTPPreAndPostLayerWeight(PreAndPostLayerWeight): + def __init__(self, data_type, network_config, quant_cfg: Quantcfg): + super().__init__(data_type, network_config) + self.quant_cfg: Quantcfg = quant_cfg + + mtp_layer_idx = network_config["num_hidden_layers"] + hidden_size = network_config["hidden_size"] + + self.eh_proj_weight_ = ROWMMWeight( + in_dim=hidden_size * 2, + out_dims=[hidden_size], + weight_names=f"model.layers.{mtp_layer_idx}.eh_proj.weight", + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(mtp_layer_idx, "eh_proj"), + tp_rank=0, + tp_world_size=1, + ) + + self.enorm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=f"model.layers.{mtp_layer_idx}.enorm.weight", + data_type=self.data_type_, + ) + + self.hnorm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=f"model.layers.{mtp_layer_idx}.hnorm.weight", + data_type=self.data_type_, + ) + + self.final_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=f"model.layers.{mtp_layer_idx}.shared_head.norm.weight", + data_type=self.data_type_, + ) + + self.wte_weight_: EmbeddingWeight = None + self.lm_head_weight_: LMHeadWeight = None diff --git a/lightllm/models/glm4_moe_lite_mtp/model.py b/lightllm/models/glm4_moe_lite_mtp/model.py new file mode 100644 index 000000000..dea415abb --- /dev/null +++ b/lightllm/models/glm4_moe_lite_mtp/model.py @@ -0,0 +1,90 @@ +from typing import List +from lightllm.models.glm4_moe_lite.model import Glm4MoeLiteTpPartModel +from lightllm.models.glm4_moe_lite_mtp.layer_infer.pre_layer_infer import Glm4MoeLiteMTPPreLayerInfer +from lightllm.models.glm4_moe_lite_mtp.layer_weights.pre_and_post_layer_weight import ( + Glm4MoeLiteMTPPreAndPostLayerWeight, +) +from lightllm.common.basemodel import TpPartBaseModel +from lightllm.common.basemodel.basemodel import load_hf_weights + + +class Glm4MoeLiteMTPModel(Glm4MoeLiteTpPartModel): + + pre_and_post_weight_class = Glm4MoeLiteMTPPreAndPostLayerWeight + pre_layer_infer_class = Glm4MoeLiteMTPPreLayerInfer + + def __init__(self, kvargs: dict): + self._pre_init(kvargs) + super().__init__(kvargs) + + def _pre_init(self, kvargs: dict): + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mtp_previous_draft_models: List[TpPartBaseModel] = kvargs.pop("mtp_previous_draft_models") + + def _init_custom(self): + self._cos_cached = self.main_model._cos_cached + self._sin_cached = self.main_model._sin_cached + + def _init_req_manager(self): + self.req_manager = self.main_model.req_manager + + def _init_mem_manager(self): + self.mem_manager = self.main_model.mem_manager + + def _init_weights(self, start_layer_index=None): + assert start_layer_index is None + + mtp_layer_start = self.config["num_hidden_layers"] + num_mtp_layers = self.config.get("num_nextn_predict_layers", 1) + + self.pre_post_weight = self.pre_and_post_weight_class( + self.data_type, network_config=self.config, quant_cfg=self.quant_cfg + ) + + self.trans_layers_weight = [ + self.transformer_weight_class( + i, + self.data_type, + network_config=self.config, + quant_cfg=self.quant_cfg, + ) + for i in range(mtp_layer_start, mtp_layer_start + num_mtp_layers) + ] + + load_hf_weights( + self.data_type, + weight_dir=self.weight_dir_, + pre_post_layer=self.pre_post_weight, + transformer_layer_list=self.trans_layers_weight, + weight_dict=self.weight_dict, + ) + + self.pre_post_weight.verify_load() + [weight.verify_load() for weight in self.trans_layers_weight] + + self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ + self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ + + def _init_infer_layer(self, start_layer_index=None): + assert start_layer_index is None + + self.pre_infer = self.pre_layer_infer_class(network_config=self.config) + self.post_infer = self.post_layer_infer_class(network_config=self.config) + + total_pre_layers_num = len(self.main_model.layers_infer) + total_pre_layers_num += sum( + [len(previous_model.layers_infer) for previous_model in self.mtp_previous_draft_models] + ) + + num_mtp_layers = self.config.get("num_nextn_predict_layers", 1) + self.layers_infer = [ + self.transformer_layer_infer_class(i, network_config=self.config) + for i in range(total_pre_layers_num, total_pre_layers_num + num_mtp_layers) + ] + + def _init_some_value(self): + super()._init_some_value() + self.layers_num = self.config.get("num_nextn_predict_layers", 1) + + def autotune_layers(self): + return self.config.get("num_nextn_predict_layers", 1) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 3cf04a14c..73b9bad4a 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -128,7 +128,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--tool_call_parser", type=str, - choices=["qwen25", "llama3", "mistral", "deepseekv3", "qwen", "deepseekv31"], + choices=["qwen25", "llama3", "mistral", "deepseekv3", "qwen", "deepseekv31", "glm47", "kimi_k2"], default=None, help="tool call parser type", ) @@ -259,7 +259,7 @@ def make_argument_parser() -> argparse.ArgumentParser: ) parser.add_argument("--disable_dynamic_prompt_cache", action="store_true", help="disable dynamic prompt cache") - parser.add_argument("--chunked_prefill_size", type=int, default=None, help="chunked prefill size") + parser.add_argument("--chunked_prefill_size", type=int, default=8192, help="chunked prefill size") parser.add_argument("--disable_chunked_prefill", action="store_true", help="whether to disable chunked prefill") parser.add_argument("--diverse_mode", action="store_true", help="diversity generation mode") parser.add_argument("--token_healing_mode", action="store_true", help="code model infer mode") diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py index 1620cff13..9214715b1 100644 --- a/lightllm/server/function_call_parser.py +++ b/lightllm/server/function_call_parser.py @@ -241,7 +241,7 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami if start_idx >= len(current_text): return StreamingParseResult() - (obj, end_idx) = _partial_json_loads(current_text[start_idx:], flags) + obj, end_idx = _partial_json_loads(current_text[start_idx:], flags) is_current_complete = _is_complete_json(current_text[start_idx : start_idx + end_idx]) @@ -1173,6 +1173,276 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami return StreamingParseResult(normal_text=current_text) +class Glm47Detector(BaseFormatDetector): + """ + Detector for GLM-4.7/GLM-4.7-Flash model function call format. + + The GLM-4.7 format uses an XML-style envelope with arg_key/arg_value pairs + instead of JSON arguments. + + Format Structure: + ``` + function_name + param1 + value1 + param2 + value2 + + ``` + + Example: + ``` + tool_brave_web_search_post + query + test search + count + 5 + + ``` + + Key Components: + - Tool Call Tags: `` and `` wrap each individual call + - Function Name: Appears on the first line after `` + - Arguments: Pairs of `name` and `value` + + Reference: https://github.com/vllm-project/vllm/blob/main/vllm/tool_parsers/glm4_moe_tool_parser.py + """ + + def __init__(self): + super().__init__() + self.bot_token = "" + self.eot_token = "" + self.tool_call_separator = "\n" + + # Regex patterns for parsing GLM-4.7 tool calls + # Match complete tool call blocks + self.func_call_regex = re.compile(r".*?", re.DOTALL) + # Extract function name and arguments from a tool call block + # Function name can be followed by newline OR directly by + # Pattern: function_name(\n|)... + self.func_detail_regex = re.compile( + r"([^<\n]+?)(?:\n|(?=)|(?=))(.*?)", re.DOTALL + ) + # Extract arg_key/arg_value pairs + self.func_arg_regex = re.compile(r"(.*?)\s*(.*?)", re.DOTALL) + + self._last_arguments = "" + self._normal_text_buffer = "" + + def has_tool_call(self, text: str) -> bool: + """Check if the text contains a GLM-4.7 format tool call.""" + return self.bot_token in text + + def _parse_xml_arguments(self, arg_text: str) -> dict: + """ + Parse XML-style arguments into a dictionary. + + Args: + arg_text: The text containing / pairs + + Returns: + Dictionary of argument name to value + """ + if not arg_text: + return {} + + args = {} + matches = self.func_arg_regex.findall(arg_text) + for key, value in matches: + key = key.strip() + value = value.strip() + # Try to parse value as JSON for complex types (arrays, objects, numbers, booleans) + try: + parsed_value = json.loads(value) + args[key] = parsed_value + except (json.JSONDecodeError, ValueError): + # Keep as string if not valid JSON + args[key] = value + return args + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + """ + One-time parsing: Detects and parses tool calls in the provided text. + + :param text: The complete text to parse. + :param tools: List of available tools. + :return: StreamingParseResult with normal_text and parsed calls. + """ + idx = text.find(self.bot_token) + normal_text = text[:idx].strip() if idx != -1 else text + + if self.bot_token not in text: + return StreamingParseResult(normal_text=normal_text, calls=[]) + + tool_indices = self._get_tool_indices(tools) + calls = [] + + # Find all ... blocks + match_result_list = self.func_call_regex.findall(text) + + for match_result in match_result_list: + try: + # Extract function name and arguments + func_detail = self.func_detail_regex.search(match_result) + if not func_detail: + logger.warning(f"Failed to parse GLM-4.7 tool call: {match_result}") + continue + + func_name = func_detail.group(1).strip() + arg_text = func_detail.group(2) if func_detail.group(2) else "" + + # Validate function name + if func_name not in tool_indices: + logger.warning(f"Model attempted to call undefined function: {func_name}") + continue + + # Parse XML arguments to JSON + func_args = self._parse_xml_arguments(arg_text) + + calls.append( + ToolCallItem( + tool_index=tool_indices[func_name], + name=func_name, + parameters=json.dumps(func_args, ensure_ascii=False), + ) + ) + except Exception as e: + logger.warning(f"Failed to parse GLM-4.7 tool call: {match_result}, error: {str(e)}") + continue + + return StreamingParseResult(normal_text=normal_text, calls=calls) + + def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> StreamingParseResult: + """ + Streaming incremental parsing for GLM-4.7 tool calls. + + This handles the streaming case where tool calls arrive incrementally. + """ + self._buffer += new_text + current_text = self._buffer + + # Check if we have a tool call starting + if not self.has_tool_call(current_text): + # Check for partial bot_token at the end + partial_len = self._ends_with_partial_token(current_text, self.bot_token) + if partial_len: + # Might be partial bot_token, keep buffering + return StreamingParseResult() + + # No tool call, emit as normal text + self._buffer = "" + # Clean up any stray end tokens + if self.eot_token in new_text: + new_text = new_text.replace(self.eot_token, "") + return StreamingParseResult(normal_text=new_text) + + # Build tool indices if not already built + if not hasattr(self, "_tool_indices"): + self._tool_indices = self._get_tool_indices(tools) + + calls: List[ToolCallItem] = [] + + try: + # Check if we have a complete tool call + if self.eot_token in current_text: + # We have at least one complete tool call + # Parse all complete tool calls + result = self.detect_and_parse(current_text, tools) + + # Find the end of the last complete tool call + last_end = current_text.rfind(self.eot_token) + if last_end != -1: + remaining = current_text[last_end + len(self.eot_token) :] + self._buffer = remaining.lstrip() + else: + self._buffer = "" + + # Reset state for next tool call + self.current_tool_id = -1 + self.current_tool_name_sent = False + self._last_arguments = "" + + return result + + # We have a partial tool call - try to stream it + # Extract what we can from the partial tool call + tool_call_start = current_text.find(self.bot_token) + if tool_call_start == -1: + return StreamingParseResult() + + # Get content after + content_after_start = current_text[tool_call_start + len(self.bot_token) :] + + # Try to extract function name (first line after ) + newline_pos = content_after_start.find("\n") + if newline_pos == -1: + # Still waiting for function name to complete + return StreamingParseResult() + + func_name = content_after_start[:newline_pos].strip() + + # Initialize state if this is the first tool call + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [""] + + # Ensure we have enough entries + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + # Check if function name is valid + if func_name and func_name in self._tool_indices: + if not self.current_tool_name_sent: + # Send function name first + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=func_name, + parameters="", + ) + ) + self.current_tool_name_sent = True + self.prev_tool_call_arr[self.current_tool_id] = { + "name": func_name, + "arguments": {}, + } + else: + # Stream arguments incrementally + arg_text = content_after_start[newline_pos + 1 :] + current_args = self._parse_xml_arguments(arg_text) + + if current_args: + current_args_json = json.dumps(current_args, ensure_ascii=False) + prev_args = self.prev_tool_call_arr[self.current_tool_id].get("arguments", {}) + prev_args_json = json.dumps(prev_args, ensure_ascii=False) if prev_args else "" + + if current_args_json != prev_args_json: + # Calculate the diff + sent = len(self.streamed_args_for_tool[self.current_tool_id]) + argument_diff = current_args_json[sent:] + + if argument_diff: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=argument_diff, + ) + ) + self.streamed_args_for_tool[self.current_tool_id] += argument_diff + + self.prev_tool_call_arr[self.current_tool_id]["arguments"] = current_args + + return StreamingParseResult(normal_text="", calls=calls) + + except Exception as e: + logger.error(f"Error in GLM-4.7 parse_streaming_increment: {e}") + return StreamingParseResult(normal_text="", calls=calls) + + class FunctionCallParser: """ Parser for function/tool calls in model outputs. @@ -1185,6 +1455,7 @@ class FunctionCallParser: ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = { "deepseekv3": DeepSeekV3Detector, "deepseekv31": DeepSeekV31Detector, + "glm47": Glm47Detector, "kimi_k2": KimiK2Detector, "llama3": Llama32Detector, "mistral": MistralDetector, diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 805c9b8e5..64310d6b0 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -37,6 +37,7 @@ from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel from lightllm.models.qwen3_moe_mtp.model import Qwen3MOEMTPModel from lightllm.models.mistral_mtp.model import MistralMTPModel +from lightllm.models.glm4_moe_lite_mtp.model import Glm4MoeLiteMTPModel from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet @@ -328,6 +329,9 @@ def init_mtp_draft_model(self, main_kvargs: dict): elif mtp_model_cfg["model_type"] == "mistral": assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"] self.draft_models.append(MistralMTPModel(mtp_model_kvargs)) + elif mtp_model_cfg["model_type"] == "glm4_moe_lite": + assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] + self.draft_models.append(Glm4MoeLiteMTPModel(mtp_model_kvargs)) else: assert False, f"error mtp mode {mtp_model_cfg['model_type']}" diff --git a/test/acc/bfcl/eval_bfcl.py b/test/acc/bfcl/eval_bfcl.py new file mode 100644 index 000000000..59b81c26d --- /dev/null +++ b/test/acc/bfcl/eval_bfcl.py @@ -0,0 +1,686 @@ +#!/usr/bin/env python3 +""" +BFCL (Berkeley Function Calling Leaderboard) Evaluation Script for LightLLM + +This script evaluates function/tool calling capabilities on the BFCL benchmark. + +Usage: + # Start LightLLM server first: + python -m lightllm.server.api_server --model_dir /path/to/GLM-4.7-Flash --tp 1 + + # Run evaluation: + python eval_bfcl.py \ + --model_name GLM-4.7-Flash \ + --base_url http://localhost:8000/v1 \ + --test_category simple + +Test Categories: + - simple: Single function calls (400 examples) + - multiple: Select one function from multiple options (200 examples) + - parallel: Multiple function calls in parallel (200 examples) + - parallel_multiple: Combination of parallel and multiple (200 examples) + - java: Java function calls (100 examples) + - javascript: JavaScript function calls (70 examples) + - irrelevance: Detect when no function should be called + - all: Run all categories + +Requirements: + pip install openai tqdm huggingface_hub +""" + +import argparse +import json +import os +import re +import ast +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import List, Dict, Any, Optional, Tuple +from dataclasses import dataclass +from collections import defaultdict + +from tqdm import tqdm + +try: + from openai import OpenAI +except ImportError: + print("Please install openai: pip install openai") + exit(1) + +try: + from huggingface_hub import hf_hub_download +except ImportError: + print("Please install huggingface_hub: pip install huggingface_hub") + exit(1) + + +# BFCL Dataset on HuggingFace +BFCL_REPO = "gorilla-llm/Berkeley-Function-Calling-Leaderboard" + +# Test category mappings to filenames +TEST_CATEGORIES = { + "simple": "BFCL_v3_simple.json", + "multiple": "BFCL_v3_multiple.json", + "parallel": "BFCL_v3_parallel.json", + "parallel_multiple": "BFCL_v3_parallel_multiple.json", + "java": "BFCL_v3_java.json", + "javascript": "BFCL_v3_javascript.json", + "irrelevance": "BFCL_v3_irrelevance.json", + "live_simple": "BFCL_v3_live_simple.json", + "live_multiple": "BFCL_v3_live_multiple.json", + "live_parallel": "BFCL_v3_live_parallel.json", + "live_parallel_multiple": "BFCL_v3_live_parallel_multiple.json", + "rest": "BFCL_v3_rest.json", + "sql": "BFCL_v3_sql.json", +} + +# Possible answer files for ground truth +ANSWER_FILES = { + "simple": "possible_answer/BFCL_v3_simple.json", + "multiple": "possible_answer/BFCL_v3_multiple.json", + "parallel": "possible_answer/BFCL_v3_parallel.json", + "parallel_multiple": "possible_answer/BFCL_v3_parallel_multiple.json", + "java": "possible_answer/BFCL_v3_java.json", + "javascript": "possible_answer/BFCL_v3_javascript.json", + "live_simple": "possible_answer/BFCL_v3_live_simple.json", + "live_multiple": "possible_answer/BFCL_v3_live_multiple.json", + "live_parallel": "possible_answer/BFCL_v3_live_parallel.json", + "live_parallel_multiple": "possible_answer/BFCL_v3_live_parallel_multiple.json", + "sql": "possible_answer/BFCL_v3_sql.json", +} + + +@dataclass +class EvalResult: + """Result of a single evaluation.""" + + task_id: str + category: str + passed: bool + model_output: str + expected: Any + error: Optional[str] = None + + +def download_bfcl_file(filename: str) -> str: + """Download a BFCL file from HuggingFace Hub.""" + try: + local_path = hf_hub_download( + repo_id=BFCL_REPO, + filename=filename, + repo_type="dataset", + ) + return local_path + except Exception as e: + print(f"Error downloading {filename}: {e}") + return None + + +def load_jsonl_or_json(filepath: str) -> List[Dict[str, Any]]: + """Load data from JSON or JSONL file.""" + data = [] + with open(filepath, "r", encoding="utf-8") as f: + content = f.read().strip() + # Try as JSON array first + try: + data = json.loads(content) + if isinstance(data, dict): + data = [data] + except json.JSONDecodeError: + # Try as JSONL + f.seek(0) + for line in f: + line = line.strip() + if line: + try: + data.append(json.loads(line)) + except json.JSONDecodeError: + continue + return data + + +def load_bfcl_data(category: str, limit: Optional[int] = None) -> List[Dict[str, Any]]: + """Load BFCL dataset for a specific category.""" + filename = TEST_CATEGORIES.get(category) + if not filename: + print(f"Unknown category: {category}") + return [] + + print(f"Downloading {filename} from HuggingFace...") + filepath = download_bfcl_file(filename) + if not filepath: + return [] + + print(f"Loading data from {filepath}") + data = load_jsonl_or_json(filepath) + + # Also load ground truth answers if available + answer_file = ANSWER_FILES.get(category) + if answer_file: + print(f"Downloading answer file {answer_file}...") + answer_path = download_bfcl_file(answer_file) + if answer_path: + answers = load_jsonl_or_json(answer_path) + # Create a mapping from id to answer + answer_map = {} + for ans in answers: + ans_id = ans.get("id", "") + answer_map[ans_id] = ans.get("ground_truth", ans.get("result", [])) + + # Merge answers into data + for item in data: + item_id = item.get("id", "") + if item_id in answer_map: + item["ground_truth"] = answer_map[item_id] + + if limit: + data = data[:limit] + + print(f"Loaded {len(data)} examples for category: {category}") + return data + + +def fix_schema_types(schema: Any) -> Any: + """ + Fix Python type names to JSON Schema types. + BFCL uses Python type names like 'dict', 'list' but JSON Schema needs 'object', 'array'. + """ + if isinstance(schema, dict): + result = {} + for key, value in schema.items(): + if key == "type" and isinstance(value, str): + # Map Python types to JSON Schema types + type_mapping = { + "dict": "object", + "list": "array", + "str": "string", + "int": "integer", + "float": "number", + "bool": "boolean", + "NoneType": "null", + "tuple": "array", + } + result[key] = type_mapping.get(value, value) + else: + result[key] = fix_schema_types(value) + return result + elif isinstance(schema, list): + return [fix_schema_types(item) for item in schema] + else: + return schema + + +def convert_to_openai_tools(functions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Convert BFCL function format to OpenAI tools format.""" + tools = [] + for func in functions: + if isinstance(func, str): + func = json.loads(func) + + # Fix the parameters schema to use valid JSON Schema types + parameters = fix_schema_types(func.get("parameters", {})) + + tool = { + "type": "function", + "function": { + "name": func.get("name", ""), + "description": func.get("description", ""), + "parameters": parameters, + }, + } + tools.append(tool) + return tools + + +def parse_function_call(response: str) -> List[Dict[str, Any]]: + """Parse function calls from model response.""" + calls = [] + + # Try to parse as JSON array + try: + parsed = json.loads(response) + if isinstance(parsed, list): + return parsed + elif isinstance(parsed, dict): + return [parsed] + except json.JSONDecodeError: + pass + + # Try to find function call patterns + # Pattern 1: function_name(args) + func_pattern = r"(\w+)\s*\((.*?)\)" + matches = re.findall(func_pattern, response, re.DOTALL) + for name, args_str in matches: + try: + # Try to parse args as Python dict/kwargs + args_str = args_str.strip() + if args_str: + # Convert to dict format + args = eval(f"dict({args_str})") + else: + args = {} + calls.append({"name": name, "arguments": args}) + except: + pass + + # Pattern 2: JSON-like tool_calls + tool_call_pattern = r'\{"name":\s*"([^"]+)",\s*"arguments":\s*(\{[^}]+\})\}' + matches = re.findall(tool_call_pattern, response) + for name, args_str in matches: + try: + args = json.loads(args_str) + calls.append({"name": name, "arguments": args}) + except: + pass + + return calls + + +def extract_tool_calls_from_response(response) -> List[Dict[str, Any]]: + """Extract tool calls from OpenAI API response.""" + calls = [] + + if hasattr(response, "choices") and response.choices: + choice = response.choices[0] + message = choice.message + + # Check for tool_calls in response + if hasattr(message, "tool_calls") and message.tool_calls: + for tool_call in message.tool_calls: + func = tool_call.function + try: + args = json.loads(func.arguments) if func.arguments else {} + except json.JSONDecodeError: + args = {} + calls.append({"name": func.name, "arguments": args}) + + # Also check content for function calls (some models output in content) + if hasattr(message, "content") and message.content: + content_calls = parse_function_call(message.content) + if content_calls and not calls: + calls = content_calls + + return calls + + +def normalize_value(value: Any) -> Any: + """Normalize values for comparison.""" + if isinstance(value, str): + # Try to parse as number + try: + return float(value) + except ValueError: + return value.lower().strip() + elif isinstance(value, bool): + return value + elif isinstance(value, (int, float)): + return float(value) + elif isinstance(value, list): + return [normalize_value(v) for v in value] + elif isinstance(value, dict): + return {k: normalize_value(v) for k, v in value.items()} + return value + + +def value_matches_expected(predicted_value: Any, expected_values: Any) -> bool: + """ + Check if predicted value matches expected value(s). + BFCL format: expected values can be a list of acceptable values. + """ + # Normalize predicted value + pred_normalized = normalize_value(predicted_value) + + # If expected is a list, check if predicted matches any item + if isinstance(expected_values, list): + for exp_val in expected_values: + exp_normalized = normalize_value(exp_val) + if pred_normalized == exp_normalized: + return True + # Also try string comparison for edge cases + if str(pred_normalized) == str(exp_normalized): + return True + return False + else: + exp_normalized = normalize_value(expected_values) + return pred_normalized == exp_normalized or str(pred_normalized) == str(exp_normalized) + + +def compare_function_calls( + predicted: List[Dict[str, Any]], expected: List[Dict[str, Any]], strict: bool = False +) -> Tuple[bool, str]: + """Compare predicted function calls with expected ones.""" + if not predicted and not expected: + return True, "" + + if len(predicted) != len(expected): + return False, f"Count mismatch: predicted {len(predicted)}, expected {len(expected)}" + + # Sort by function name for comparison + pred_sorted = sorted(predicted, key=lambda x: x.get("name", "")) + exp_sorted = sorted(expected, key=lambda x: x.get("name", "")) + + for pred, exp in zip(pred_sorted, exp_sorted): + pred_name = pred.get("name", "") + exp_name = exp.get("name", "") + + if pred_name != exp_name: + return False, f"Function name mismatch: {pred_name} vs {exp_name}" + + pred_args = pred.get("arguments", {}) + exp_args = exp.get("arguments", {}) + + # Check required arguments match (BFCL format: values are lists of acceptable values) + for key, expected_values in exp_args.items(): + if key not in pred_args: + return False, f"Missing argument {key} in {pred_name}" + if not value_matches_expected(pred_args[key], expected_values): + return False, f"Argument {key} mismatch in {pred_name}" + + return True, "" + + +def parse_expected_output(ground_truth: Any) -> List[Dict[str, Any]]: + """ + Parse expected output from BFCL ground truth. + + BFCL format: [{"func_name": {"arg1": [val1, val2], "arg2": [val3]}}] + Convert to: [{"name": "func_name", "arguments": {"arg1": [val1, val2], "arg2": [val3]}}] + """ + if isinstance(ground_truth, str): + try: + ground_truth = json.loads(ground_truth) + except json.JSONDecodeError: + # Try parsing as Python literal + try: + ground_truth = ast.literal_eval(ground_truth) + except: + return [] + + if not ground_truth: + return [] + + # Ensure it's a list + if isinstance(ground_truth, dict): + ground_truth = [ground_truth] + + result = [] + for item in ground_truth: + if isinstance(item, dict): + # Check if it's already in standard format {"name": ..., "arguments": ...} + if "name" in item and "arguments" in item: + result.append(item) + else: + # BFCL format: {"func_name": {"arg1": [v1], "arg2": [v2]}} + for func_name, args in item.items(): + if isinstance(args, dict): + result.append({"name": func_name, "arguments": args}) + else: + # Handle edge case where args might not be a dict + result.append({"name": func_name, "arguments": {}}) + + return result + + +class BFCLEvaluator: + """BFCL Benchmark Evaluator using OpenAI-compatible API.""" + + def __init__( + self, + base_url: str, + model_name: str, + api_key: str = "EMPTY", + max_tokens: int = 1024, + temperature: float = 0.0, + ): + self.client = OpenAI(base_url=base_url, api_key=api_key) + self.model_name = model_name + self.max_tokens = max_tokens + self.temperature = temperature + + def generate_response( + self, prompt: str, tools: List[Dict[str, Any]], system_prompt: Optional[str] = None + ) -> Tuple[Any, List[Dict[str, Any]]]: + """Generate response from the model with tool calling.""" + messages = [] + + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + + messages.append({"role": "user", "content": prompt}) + + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=messages, + tools=tools if tools else None, + tool_choice="auto" if tools else None, + max_tokens=self.max_tokens, + temperature=self.temperature, + ) + tool_calls = extract_tool_calls_from_response(response) + return response, tool_calls + except Exception as e: + print(f"API Error: {e}") + return None, [] + + def evaluate_single(self, item: Dict[str, Any], category: str) -> EvalResult: + """Evaluate a single BFCL example.""" + task_id = item.get("id", "unknown") + + # Extract question and functions + question = item.get("question", [[{"role": "user", "content": ""}]]) + if isinstance(question, str): + prompt = question + elif isinstance(question, list) and question: + if isinstance(question[0], dict): + prompt = question[0].get("content", "") + elif isinstance(question[0], list) and question[0]: + prompt = question[0][0].get("content", "") + else: + prompt = str(question[0]) + else: + prompt = str(question) + + # Get functions + functions = item.get("function", []) + if isinstance(functions, str): + try: + functions = json.loads(functions) + except: + functions = [] + + if not isinstance(functions, list): + functions = [functions] + + # Convert to OpenAI tools format + tools = convert_to_openai_tools(functions) + + # Get expected output + ground_truth = item.get("ground_truth", item.get("answer", [])) + expected = parse_expected_output(ground_truth) + + # Generate response + system_prompt = ( + "You are a helpful assistant that can use tools/functions to help answer questions. " + "When you need to call a function, use the provided tools." + ) + + response, predicted_calls = self.generate_response(prompt, tools, system_prompt) + + if response is None: + return EvalResult( + task_id=task_id, + category=category, + passed=False, + model_output="", + expected=expected, + error="API call failed", + ) + + # For irrelevance category, model should NOT call any function + if "irrelevance" in category.lower(): + passed = len(predicted_calls) == 0 + error = "Model called function when it shouldn't" if not passed else None + else: + # Compare function calls + passed, error = compare_function_calls(predicted_calls, expected) + + model_output = json.dumps(predicted_calls, indent=2) if predicted_calls else str(response) + + return EvalResult( + task_id=task_id, category=category, passed=passed, model_output=model_output, expected=expected, error=error + ) + + def evaluate_category(self, category: str, limit: Optional[int] = None, num_workers: int = 4) -> Dict[str, Any]: + """Evaluate all examples in a category.""" + print(f"\nLoading BFCL dataset for category: {category}") + data = load_bfcl_data(category, limit) + + if not data: + print(f"No data found for category: {category}") + return {"category": category, "total": 0, "passed": 0, "accuracy": 0.0} + + print(f"Loaded {len(data)} examples") + + results = [] + + # Use ThreadPoolExecutor for concurrent evaluation + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = {executor.submit(self.evaluate_single, item, category): item for item in data} + + for future in tqdm(as_completed(futures), total=len(futures), desc=f"Evaluating {category}"): + try: + result = future.result() + results.append(result) + except Exception as e: + print(f"Error evaluating: {e}") + + # Calculate metrics + total = len(results) + passed = sum(1 for r in results if r.passed) + accuracy = passed / total * 100 if total > 0 else 0.0 + + # Collect errors for analysis + errors = defaultdict(int) + for r in results: + if not r.passed and r.error: + errors[r.error[:50]] += 1 + + return { + "category": category, + "total": total, + "passed": passed, + "accuracy": accuracy, + "results": results, + "error_summary": dict(errors), + } + + +def main(): + parser = argparse.ArgumentParser(description="BFCL Evaluation for LightLLM") + parser.add_argument("--model_name", type=str, required=True, help="Model name") + parser.add_argument( + "--base_url", type=str, default="http://localhost:8000/v1", help="OpenAI-compatible API base URL" + ) + parser.add_argument("--api_key", type=str, default="EMPTY", help="API key (use EMPTY for local)") + parser.add_argument( + "--test_category", + type=str, + default="simple", + choices=list(TEST_CATEGORIES.keys()) + ["all"], + help="Test category to evaluate", + ) + parser.add_argument("--limit", type=int, default=None, help="Limit number of examples (for testing)") + parser.add_argument("--num_workers", type=int, default=4, help="Number of concurrent workers") + parser.add_argument("--max_tokens", type=int, default=1024, help="Maximum tokens to generate") + parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature") + parser.add_argument("--output", "-o", type=str, default=None, help="Output file for detailed results") + + args = parser.parse_args() + + print("=" * 60) + print("BFCL (Berkeley Function Calling Leaderboard) Evaluation") + print("=" * 60) + print(f"Model: {args.model_name}") + print(f"API URL: {args.base_url}") + print(f"Test Category: {args.test_category}") + print() + + evaluator = BFCLEvaluator( + base_url=args.base_url, + model_name=args.model_name, + api_key=args.api_key, + max_tokens=args.max_tokens, + temperature=args.temperature, + ) + + # Determine categories to evaluate + if args.test_category == "all": + categories = list(TEST_CATEGORIES.keys()) + else: + categories = [args.test_category] + + all_results = {} + + for category in categories: + result = evaluator.evaluate_category(category, limit=args.limit, num_workers=args.num_workers) + all_results[category] = result + + print(f"\n{category.upper()} Results:") + print(f" Total: {result['total']}") + print(f" Passed: {result['passed']}") + print(f" Accuracy: {result['accuracy']:.2f}%") + + if result.get("error_summary"): + print(" Common errors:") + for error, count in sorted(result["error_summary"].items(), key=lambda x: -x[1])[:5]: + print(f" - {error}: {count}") + + # Print summary + print("\n" + "=" * 60) + print("SUMMARY") + print("=" * 60) + print(f"{'Category':<25} {'Total':>8} {'Passed':>8} {'Accuracy':>10}") + print("-" * 60) + + total_all = 0 + passed_all = 0 + + for category, result in all_results.items(): + print(f"{category:<25} {result['total']:>8} {result['passed']:>8} {result['accuracy']:>9.2f}%") + total_all += result["total"] + passed_all += result["passed"] + + if len(all_results) > 1: + print("-" * 60) + overall_acc = passed_all / total_all * 100 if total_all > 0 else 0 + print(f"{'OVERALL':<25} {total_all:>8} {passed_all:>8} {overall_acc:>9.2f}%") + + print("=" * 60) + + # Save detailed results + if args.output: + output_data = { + "model": args.model_name, + "config": { + "base_url": args.base_url, + "max_tokens": args.max_tokens, + "temperature": args.temperature, + }, + "results": { + cat: { + "total": r["total"], + "passed": r["passed"], + "accuracy": r["accuracy"], + "error_summary": r.get("error_summary", {}), + } + for cat, r in all_results.items() + }, + } + with open(args.output, "w") as f: + json.dump(output_data, f, indent=2) + print(f"\nResults saved to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/test/acc/bfcl/requirements.txt b/test/acc/bfcl/requirements.txt new file mode 100644 index 000000000..e57d2da41 --- /dev/null +++ b/test/acc/bfcl/requirements.txt @@ -0,0 +1,13 @@ +# Evaluation benchmark dependencies +aiohttp>=3.8.0 +tqdm>=4.64.0 +transformers>=4.30.0 +numpy>=1.21.0 +openai>=1.0.0 +huggingface_hub>=0.20.0 + +# Optional: official human-eval package for dataset loading +# pip install git+https://github.com/openai/human-eval.git + +# Optional: official BFCL evaluation package +# pip install bfcl-eval diff --git a/test/acc/bfcl/run_bfcl.sh b/test/acc/bfcl/run_bfcl.sh new file mode 100644 index 000000000..2e68f8380 --- /dev/null +++ b/test/acc/bfcl/run_bfcl.sh @@ -0,0 +1,48 @@ +#!/bin/bash +# BFCL (Berkeley Function Calling Leaderboard) evaluation script for LightLLM +# +# Prerequisites: +# 1. Start LightLLM server: +# python -m lightllm.server.api_server \ +# --model_dir /path/to/GLM-4.7-Flash \ +# --tp 1 \ +# --port 8000 +# +# 2. Install dependencies: +# pip install openai tqdm datasets + +set -e + +# Configuration +MODEL_NAME="${MODEL_NAME:-GLM-4.7-Flash}" +BASE_URL="${BASE_URL:-http://localhost:8000/v1}" +PORT="${PORT:-8000}" +TEST_CATEGORY="${TEST_CATEGORY:-simple}" +NUM_WORKERS="${NUM_WORKERS:-4}" + +# Check if server is running +if ! curl -s "http://localhost:${PORT}/health" > /dev/null 2>&1; then + echo "Error: LightLLM server not running on port ${PORT}" + echo "Start the server first with:" + echo " python -m lightllm.server.api_server --model_dir /path/to/model --tp 1 --port ${PORT}" + exit 1 +fi + +echo "==========================================" +echo "BFCL Function Calling Evaluation" +echo "==========================================" +echo "Model: ${MODEL_NAME}" +echo "Server: ${BASE_URL}" +echo "Test Category: ${TEST_CATEGORY}" +echo "" + +# Run evaluation +python "$(dirname "$0")/eval_bfcl.py" \ + --model_name "${MODEL_NAME}" \ + --base_url "${BASE_URL}" \ + --test_category "${TEST_CATEGORY}" \ + --num_workers "${NUM_WORKERS}" \ + --output "bfcl_results_${TEST_CATEGORY}_$(date +%Y%m%d_%H%M%S).json" + +echo "" +echo "Evaluation complete!"