diff --git a/docs/CN/source/cookbook/glm4_deployment.rst b/docs/CN/source/cookbook/glm4_deployment.rst
new file mode 100644
index 000000000..ec63490d4
--- /dev/null
+++ b/docs/CN/source/cookbook/glm4_deployment.rst
@@ -0,0 +1,213 @@
+.. _glm4_deployment:
+
+GLM-4.7-Flash 模型部署指南
+===========================
+
+LightLLM 支持 GLM-4.7-Flash (glm4_moe_lite) 模型系列的部署,该模型采用 MoE 架构。本文档提供详细的部署配置、函数调用和 MTP(多令牌预测)支持信息。
+
+模型概述
+--------
+
+**主要特性:**
+
+- 分组 MoE,支持 top-k 专家选择
+- 支持 ``vanilla_with_att`` 和 ``eagle_with_att`` MTP 模式
+- 兼容 FlashAttention3 后端
+- 支持 XML 风格参数格式的函数调用
+
+模型参考:https://huggingface.co/zai-org/GLM-4.7-Flash
+
+推荐启动脚本 (H200)
+-------------------
+
+**基础启动命令:**
+
+.. code-block:: bash
+
+ LIGHTLLM_TRITON_AUTOTUNE_LEVEL=1 LOADWORKER=18 \
+ python -m lightllm.server.api_server \
+ --model_dir /path/to/GLM-4.7-Flash/ \
+ --tp 1 \
+ --max_req_total_len 202752 \
+ --chunked_prefill_size 8192 \
+ --llm_prefill_att_backend fa3 \
+ --llm_decode_att_backend flashinfer \
+ --graph_max_batch_size 512 \
+ --tool_call_parser glm47 \
+ --reasoning_parser glm45 \
+ --host 0.0.0.0 \
+ --port 8000
+
+**参数说明:**
+
+- ``LIGHTLLM_TRITON_AUTOTUNE_LEVEL=1``: 启用 Triton 自动调优以获得最佳内核性能
+- ``LOADWORKER=18``: 模型加载线程数,加快权重加载速度
+- ``--tp 1``: 张量并行度(单 GPU)
+- ``--max_req_total_len 202752``: 最大请求总长度
+- ``--chunked_prefill_size 8192``: 预填充处理的分块大小
+- ``--llm_prefill_att_backend fa3``: 预填充阶段使用 FlashAttention3
+- ``--llm_decode_att_backend flashinfer``: 解码阶段使用 FlashInfer
+- ``--graph_max_batch_size 512``: CUDA graph 最大批处理大小
+- ``--tool_call_parser glm47``: 使用 GLM-4.7 函数调用解析器
+- ``--reasoning_parser glm45``: 使用 GLM-4.5 推理解析器
+
+MTP(多令牌预测)模式
+---------------------
+
+要启用 MTP 进行推测解码,请添加以下参数:
+
+.. code-block:: bash
+
+ LIGHTLLM_TRITON_AUTOTUNE_LEVEL=1 LOADWORKER=18 \
+ python -m lightllm.server.api_server \
+ --model_dir /path/to/GLM-4.7-Flash/ \
+ --tp 1 \
+ --max_req_total_len 202752 \
+ --chunked_prefill_size 8192 \
+ --llm_prefill_att_backend fa3 \
+ --llm_decode_att_backend flashinfer \
+ --graph_max_batch_size 512 \
+ --tool_call_parser glm47 \
+ --reasoning_parser glm45 \
+ --mtp_step 4 \
+ --mtp_mode eagle_with_att \
+ --mtp_draft_model_dir /path/to/GLM-4.7-Flash/ \
+ --host 0.0.0.0 \
+ --port 8000
+
+**MTP 参数说明:**
+
+- ``--mtp_step 4``: 每个 MTP 步骤预测的令牌数
+- ``--mtp_mode eagle_with_att``: MTP 模式(支持 ``vanilla_with_att`` 和 ``eagle_with_att``)
+- ``--mtp_draft_model_dir``: MTP 草稿模型路径
+
+函数调用支持
+------------
+
+GLM-4.7-Flash 使用新的 ``Glm47Detector`` 类来解析 XML 风格的工具调用。
+
+**函数调用格式:**
+
+.. code-block:: xml
+
+ func_name
+ keyvalue
+
+
+**特性:**
+
+- 完整的流式支持,支持增量解析
+- 兼容 OpenAI 风格的函数调用 API
+
+测试与验证
+----------
+
+基础功能测试
+~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ curl http://localhost:8000/generate \
+ -H "Content-Type: application/json" \
+ -d '{
+ "inputs": "什么是人工智能?",
+ "parameters":{
+ "max_new_tokens": 100,
+ "frequency_penalty": 1
+ }
+ }'
+
+OpenAI 兼容聊天接口
+~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ curl http://localhost:8000/v1/chat/completions \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "GLM-4.7-Flash",
+ "messages": [{"role": "user", "content": "你好"}],
+ "max_tokens": 100
+ }'
+
+性能基准测试
+------------
+
+函数调用测试结果 (BFCL v3)
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. list-table::
+ :header-rows: 1
+ :widths: 40 20
+
+ * - 类别
+ - LightLLM
+ * - simple
+ - 62.50%
+ * - multiple
+ - 54.50%
+ * - parallel
+ - 69.50%
+ * - parallel_multiple
+ - 61.50%
+ * - java
+ - 66.00%
+ * - javascript
+ - 48.00%
+ * - irrelevance
+ - 83.33%
+ * - live_simple
+ - 45.74%
+ * - live_multiple
+ - 34.00%
+ * - live_parallel
+ - 25.00%
+ * - live_parallel_multiple
+ - 37.50%
+ * - rest
+ - 2.86%
+ * - sql
+ - 28.00%
+ * - **总体**
+ - **49.12%**
+
+速度测试结果 (ShareGPT 2000 条提示,4×H200)
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. list-table::
+ :header-rows: 1
+ :widths: 30 20 20 30
+
+ * - 工作负载
+ - 输出 (tok/s)
+ - TTFT (ms)
+ - 端到端延迟 (ms)
+ * - burst
+ - 6442
+ - 11476
+ - 27719
+ * - high-conc (512)
+ - **6728**
+ - 1099
+ - 11240
+ * - moderate (10 req/s)
+ - 1798
+ - 196
+ - 5746
+ * - steady (5 req/s)
+ - 917
+ - 154
+ - 2797
+
+硬件要求
+--------
+
+**测试配置:**
+
+- 4× NVIDIA H200 (每卡 80GB HBM3)
+- NVLink 4.0 互联
+
+**最低要求:**
+
+- 基础部署需要单张 NVIDIA H100/H200 GPU(80GB 显存)
+- 生产环境建议使用多 GPU 配置
diff --git a/docs/CN/source/index.rst b/docs/CN/source/index.rst
index b97b2c759..58b747359 100755
--- a/docs/CN/source/index.rst
+++ b/docs/CN/source/index.rst
@@ -57,7 +57,13 @@ Lightllm 整合了众多的开源方案的优点,包括但不限于 FasterTran
思考解析(Reasoning Parser)
APIServer 参数详解
lightllm api介绍
-
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Cookbook
+
+ GLM-4.7-Flash 部署
+
.. toctree::
:maxdepth: 1
:caption: 模型支持
diff --git a/docs/EN/source/cookbook/glm4_deployment.rst b/docs/EN/source/cookbook/glm4_deployment.rst
new file mode 100644
index 000000000..c3807f812
--- /dev/null
+++ b/docs/EN/source/cookbook/glm4_deployment.rst
@@ -0,0 +1,213 @@
+.. _glm4_deployment:
+
+GLM-4.7-Flash Model Deployment Guide
+=====================================
+
+LightLLM supports deployment of GLM-4.7-Flash (glm4_moe_lite) model family with MoE architecture. This document provides detailed information on deployment configuration, function calling, and MTP (Multi-Token Prediction) support.
+
+Model Overview
+--------------
+
+**Key Features:**
+
+- Grouped MoE with top-k expert selection
+- Support for ``vanilla_with_att`` and ``eagle_with_att`` MTP modes
+- Compatible with FlashAttention3 backend
+- Function calling support with XML-style argument format
+
+Model Reference: https://huggingface.co/zai-org/GLM-4.7-Flash
+
+Recommended Launch Script (H200)
+--------------------------------
+
+**Basic Launch Command:**
+
+.. code-block:: bash
+
+ LIGHTLLM_TRITON_AUTOTUNE_LEVEL=1 LOADWORKER=18 \
+ python -m lightllm.server.api_server \
+ --model_dir /path/to/GLM-4.7-Flash/ \
+ --tp 1 \
+ --max_req_total_len 202752 \
+ --chunked_prefill_size 8192 \
+ --llm_prefill_att_backend fa3 \
+ --llm_decode_att_backend flashinfer \
+ --graph_max_batch_size 512 \
+ --tool_call_parser glm47 \
+ --reasoning_parser glm45 \
+ --host 0.0.0.0 \
+ --port 8000
+
+**Parameter Description:**
+
+- ``LIGHTLLM_TRITON_AUTOTUNE_LEVEL=1``: Enable Triton autotuning for optimal kernel performance
+- ``LOADWORKER=18``: Number of model loading threads for faster weight loading
+- ``--tp 1``: Tensor parallelism degree (single GPU)
+- ``--max_req_total_len 202752``: Maximum total request length
+- ``--chunked_prefill_size 8192``: Chunk size for prefill processing
+- ``--llm_prefill_att_backend fa3``: Use FlashAttention3 for prefill
+- ``--llm_decode_att_backend flashinfer``: Use FlashInfer for decode
+- ``--graph_max_batch_size 512``: Maximum batch size for CUDA graph
+- ``--tool_call_parser glm47``: Use GLM-4.7 function calling parser
+- ``--reasoning_parser glm45``: Use GLM-4.5 reasoning parser
+
+MTP (Multi-Token Prediction) Mode
+---------------------------------
+
+To enable MTP for speculative decoding, add the following parameters:
+
+.. code-block:: bash
+
+ LIGHTLLM_TRITON_AUTOTUNE_LEVEL=1 LOADWORKER=18 \
+ python -m lightllm.server.api_server \
+ --model_dir /path/to/GLM-4.7-Flash/ \
+ --tp 1 \
+ --max_req_total_len 202752 \
+ --chunked_prefill_size 8192 \
+ --llm_prefill_att_backend fa3 \
+ --llm_decode_att_backend flashinfer \
+ --graph_max_batch_size 512 \
+ --tool_call_parser glm47 \
+ --reasoning_parser glm45 \
+ --mtp_step 4 \
+ --mtp_mode eagle_with_att \
+ --mtp_draft_model_dir /path/to/GLM-4.7-Flash/ \
+ --host 0.0.0.0 \
+ --port 8000
+
+**MTP Parameters:**
+
+- ``--mtp_step 4``: Number of tokens to predict in each MTP step
+- ``--mtp_mode eagle_with_att``: MTP mode (supports ``vanilla_with_att`` and ``eagle_with_att``)
+- ``--mtp_draft_model_dir``: Path to the draft model for MTP
+
+Function Calling Support
+------------------------
+
+GLM-4.7-Flash uses a new ``Glm47Detector`` class for parsing XML-style tool calls.
+
+**Function Call Format:**
+
+.. code-block:: xml
+
+ func_name
+ keyvalue
+
+
+**Features:**
+
+- Full streaming support for incremental parsing
+- Compatible with OpenAI-style function calling API
+
+Testing and Validation
+----------------------
+
+Basic Functionality Testing
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ curl http://localhost:8000/generate \
+ -H "Content-Type: application/json" \
+ -d '{
+ "inputs": "What is AI?",
+ "parameters":{
+ "max_new_tokens": 100,
+ "frequency_penalty": 1
+ }
+ }'
+
+OpenAI-Compatible Chat Completions
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ curl http://localhost:8000/v1/chat/completions \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "GLM-4.7-Flash",
+ "messages": [{"role": "user", "content": "Hello"}],
+ "max_tokens": 100
+ }'
+
+Performance Benchmarks
+----------------------
+
+Function Calling Test Results (BFCL v3)
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. list-table::
+ :header-rows: 1
+ :widths: 40 20
+
+ * - Category
+ - LightLLM
+ * - simple
+ - 62.50%
+ * - multiple
+ - 54.50%
+ * - parallel
+ - 69.50%
+ * - parallel_multiple
+ - 61.50%
+ * - java
+ - 66.00%
+ * - javascript
+ - 48.00%
+ * - irrelevance
+ - 83.33%
+ * - live_simple
+ - 45.74%
+ * - live_multiple
+ - 34.00%
+ * - live_parallel
+ - 25.00%
+ * - live_parallel_multiple
+ - 37.50%
+ * - rest
+ - 2.86%
+ * - sql
+ - 28.00%
+ * - **OVERALL**
+ - **49.12%**
+
+Speed Test Results (ShareGPT 2000 prompts, 4×H200)
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. list-table::
+ :header-rows: 1
+ :widths: 30 20 20 30
+
+ * - Workload
+ - Output (tok/s)
+ - TTFT (ms)
+ - E2E Latency (ms)
+ * - burst
+ - 6442
+ - 11476
+ - 27719
+ * - high-conc (512)
+ - **6728**
+ - 1099
+ - 11240
+ * - moderate (10 req/s)
+ - 1798
+ - 196
+ - 5746
+ * - steady (5 req/s)
+ - 917
+ - 154
+ - 2797
+
+Hardware Requirements
+---------------------
+
+**Tested Configuration:**
+
+- 4× NVIDIA H200 (80GB HBM3 each)
+- NVLink 4.0 interconnect
+
+**Minimum Requirements:**
+
+- Single NVIDIA H100/H200 GPU with 80GB memory for basic deployment
+- Multiple GPUs recommended for production workloads
diff --git a/docs/EN/source/index.rst b/docs/EN/source/index.rst
index 07eaaa42e..5ad3c63c1 100755
--- a/docs/EN/source/index.rst
+++ b/docs/EN/source/index.rst
@@ -56,7 +56,13 @@ Documentation List
Reasoning Parser
APIServer Parameters
Lightllm API Introduction
-
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Cookbook
+
+ GLM-4.7-Flash Deployment
+
.. toctree::
:maxdepth: 1
:caption: Model Support
diff --git a/lightllm/common/basemodel/attention/flashinfer/mla.py b/lightllm/common/basemodel/attention/flashinfer/mla.py
index 6c74b22e1..c9e831829 100644
--- a/lightllm/common/basemodel/attention/flashinfer/mla.py
+++ b/lightllm/common/basemodel/attention/flashinfer/mla.py
@@ -16,6 +16,8 @@ def __init__(self, model):
self.qk_nope_head_dim = model.qk_nope_head_dim
self.qk_rope_head_dim = model.qk_rope_head_dim
self.kv_lora_rank = model.kv_lora_rank
+ # v_head_dim may differ from qk_nope_head_dim (e.g., GLM-4.7-Flash: v_head_dim=256, qk_nope_head_dim=192)
+ self.v_head_dim = getattr(model, "v_head_dim", self.qk_nope_head_dim)
self.q_data_type = model.data_type
self.kv_data_type = model.data_type
self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id())
@@ -69,7 +71,7 @@ def init_state(self):
num_qo_heads=self.backend.tp_q_head_num,
num_kv_heads=self.backend.tp_q_head_num,
head_dim_qk=self.backend.qk_nope_head_dim + self.backend.qk_rope_head_dim,
- head_dim_vo=self.backend.qk_nope_head_dim,
+ head_dim_vo=self.backend.v_head_dim,
q_data_type=self.backend.q_data_type,
causal=True,
sm_scale=self.backend.softmax_scale,
@@ -101,7 +103,7 @@ def _mla_prefill_att(
) -> torch.Tensor:
self.backend: MlaFlashInferAttBackend = self.backend # for typing
k_nope, k_rope = k
- o_tensor = alloc_func((q.shape[0], q.shape[1], k_nope.shape[2]), q.dtype, device="cuda")
+ o_tensor = alloc_func((q.shape[0], q.shape[1], v.shape[-1]), q.dtype, device="cuda")
q_head_num = q.shape[1]
k = torch.cat([k_nope, torch.repeat_interleave(k_rope, q_head_num, dim=-2)], dim=-1)
self.prefill_wrapper.run(q, k, v, out=o_tensor)
diff --git a/lightllm/common/basemodel/attention/triton/mla.py b/lightllm/common/basemodel/attention/triton/mla.py
index 8288193ad..c7edecd10 100644
--- a/lightllm/common/basemodel/attention/triton/mla.py
+++ b/lightllm/common/basemodel/attention/triton/mla.py
@@ -44,7 +44,7 @@ def _mla_prefill_att(
qk_rope_head_dim = 64
q_nope, q_rope = q[:, :, :-qk_rope_head_dim], q[:, :, -qk_rope_head_dim:]
- o_tensor = alloc_func(q_nope.shape, dtype=q_nope.dtype, device=q.device)
+ o_tensor = alloc_func((q_nope.shape[0], q_nope.shape[1], v.shape[-1]), dtype=q_nope.dtype, device=q.device)
k_nope, k_rope = k
assert att_control.mla_prefill
softmax_scale = att_control.mla_prefill_dict["softmax_scale"]
diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py
index e6405e4d7..5c1d2b871 100755
--- a/lightllm/common/basemodel/basemodel.py
+++ b/lightllm/common/basemodel/basemodel.py
@@ -1022,6 +1022,7 @@ def _gen_special_model_input(self, token_num: int):
"Deepseek3MTPModel" in str(self.__class__)
or "Qwen3MOEMTPModel" in str(self.__class__)
or "MistralMTPModel" in str(self.__class__)
+ or "Glm4MoeLiteMTPModel" in str(self.__class__)
)
if is_mtp_draft_model:
special_model_input["mtp_draft_input_hiddens"] = torch.randn(
diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py
index 6a9bb79c7..141587ff3 100644
--- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py
+++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py
@@ -81,8 +81,11 @@ def _fwd_kernel_calcu_index_and_block_seq(
vsm_count,
batch_size,
BLOCK_N: tl.constexpr,
+ MAX_BATCH_SIZE: tl.constexpr,
):
- b_seq_len = tl.load(b_seq_len + tl.arange(0, 2048), mask=tl.arange(0, 2048) < batch_size, other=0)
+ b_seq_len = tl.load(
+ b_seq_len + tl.arange(0, MAX_BATCH_SIZE), mask=tl.arange(0, MAX_BATCH_SIZE) < batch_size, other=0
+ )
total_token_num = tl.sum(b_seq_len)
block_seq = tl.cdiv(total_token_num, vsm_count * 4)
@@ -93,9 +96,9 @@ def _fwd_kernel_calcu_index_and_block_seq(
cumsum_seq_len = tl.cumsum(block_seq_len)
batch_start_index = cumsum_seq_len - block_seq_len
tl.store(
- mid_o_batch_start_index + tl.arange(0, 2048),
+ mid_o_batch_start_index + tl.arange(0, MAX_BATCH_SIZE),
batch_start_index,
- mask=tl.arange(0, 2048) < batch_size,
+ mask=tl.arange(0, MAX_BATCH_SIZE) < batch_size,
)
tl.store(mid_o_decode_att_block_seq, block_seq)
@@ -455,7 +458,6 @@ def gqa_token_decode_attention_flash_decoding_vsm(
)
if not hasattr(infer_state, "decode_att_block_seq"):
- assert batch_size <= 2048
decode_att_block_seq = torch.empty(
[
1,
@@ -477,6 +479,7 @@ def gqa_token_decode_attention_flash_decoding_vsm(
num_vsm,
batch_size,
BLOCK_N=run_config["BLOCK_N"],
+ MAX_BATCH_SIZE=triton.next_power_of_2(batch_size),
num_warps=4,
)
diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py
index fb0323cd4..2687adf14 100644
--- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py
+++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py
@@ -227,7 +227,7 @@ def triton_grouped_topk(
scores_buffer = torch.empty((token_num, total_expert_num), dtype=dtype, device="cuda")
out_topk_weights = torch.empty((token_num, topk), dtype=torch.float32, device="cuda")
- out_topk_ids = torch.empty((token_num, topk), dtype=torch.long, device="cuda")
+ out_topk_ids = torch.empty((token_num, topk), dtype=torch.int32, device="cuda")
assert total_expert_num % num_expert_group == 0
diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py b/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py
index 72c3a381e..7ac5a03b5 100644
--- a/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py
+++ b/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py
@@ -196,10 +196,12 @@ def select_experts(
scoring_func=scoring_func,
)
else:
- group_score_topk_num = 1
- # for deepseek v3
- if topk_group == 4 and num_expert_group == 8 and top_k == 8:
+ if correction_bias is not None:
group_score_topk_num = 2
+ elif topk_group == 4 and num_expert_group == 8 and top_k == 8:
+ group_score_topk_num = 2
+ else:
+ group_score_topk_num = 1
topk_weights, topk_ids = triton_grouped_topk(
hidden_states=hidden_states,
diff --git a/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py
index 28839b5f5..063181d99 100644
--- a/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py
+++ b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py
@@ -67,7 +67,6 @@ def gqa_token_decode_attention_flash_decoding(
)
if not hasattr(infer_state, "decode_att_block_seq"):
- assert batch_size <= 2048
decode_att_block_seq = torch.empty(
[
1,
@@ -89,6 +88,7 @@ def gqa_token_decode_attention_flash_decoding(
vsm_count,
batch_size,
BLOCK_N=BLOCK_N,
+ MAX_BATCH_SIZE=triton.next_power_of_2(batch_size),
num_warps=4,
)
@@ -134,8 +134,11 @@ def _fwd_kernel_calcu_index_and_block_seq(
num_sm,
batch_size,
BLOCK_N: tl.constexpr,
+ MAX_BATCH_SIZE: tl.constexpr,
):
- b_seq_len = tl.load(b_seq_len_ptr + tl.arange(0, 2048), mask=tl.arange(0, 2048) < batch_size, other=0)
+ b_seq_len = tl.load(
+ b_seq_len_ptr + tl.arange(0, MAX_BATCH_SIZE), mask=tl.arange(0, MAX_BATCH_SIZE) < batch_size, other=0
+ )
total_token_num = tl.sum(b_seq_len)
block_seq = tl.cast(total_token_num / (num_sm * 4), dtype=tl.int32) + 1
@@ -144,6 +147,10 @@ def _fwd_kernel_calcu_index_and_block_seq(
block_seq_len = tl.cdiv(b_seq_len, block_seq)
cumsum_seq_len = tl.cumsum(block_seq_len)
batch_start_index = cumsum_seq_len - block_seq_len
- tl.store(mid_o_batch_start_index_ptr + tl.arange(0, 2048), batch_start_index, mask=tl.arange(0, 2048) < batch_size)
+ tl.store(
+ mid_o_batch_start_index_ptr + tl.arange(0, MAX_BATCH_SIZE),
+ batch_start_index,
+ mask=tl.arange(0, MAX_BATCH_SIZE) < batch_size,
+ )
tl.store(mid_o_decode_att_block_seq_ptr, block_seq)
return
diff --git a/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/context_flashattention_nopad_with_v.py b/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/context_flashattention_nopad_with_v.py
index be0635182..d79020844 100644
--- a/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/context_flashattention_nopad_with_v.py
+++ b/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/context_flashattention_nopad_with_v.py
@@ -36,6 +36,9 @@ def _fwd_kernel_with_v(
BLOCK_DMODEL: tl.constexpr,
BLOCK_ROPE_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
+ BLOCK_V_DMODEL: tl.constexpr,
+ ACTUAL_DMODEL: tl.constexpr,
+ ACTUAL_V_DMODEL: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
@@ -53,8 +56,13 @@ def _fwd_kernel_with_v(
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
+ offs_v_d = tl.arange(0, BLOCK_V_DMODEL)
offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
+
+ d_mask = offs_d < ACTUAL_DMODEL
+ v_d_mask = offs_v_d < ACTUAL_V_DMODEL
+
off_q = (cur_batch_in_q_start_index + offs_m[:, None]) * stride_q_bs + cur_head * stride_q_h + offs_d[None, :]
off_q_rope = (
(cur_batch_in_q_start_index + offs_m[:, None]) * stride_q_rope_bs
@@ -63,9 +71,10 @@ def _fwd_kernel_with_v(
)
off_k = offs_n[None, :] * stride_k_bs + cur_k_head * stride_k_h + offs_d[:, None]
off_k_rope = offs_n[None, :] * stride_k_rope_bs + offs_rope_d[:, None]
- off_v = offs_n[:, None] * stride_vbs + cur_k_head * stride_vh + offs_d[None, :]
+ off_v = offs_n[:, None] * stride_vbs + cur_k_head * stride_vh + offs_v_d[None, :]
- q = tl.load(Q_nope + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
+ q_mask = (offs_m[:, None] < cur_batch_seq_len) & d_mask[None, :]
+ q = tl.load(Q_nope + off_q, mask=q_mask, other=0.0)
q_rope = tl.load(Q_rope + off_q_rope, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
k_ptrs = K_nope + off_k
@@ -75,7 +84,7 @@ def _fwd_kernel_with_v(
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
- acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
+ acc = tl.zeros([BLOCK_M, BLOCK_V_DMODEL], dtype=tl.float32)
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
block_end_loc = tl.minimum((start_m + 1) * BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len)
@@ -83,14 +92,16 @@ def _fwd_kernel_with_v(
for start_n in range(0, block_mask * block_end_loc, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
+ k_seq_mask = (start_n + offs_n[None, :]) < block_end_loc
+ k_mask = k_seq_mask & d_mask[:, None]
k = tl.load(
k_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_k_bs,
- mask=(start_n + offs_n[None, :]) < block_end_loc,
+ mask=k_mask,
other=0.0,
)
k_rope = tl.load(
k_rope_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_k_rope_bs,
- mask=(start_n + offs_n[None, :]) < block_end_loc,
+ mask=k_seq_mask,
other=0.0,
)
@@ -112,9 +123,11 @@ def _fwd_kernel_with_v(
# -- update output accumulator --
acc = acc * alpha[:, None]
# update acc
+ v_seq_mask = (start_n + offs_n[:, None]) < block_end_loc
+ v_mask = v_seq_mask & v_d_mask[None, :]
v = tl.load(
v_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_vbs,
- mask=(start_n + offs_n[:, None]) < block_end_loc,
+ mask=v_mask,
other=0.0,
)
p = p.to(v.dtype)
@@ -124,9 +137,10 @@ def _fwd_kernel_with_v(
acc = acc / l_i[:, None]
# initialize pointers to output
- off_o = (cur_batch_in_q_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :]
+ off_o = (cur_batch_in_q_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_v_d[None, :]
out_ptrs = Out + off_o
- tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
+ o_mask = (offs_m[:, None] < cur_batch_seq_len) & v_d_mask[None, :]
+ tl.store(out_ptrs, acc, mask=o_mask)
return
@@ -149,13 +163,14 @@ def context_attention_fwd_with_v(
BLOCK = 128 if not is_tesla() else 64
q_nope_dim = q_nope.shape[-1]
q_rope_dim = q_rope.shape[-1]
+ v_dim = v.shape[-1]
assert q_nope_dim == k_nope.shape[-1]
assert q_rope_dim == k_rope.shape[-1]
- assert q_nope_dim in {16, 32, 64, 128, 256, 512}
- assert q_rope_dim in {16, 32, 64, 128, 256}
- assert q_nope_dim == v.shape[-1]
- if q_nope_dim >= 512:
+ q_nope_dim_padded = triton.next_power_of_2(q_nope_dim)
+ v_dim_padded = triton.next_power_of_2(v_dim)
+
+ if q_nope_dim_padded >= 512 or v_dim_padded >= 512:
BLOCK = 64 if not is_tesla() else 32
else:
BLOCK = 128 if not is_tesla() else 64
@@ -167,7 +182,7 @@ def context_attention_fwd_with_v(
batch, head = b_seq_len.shape[0], q_nope.shape[1]
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
- num_warps = 4 if q_nope_dim <= 64 else 8
+ num_warps = 4 if q_nope_dim_padded <= 64 else 8
_fwd_kernel_with_v[grid](
q_nope,
@@ -194,9 +209,12 @@ def context_attention_fwd_with_v(
o.stride(1),
b_prompt_cache_len=b_prompt_cache_len,
BLOCK_M=BLOCK,
- BLOCK_DMODEL=q_nope_dim,
+ BLOCK_DMODEL=q_nope_dim_padded,
BLOCK_ROPE_DMODEL=q_rope_dim,
BLOCK_N=BLOCK,
+ BLOCK_V_DMODEL=v_dim_padded,
+ ACTUAL_DMODEL=q_nope_dim,
+ ACTUAL_V_DMODEL=v_dim,
num_warps=num_warps,
num_stages=1,
)
diff --git a/lightllm/common/quantization/no_quant.py b/lightllm/common/quantization/no_quant.py
index 3bf023f8a..fa926ad6f 100644
--- a/lightllm/common/quantization/no_quant.py
+++ b/lightllm/common/quantization/no_quant.py
@@ -28,8 +28,8 @@ def apply(
device = input_tensor.device
if use_custom_tensor_mananger:
out = g_cache_manager.alloc_tensor(shape, dtype, device=device)
- else:
- out = torch.empty(shape, dtype=dtype, device=device)
+ else:
+ out = torch.empty(shape, dtype=dtype, device=device)
if bias is None:
return torch.mm(input_tensor, weight, out=out)
return torch.addmm(bias, input_tensor, weight, out=out)
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json
new file mode 100644
index 000000000..ea5ca845d
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json
@@ -0,0 +1,119 @@
+{
+ "1024": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 8
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "4": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "400": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ },
+ "512": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "65536": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "66560": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=65,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=65,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json
new file mode 100644
index 000000000..588bc5a2a
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=65,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json
@@ -0,0 +1,110 @@
+{
+ "10240": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "1280": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "160": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "20480": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "320": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "40": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "5": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "500": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "5120": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "640": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "80": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "83200": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=1536,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=1536,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json
new file mode 100644
index 000000000..f7df66542
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=1536,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json
@@ -0,0 +1,119 @@
+{
+ "1": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 8
+ },
+ "16": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16640": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 8
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json
new file mode 100644
index 000000000..27b601a56
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json
@@ -0,0 +1,119 @@
+{
+ "1": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 8
+ },
+ "1024": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 8
+ },
+ "16": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16640": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 8
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=65,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=5,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=65,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=5,use_fp8_w8a8=false}_NVIDIA_H200.json
new file mode 100644
index 000000000..5f05eb2a6
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=65,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=5,use_fp8_w8a8=false}_NVIDIA_H200.json
@@ -0,0 +1,110 @@
+{
+ "1": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16640": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 32,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 5,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json
new file mode 100644
index 000000000..72dc716ac
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json
@@ -0,0 +1,119 @@
+{
+ "1": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_SIZE_K": 32,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16640": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json
new file mode 100644
index 000000000..ceba290e6
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json
@@ -0,0 +1,119 @@
+{
+ "1024": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "4": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "400": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "512": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "65536": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "66560": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=768,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=768,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json
new file mode 100644
index 000000000..063afa697
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=768,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json
@@ -0,0 +1,119 @@
+{
+ "1024": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "4": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "400": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ },
+ "512": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "65536": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "66560": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=4}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=4}_NVIDIA_H200.json
new file mode 100644
index 000000000..1764ad455
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=4}_NVIDIA_H200.json
@@ -0,0 +1,54 @@
+{
+ "1": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "100": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "1024": {
+ "BLOCK_SIZE": 256,
+ "num_warps": 8
+ },
+ "128": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 4
+ },
+ "16": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 2
+ },
+ "16384": {
+ "BLOCK_SIZE": 256,
+ "num_warps": 8
+ },
+ "16640": {
+ "BLOCK_SIZE": 256,
+ "num_warps": 8
+ },
+ "2048": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=5}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=5}_NVIDIA_H200.json
new file mode 100644
index 000000000..f1d903fd2
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=5}_NVIDIA_H200.json
@@ -0,0 +1,50 @@
+{
+ "1": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "1024": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 4
+ },
+ "16": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 4
+ },
+ "16640": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 1
+ },
+ "2048": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "4096": {
+ "BLOCK_SIZE": 256,
+ "num_warps": 8
+ },
+ "64": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=4}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=4}_NVIDIA_H200.json
new file mode 100644
index 000000000..8da1f07f1
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=4}_NVIDIA_H200.json
@@ -0,0 +1,80 @@
+{
+ "1": {
+ "BLOCK_DIM": 64,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "1024": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "128": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 1
+ },
+ "16": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "16384": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 4
+ },
+ "16640": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "2048": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 1
+ },
+ "256": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 2
+ },
+ "4096": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "64": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 16
+ },
+ "8": {
+ "BLOCK_DIM": 64,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 2,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=5}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=5}_NVIDIA_H200.json
new file mode 100644
index 000000000..610acfaa5
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=5}_NVIDIA_H200.json
@@ -0,0 +1,74 @@
+{
+ "1": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 16,
+ "NUM_STAGE": 4,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "1024": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 2,
+ "num_warps": 2
+ },
+ "128": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "16": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "16640": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 1
+ },
+ "64": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 2
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=10,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=10,dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 000000000..1b75882e7
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=10,dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,80 @@
+{
+ "1": {
+ "BLOCK_SEQ": 32,
+ "HEAD_PARALLEL_NUM": 16,
+ "num_stages": 3,
+ "num_warps": 8
+ },
+ "100": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 16,
+ "num_stages": 5,
+ "num_warps": 1
+ },
+ "1024": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 2,
+ "num_stages": 5,
+ "num_warps": 2
+ },
+ "128": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 8,
+ "num_stages": 4,
+ "num_warps": 1
+ },
+ "16": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 16,
+ "num_stages": 2,
+ "num_warps": 2
+ },
+ "16384": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 1,
+ "num_stages": 2,
+ "num_warps": 1
+ },
+ "16640": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 1,
+ "num_stages": 2,
+ "num_warps": 1
+ },
+ "2048": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 2,
+ "num_stages": 2,
+ "num_warps": 1
+ },
+ "256": {
+ "BLOCK_SEQ": 2,
+ "HEAD_PARALLEL_NUM": 16,
+ "num_stages": 1,
+ "num_warps": 2
+ },
+ "32": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 16,
+ "num_stages": 2,
+ "num_warps": 8
+ },
+ "4096": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 1,
+ "num_stages": 5,
+ "num_warps": 1
+ },
+ "64": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 16,
+ "num_stages": 5,
+ "num_warps": 1
+ },
+ "8": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 16,
+ "num_stages": 2,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=20,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=20,dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 000000000..056b6a747
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=20,dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,80 @@
+{
+ "1": {
+ "BLOCK_SEQ": 32,
+ "HEAD_PARALLEL_NUM": 16,
+ "num_stages": 5,
+ "num_warps": 8
+ },
+ "100": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 16,
+ "num_stages": 3,
+ "num_warps": 1
+ },
+ "1024": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 2,
+ "num_stages": 5,
+ "num_warps": 1
+ },
+ "128": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 8,
+ "num_stages": 5,
+ "num_warps": 1
+ },
+ "16": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 16,
+ "num_stages": 3,
+ "num_warps": 2
+ },
+ "16384": {
+ "BLOCK_SEQ": 8,
+ "HEAD_PARALLEL_NUM": 2,
+ "num_stages": 5,
+ "num_warps": 1
+ },
+ "16640": {
+ "BLOCK_SEQ": 4,
+ "HEAD_PARALLEL_NUM": 1,
+ "num_stages": 4,
+ "num_warps": 1
+ },
+ "2048": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 2,
+ "num_stages": 5,
+ "num_warps": 1
+ },
+ "256": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 8,
+ "num_stages": 1,
+ "num_warps": 1
+ },
+ "32": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 16,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SEQ": 2,
+ "HEAD_PARALLEL_NUM": 2,
+ "num_stages": 3,
+ "num_warps": 1
+ },
+ "64": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 16,
+ "num_stages": 1,
+ "num_warps": 1
+ },
+ "8": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 16,
+ "num_stages": 4,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=5,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=5,dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 000000000..0c480f44e
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=5,dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,80 @@
+{
+ "1": {
+ "BLOCK_SEQ": 2,
+ "HEAD_PARALLEL_NUM": 8,
+ "num_stages": 3,
+ "num_warps": 1
+ },
+ "100": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 8,
+ "num_stages": 5,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 2,
+ "num_stages": 4,
+ "num_warps": 2
+ },
+ "128": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 8,
+ "num_stages": 1,
+ "num_warps": 1
+ },
+ "16": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 16,
+ "num_stages": 2,
+ "num_warps": 1
+ },
+ "16384": {
+ "BLOCK_SEQ": 4,
+ "HEAD_PARALLEL_NUM": 1,
+ "num_stages": 2,
+ "num_warps": 1
+ },
+ "16640": {
+ "BLOCK_SEQ": 4,
+ "HEAD_PARALLEL_NUM": 1,
+ "num_stages": 2,
+ "num_warps": 1
+ },
+ "2048": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 1,
+ "num_stages": 3,
+ "num_warps": 1
+ },
+ "256": {
+ "BLOCK_SEQ": 2,
+ "HEAD_PARALLEL_NUM": 8,
+ "num_stages": 3,
+ "num_warps": 2
+ },
+ "32": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 8,
+ "num_stages": 1,
+ "num_warps": 1
+ },
+ "4096": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 1,
+ "num_stages": 3,
+ "num_warps": 1
+ },
+ "64": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 8,
+ "num_stages": 5,
+ "num_warps": 8
+ },
+ "8": {
+ "BLOCK_SEQ": 1,
+ "HEAD_PARALLEL_NUM": 8,
+ "num_stages": 3,
+ "num_warps": 2
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=10240,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=10240,out_dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 000000000..70f6af6d6
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=10240,out_dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,80 @@
+{
+ "1": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "100": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "1024": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "128": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "16": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "16384": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "16640": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "2048": {
+ "BLOCK_M": 32,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "256": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "32": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 2,
+ "num_warps": 1
+ },
+ "4096": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "64": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "8": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 000000000..71bc9d341
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,188 @@
+{
+ "1": {
+ "BLOCK_M": 32,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 2,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 2,
+ "num_warps": 1
+ },
+ "1024": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "10240": {
+ "BLOCK_M": 32,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "128": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "1280": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "16": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 4
+ },
+ "160": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "16640": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "2048": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "20480": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "256": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 2,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 8
+ },
+ "320": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "4": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 32,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "40": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "400": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "5": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 8
+ },
+ "500": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "512": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 4
+ },
+ "5120": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "64": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "640": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "65536": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "66560": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "8": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 64,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "80": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "83200": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=2560,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=2560,out_dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 000000000..1bb9f7370
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=2560,out_dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,80 @@
+{
+ "1": {
+ "BLOCK_M": 32,
+ "BLOCK_N": 64,
+ "NUM_STAGES": 4,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 2,
+ "num_warps": 1
+ },
+ "1024": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 2,
+ "num_warps": 1
+ },
+ "16": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 2,
+ "num_warps": 8
+ },
+ "16384": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "16640": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "2048": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "256": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 8
+ },
+ "4096": {
+ "BLOCK_M": 32,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "64": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 8
+ },
+ "8": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 000000000..f133810ae
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,116 @@
+{
+ "1": {
+ "BLOCK_M": 128,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 8
+ },
+ "100": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 2,
+ "num_warps": 1
+ },
+ "16": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 2,
+ "num_warps": 8
+ },
+ "16384": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "16640": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "2048": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 8
+ },
+ "32": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 8
+ },
+ "4": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 64,
+ "NUM_STAGES": 4,
+ "num_warps": 8
+ },
+ "400": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "512": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 8
+ },
+ "65536": {
+ "BLOCK_M": 64,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "66560": {
+ "BLOCK_M": 64,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "8": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 2,
+ "num_warps": 1
+ },
+ "8192": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=5120,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=5120,out_dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 000000000..796d8f01b
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=5120,out_dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,80 @@
+{
+ "1": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 32,
+ "NUM_STAGES": 2,
+ "num_warps": 8
+ },
+ "100": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "1024": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "128": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "16": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 2,
+ "num_warps": 8
+ },
+ "16384": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "16640": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "2048": {
+ "BLOCK_M": 32,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "256": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 8
+ },
+ "32": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 2,
+ "num_warps": 8
+ },
+ "4096": {
+ "BLOCK_M": 32,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "64": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 2,
+ "num_warps": 1
+ },
+ "8": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=768,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=768,out_dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 000000000..dfd9dd729
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=768,out_dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,116 @@
+{
+ "1": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 8
+ },
+ "100": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "1024": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "128": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 2,
+ "num_warps": 1
+ },
+ "16": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 2,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_M": 32,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "16640": {
+ "BLOCK_M": 32,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "2048": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "256": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "32": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 64,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "4": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "400": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "4096": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "512": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "64": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "65536": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "66560": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "8": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 2,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotuner.py b/lightllm/common/triton_utils/autotuner.py
index a919f7b28..c62a2572f 100644
--- a/lightllm/common/triton_utils/autotuner.py
+++ b/lightllm/common/triton_utils/autotuner.py
@@ -62,7 +62,7 @@ def autotune(
as needed before invocation.
"""
- def decorator(fn):
+ def decorator(fn: Callable) -> Callable:
return Autotuner(
fn=fn,
kernel_name=kernel_name,
diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py
index 539b32dec..095f73679 100644
--- a/lightllm/models/__init__.py
+++ b/lightllm/models/__init__.py
@@ -18,6 +18,7 @@
from lightllm.models.gemma_2b.model import Gemma_2bTpPartModel
from lightllm.models.phi3.model import Phi3TpPartModel
from lightllm.models.deepseek2.model import Deepseek2TpPartModel
+from lightllm.models.glm4_moe_lite.model import Glm4MoeLiteTpPartModel
from lightllm.models.internvl.model import (
InternVLLlamaTpPartModel,
InternVLPhi3TpPartModel,
diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py
index e1e435cce..98cc7c229 100644
--- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py
@@ -33,7 +33,7 @@ def __init__(self, layer_num, network_config):
self.is_moe = (
network_config["n_routed_experts"] is not None
and layer_num >= network_config["first_k_dense_replace"]
- and layer_num % network_config["moe_layer_freq"] == 0
+ and layer_num % network_config.get("moe_layer_freq", 1) == 0
)
self.n_shared_experts = network_config["n_shared_experts"]
@@ -65,10 +65,10 @@ def _bind_ffn(self):
if self.is_moe:
enable_ep_moe = get_env_start_args().enable_ep_moe
if enable_ep_moe:
- self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn_edp, self)
+ self._ffn = self._moe_ffn_edp
self._tpsp_ffn = self._tpsp_ffn_ep
else:
- self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn, self)
+ self._ffn = self._moe_ffn
self._tpsp_ffn = self._tpsp_ffn_tp
else:
self._ffn = partial(LlamaTransformerLayerInfer._ffn, self)
@@ -257,7 +257,7 @@ def _get_o(
) -> torch.Tensor:
if input.shape[2] == self.kv_lora_rank:
input = layer_weight.v_b_proj_.bmm(input.transpose(0, 1)).transpose(0, 1)
- o_tensor = layer_weight.o_weight_.mm(input.reshape(-1, self.tp_q_head_num_ * self.qk_nope_head_dim))
+ o_tensor = layer_weight.o_weight_.mm(input.reshape(-1, self.tp_q_head_num_ * self.v_head_dim))
return o_tensor
def _tpsp_get_o(
@@ -269,7 +269,7 @@ def _tpsp_get_o(
if input.shape[2] == self.kv_lora_rank:
input = layer_weight.v_b_proj_.bmm(input.transpose(0, 1)).transpose(0, 1)
- input = input.reshape(-1, self.tp_q_head_num_ * self.qk_nope_head_dim)
+ input = input.reshape(-1, self.tp_q_head_num_ * self.v_head_dim)
dest_size = triton.cdiv(input.shape[0], self.tp_world_size_) * self.tp_world_size_
o_tensor = self.alloc_tensor((dest_size, self.embed_dim_), dtype=input.dtype, device=input.device)
layer_weight.o_weight_.mm(input, out=o_tensor[0 : len(infer_state.input_ids), :])
@@ -302,7 +302,8 @@ def _moe_ffn(
if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0:
shared_output = LlamaTransformerLayerInfer._ffn(self, hidden_states, infer_state, layer_weight)
- router_logits = layer_weight.moe_gate.mm(hidden_states)
+ moe_gate_dtype = layer_weight.moe_gate.data_type_
+ router_logits = layer_weight.moe_gate.mm(hidden_states.to(moe_gate_dtype))
layer_weight.experts.experts(
hidden_states,
router_logits=router_logits,
@@ -327,7 +328,8 @@ def _moe_ffn_edp(
if self.n_shared_experts is not None:
shared_output = LlamaTransformerLayerInfer._ffn(self, hidden_states, infer_state, layer_weight)
- router_logits = layer_weight.moe_gate.mm(hidden_states)
+ moe_gate_dtype = layer_weight.moe_gate.data_type_
+ router_logits = layer_weight.moe_gate.mm(hidden_states.to(moe_gate_dtype))
ep_output = layer_weight.experts.experts(
hidden_states,
router_logits=router_logits,
@@ -405,7 +407,8 @@ def overlap_tpsp_token_forward(
input_embdings.add_(_0_o.view(-1, self.embed_dim_))
_0_o = None
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
- _0_router_logits = layer_weight.moe_gate.mm(_0_input1)
+ moe_gate_dtype = layer_weight.moe_gate.data_type_
+ _0_router_logits = layer_weight.moe_gate.mm(_0_input1.to(moe_gate_dtype))
# 1 hook
if getattr(infer_state1, "hook", None) is not None:
infer_state1.hook()
@@ -439,7 +442,8 @@ def overlap_tpsp_token_forward(
_1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight)
# to do gate and disptatch
- _1_router_logits = layer_weight.moe_gate.mm(_1_input1)
+ moe_gate_dtype = layer_weight.moe_gate.data_type_
+ _1_router_logits = layer_weight.moe_gate.mm(_1_input1.to(moe_gate_dtype))
# 0 hook
if getattr(infer_state, "hook", None) is not None:
infer_state.hook()
@@ -529,7 +533,8 @@ def overlap_tpsp_context_forward(
input_embdings.add_(_0_o.view(-1, self.embed_dim_))
_0_o = None
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
- _0_router_logits = layer_weight.moe_gate.mm(_0_input1)
+ moe_gate_dtype = layer_weight.moe_gate.data_type_
+ _0_router_logits = layer_weight.moe_gate.mm(_0_input1.to(moe_gate_dtype))
# wait last 1 combine
if getattr(infer_state1, "hook", None) is not None:
@@ -556,7 +561,8 @@ def overlap_tpsp_context_forward(
_1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight)
# to do gate and disptatch
- _1_router_logits = layer_weight.moe_gate.mm(_1_input1)
+ moe_gate_dtype = layer_weight.moe_gate.data_type_
+ _1_router_logits = layer_weight.moe_gate.mm(_1_input1.to(moe_gate_dtype))
# 0 dispatch execute
(
diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py
index 783e70e64..3eb09f917 100644
--- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py
+++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py
@@ -25,7 +25,7 @@ def _parse_config(self):
self.is_moe = (
self.network_config_["n_routed_experts"] is not None
and self.layer_num_ >= self.network_config_["first_k_dense_replace"]
- and self.layer_num_ % self.network_config_["moe_layer_freq"] == 0
+ and self.layer_num_ % self.network_config_.get("moe_layer_freq", 1) == 0
)
self.tp_q_head_num_ = self.network_config_["num_attention_heads"]
self.tp_q_head_num_ = self.tp_q_head_num_ // self.tp_world_size_
@@ -65,7 +65,9 @@ def _init_weight(self):
self._init_norm()
def _split_kv_b_proj(self, kv_b_proj_):
- kv_b_proj_ = kv_b_proj_.view(self.num_attention_heads, self.qk_nope_head_dim * 2, self.kv_lora_rank)
+ kv_b_proj_ = kv_b_proj_.view(
+ self.num_attention_heads, self.qk_nope_head_dim + self.v_head_dim, self.kv_lora_rank
+ )
k_b_proj_, v_b_proj_ = torch.split(kv_b_proj_, [self.qk_nope_head_dim, self.v_head_dim], dim=-2)
# num_attention_heads x qk_nope_head_dim x kv_lora_rank
k_b_proj_ = k_b_proj_.contiguous().to(kv_b_proj_.dtype)
diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py
index f0739a8a8..e596eed97 100644
--- a/lightllm/models/deepseek2/model.py
+++ b/lightllm/models/deepseek2/model.py
@@ -43,6 +43,7 @@ def _init_some_value(self):
self.qk_rope_head_dim = self.config["qk_rope_head_dim"]
self.q_lora_rank = self.config["q_lora_rank"]
self.kv_lora_rank = self.config["kv_lora_rank"]
+ self.v_head_dim = self.config.get("v_head_dim", self.qk_nope_head_dim)
self.head_dim_ = self.kv_lora_rank + self.qk_rope_head_dim
def _init_custom(self):
diff --git a/lightllm/models/glm4_moe_lite/__init__.py b/lightllm/models/glm4_moe_lite/__init__.py
new file mode 100644
index 000000000..b00657090
--- /dev/null
+++ b/lightllm/models/glm4_moe_lite/__init__.py
@@ -0,0 +1,4 @@
+from lightllm.models.glm4_moe_lite.model import Glm4MoeLiteTpPartModel
+from lightllm.models.glm4_moe_lite.infer_struct import Glm4MoeLiteInferStateInfo
+
+__all__ = ["Glm4MoeLiteTpPartModel", "Glm4MoeLiteInferStateInfo"]
diff --git a/lightllm/models/glm4_moe_lite/infer_struct.py b/lightllm/models/glm4_moe_lite/infer_struct.py
new file mode 100644
index 000000000..38d879d84
--- /dev/null
+++ b/lightllm/models/glm4_moe_lite/infer_struct.py
@@ -0,0 +1,6 @@
+from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
+
+
+class Glm4MoeLiteInferStateInfo(Deepseek2InferStateInfo):
+ def __init__(self):
+ super().__init__()
diff --git a/lightllm/models/glm4_moe_lite/layer_infer/__init__.py b/lightllm/models/glm4_moe_lite/layer_infer/__init__.py
new file mode 100644
index 000000000..a95580535
--- /dev/null
+++ b/lightllm/models/glm4_moe_lite/layer_infer/__init__.py
@@ -0,0 +1,3 @@
+from lightllm.models.glm4_moe_lite.layer_infer.transformer_layer_infer import Glm4MoeLiteTransformerLayerInfer
+
+__all__ = ["Glm4MoeLiteTransformerLayerInfer"]
diff --git a/lightllm/models/glm4_moe_lite/layer_infer/transformer_layer_infer.py b/lightllm/models/glm4_moe_lite/layer_infer/transformer_layer_infer.py
new file mode 100644
index 000000000..6d0041014
--- /dev/null
+++ b/lightllm/models/glm4_moe_lite/layer_infer/transformer_layer_infer.py
@@ -0,0 +1,17 @@
+from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer
+
+
+class Glm4MoeLiteTransformerLayerInfer(Deepseek2TransformerLayerInfer):
+ def __init__(self, layer_num, network_config):
+ self._glm4_layer_num = layer_num
+ self._glm4_first_k_dense = network_config.get("first_k_dense_replace", 0)
+ self._glm4_has_routed_experts = network_config.get("n_routed_experts") is not None
+ super().__init__(layer_num, network_config)
+
+ @property
+ def is_moe(self):
+ return self._glm4_has_routed_experts and self._glm4_layer_num >= self._glm4_first_k_dense
+
+ @is_moe.setter
+ def is_moe(self, value):
+ pass
diff --git a/lightllm/models/glm4_moe_lite/layer_weights/__init__.py b/lightllm/models/glm4_moe_lite/layer_weights/__init__.py
new file mode 100644
index 000000000..1fd5e36f8
--- /dev/null
+++ b/lightllm/models/glm4_moe_lite/layer_weights/__init__.py
@@ -0,0 +1,3 @@
+from lightllm.models.glm4_moe_lite.layer_weights.transformer_layer_weight import Glm4MoeLiteTransformerLayerWeight
+
+__all__ = ["Glm4MoeLiteTransformerLayerWeight"]
diff --git a/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py b/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py
new file mode 100644
index 000000000..92a5eabe0
--- /dev/null
+++ b/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py
@@ -0,0 +1,55 @@
+import torch
+from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight
+from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, FusedMoeWeight
+
+
+class Glm4MoeLiteTransformerLayerWeight(Deepseek2TransformerLayerWeight):
+ def __init__(self, layer_num, data_type, network_config, quant_cfg=None):
+ super().__init__(layer_num, data_type, network_config, quant_cfg)
+
+ def _parse_config(self):
+ super()._parse_config()
+
+ self.is_moe = self.network_config_.get(
+ "n_routed_experts"
+ ) is not None and self.layer_num_ >= self.network_config_.get("first_k_dense_replace", 0)
+
+ from lightllm.utils.envs_utils import get_env_start_args
+
+ self.num_fused_shared_experts = 0
+ if get_env_start_args().enable_fused_shared_experts and self.is_moe:
+ assert not get_env_start_args().enable_ep_moe, "enable_fused_shared_experts can only work with tp mode."
+ self.num_fused_shared_experts = self.network_config_.get("n_shared_experts", 0)
+
+ def _init_moe(self):
+ moe_intermediate_size = self.network_config_["moe_intermediate_size"]
+ hidden_size = self.network_config_["hidden_size"]
+
+ self.moe_gate = ROWMMWeight(
+ in_dim=hidden_size,
+ out_dims=[self.n_routed_experts],
+ weight_names=f"model.layers.{self.layer_num_}.mlp.gate.weight",
+ data_type=torch.float32, # Router gate needs float32 for numerical stability
+ quant_method=None,
+ tp_rank=0,
+ tp_world_size=1,
+ )
+
+ if self.num_fused_shared_experts == 0:
+ self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts", is_shared_experts=True)
+
+ self.experts = FusedMoeWeight(
+ gate_proj_name="gate_proj",
+ down_proj_name="down_proj",
+ up_proj_name="up_proj",
+ e_score_correction_bias_name=self.e_score_correction_bias_name,
+ weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts",
+ n_routed_experts=self.n_routed_experts,
+ hidden_size=hidden_size,
+ moe_intermediate_size=moe_intermediate_size,
+ data_type=self.data_type_,
+ quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"),
+ num_fused_shared_experts=self.num_fused_shared_experts,
+ layer_num=self.layer_num_,
+ network_config=self.network_config_,
+ )
diff --git a/lightllm/models/glm4_moe_lite/model.py b/lightllm/models/glm4_moe_lite/model.py
new file mode 100644
index 000000000..367928f2e
--- /dev/null
+++ b/lightllm/models/glm4_moe_lite/model.py
@@ -0,0 +1,54 @@
+import torch
+from lightllm.models.registry import ModelRegistry
+from lightllm.models.deepseek2.model import Deepseek2TpPartModel
+from lightllm.models.glm4_moe_lite.layer_infer.transformer_layer_infer import Glm4MoeLiteTransformerLayerInfer
+from lightllm.models.glm4_moe_lite.layer_weights.transformer_layer_weight import Glm4MoeLiteTransformerLayerWeight
+from lightllm.models.glm4_moe_lite.infer_struct import Glm4MoeLiteInferStateInfo
+from lightllm.distributed.communication_op import dist_group_manager
+from lightllm.utils.log_utils import init_logger
+
+logger = init_logger(__name__)
+
+
+@ModelRegistry("glm4_moe_lite")
+class Glm4MoeLiteTpPartModel(Deepseek2TpPartModel):
+
+ transformer_weight_class = Glm4MoeLiteTransformerLayerWeight
+ transformer_layer_infer_class = Glm4MoeLiteTransformerLayerInfer
+ infer_state_class = Glm4MoeLiteInferStateInfo
+
+ def __init__(self, kvargs):
+ super().__init__(kvargs)
+
+ def _init_config(self):
+ super()._init_config()
+ if "scoring_func" not in self.config:
+ self.config["scoring_func"] = "sigmoid"
+
+ def _init_custom(self):
+ self._init_to_get_yarn_rotary()
+ dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"])
+
+ def _init_to_get_yarn_rotary(self):
+ rope_scaling = self.config.get("rope_scaling")
+
+ if rope_scaling is None:
+ self._init_glm4_standard_rotary()
+ else:
+ super()._init_to_get_yarn_rotary()
+
+ def _init_glm4_standard_rotary(self):
+ rope_theta = self.config.get("rope_theta", 1000000.0)
+ qk_rope_head_dim = self.config.get("qk_rope_head_dim", 64)
+ max_position_embeddings = self.config.get("max_position_embeddings", 202752)
+
+ dim = qk_rope_head_dim
+
+ inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2, device="cpu", dtype=torch.float32) / dim))
+
+ max_seq_len = max(max_position_embeddings, self.max_seq_length)
+ t = torch.arange(max_seq_len, device="cpu", dtype=torch.float32)
+ freqs = torch.outer(t, inv_freq)
+
+ self._cos_cached = torch.cos(freqs).to(self.data_type).cuda()
+ self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
diff --git a/lightllm/models/glm4_moe_lite_mtp/__init__.py b/lightllm/models/glm4_moe_lite_mtp/__init__.py
new file mode 100644
index 000000000..96b6659c8
--- /dev/null
+++ b/lightllm/models/glm4_moe_lite_mtp/__init__.py
@@ -0,0 +1,3 @@
+from lightllm.models.glm4_moe_lite_mtp.model import Glm4MoeLiteMTPModel
+
+__all__ = ["Glm4MoeLiteMTPModel"]
diff --git a/lightllm/models/glm4_moe_lite_mtp/layer_infer/__init__.py b/lightllm/models/glm4_moe_lite_mtp/layer_infer/__init__.py
new file mode 100644
index 000000000..e357bfa19
--- /dev/null
+++ b/lightllm/models/glm4_moe_lite_mtp/layer_infer/__init__.py
@@ -0,0 +1,3 @@
+from lightllm.models.glm4_moe_lite_mtp.layer_infer.pre_layer_infer import Glm4MoeLiteMTPPreLayerInfer
+
+__all__ = ["Glm4MoeLiteMTPPreLayerInfer"]
diff --git a/lightllm/models/glm4_moe_lite_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/glm4_moe_lite_mtp/layer_infer/pre_layer_infer.py
new file mode 100644
index 000000000..1f2dc71d5
--- /dev/null
+++ b/lightllm/models/glm4_moe_lite_mtp/layer_infer/pre_layer_infer.py
@@ -0,0 +1,82 @@
+import torch
+
+from lightllm.models.glm4_moe_lite_mtp.layer_weights.pre_and_post_layer_weight import (
+ Glm4MoeLiteMTPPreAndPostLayerWeight,
+)
+from lightllm.models.glm4_moe_lite.infer_struct import Glm4MoeLiteInferStateInfo
+from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer
+
+
+class Glm4MoeLiteMTPPreLayerInfer(LlamaPreLayerInfer):
+ def __init__(self, network_config):
+ super().__init__(network_config)
+ self.eps_ = network_config["rms_norm_eps"]
+ self.hidden_size = network_config["hidden_size"]
+
+ def _mtp_context_forward(
+ self,
+ input_embdings,
+ infer_state: Glm4MoeLiteInferStateInfo,
+ layer_weight: Glm4MoeLiteMTPPreAndPostLayerWeight,
+ ):
+ tgt_embdings = infer_state.mtp_draft_input_hiddens
+ assert (
+ input_embdings.shape[0] == tgt_embdings.shape[0]
+ ), f"shape {input_embdings.shape} != shape {tgt_embdings.shape}"
+
+ layer_weight.enorm_weight_(
+ input=input_embdings,
+ eps=self.eps_,
+ out=input_embdings,
+ )
+ layer_weight.hnorm_weight_(
+ input=tgt_embdings,
+ eps=self.eps_,
+ out=tgt_embdings,
+ )
+ cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1)
+
+ ans_logics = layer_weight.eh_proj_weight_.mm(cat_embdings)
+ return ans_logics
+
+ def _mtp_token_forward(
+ self,
+ input_embdings,
+ infer_state: Glm4MoeLiteInferStateInfo,
+ layer_weight: Glm4MoeLiteMTPPreAndPostLayerWeight,
+ ):
+ tgt_embdings = infer_state.mtp_draft_input_hiddens
+ assert input_embdings.shape[0] == tgt_embdings.shape[0]
+
+ layer_weight.enorm_weight_(
+ input=input_embdings,
+ eps=self.eps_,
+ out=input_embdings,
+ )
+ layer_weight.hnorm_weight_(
+ input=tgt_embdings,
+ eps=self.eps_,
+ out=tgt_embdings,
+ )
+ cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1)
+
+ ans_logics = layer_weight.eh_proj_weight_.mm(cat_embdings)
+ return ans_logics
+
+ def context_forward(
+ self,
+ input_ids,
+ infer_state: Glm4MoeLiteInferStateInfo,
+ layer_weight: Glm4MoeLiteMTPPreAndPostLayerWeight,
+ ):
+ input_embdings = super().context_forward(input_ids, infer_state, layer_weight)
+ return self._mtp_context_forward(input_embdings, infer_state, layer_weight)
+
+ def token_forward(
+ self,
+ input_ids,
+ infer_state: Glm4MoeLiteInferStateInfo,
+ layer_weight: Glm4MoeLiteMTPPreAndPostLayerWeight,
+ ):
+ input_embdings = super().token_forward(input_ids, infer_state, layer_weight)
+ return self._mtp_token_forward(input_embdings, infer_state, layer_weight)
diff --git a/lightllm/models/glm4_moe_lite_mtp/layer_weights/__init__.py b/lightllm/models/glm4_moe_lite_mtp/layer_weights/__init__.py
new file mode 100644
index 000000000..57fe578cf
--- /dev/null
+++ b/lightllm/models/glm4_moe_lite_mtp/layer_weights/__init__.py
@@ -0,0 +1,5 @@
+from lightllm.models.glm4_moe_lite_mtp.layer_weights.pre_and_post_layer_weight import (
+ Glm4MoeLiteMTPPreAndPostLayerWeight,
+)
+
+__all__ = ["Glm4MoeLiteMTPPreAndPostLayerWeight"]
diff --git a/lightllm/models/glm4_moe_lite_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/glm4_moe_lite_mtp/layer_weights/pre_and_post_layer_weight.py
new file mode 100644
index 000000000..78b4d411c
--- /dev/null
+++ b/lightllm/models/glm4_moe_lite_mtp/layer_weights/pre_and_post_layer_weight.py
@@ -0,0 +1,48 @@
+from lightllm.common.basemodel import PreAndPostLayerWeight
+from lightllm.common.basemodel.layer_weights.meta_weights import (
+ EmbeddingWeight,
+ LMHeadWeight,
+ RMSNormWeight,
+ ROWMMWeight,
+)
+from lightllm.common.quantization import Quantcfg
+
+
+class Glm4MoeLiteMTPPreAndPostLayerWeight(PreAndPostLayerWeight):
+ def __init__(self, data_type, network_config, quant_cfg: Quantcfg):
+ super().__init__(data_type, network_config)
+ self.quant_cfg: Quantcfg = quant_cfg
+
+ mtp_layer_idx = network_config["num_hidden_layers"]
+ hidden_size = network_config["hidden_size"]
+
+ self.eh_proj_weight_ = ROWMMWeight(
+ in_dim=hidden_size * 2,
+ out_dims=[hidden_size],
+ weight_names=f"model.layers.{mtp_layer_idx}.eh_proj.weight",
+ data_type=self.data_type_,
+ quant_method=self.quant_cfg.get_quant_method(mtp_layer_idx, "eh_proj"),
+ tp_rank=0,
+ tp_world_size=1,
+ )
+
+ self.enorm_weight_ = RMSNormWeight(
+ dim=hidden_size,
+ weight_name=f"model.layers.{mtp_layer_idx}.enorm.weight",
+ data_type=self.data_type_,
+ )
+
+ self.hnorm_weight_ = RMSNormWeight(
+ dim=hidden_size,
+ weight_name=f"model.layers.{mtp_layer_idx}.hnorm.weight",
+ data_type=self.data_type_,
+ )
+
+ self.final_norm_weight_ = RMSNormWeight(
+ dim=hidden_size,
+ weight_name=f"model.layers.{mtp_layer_idx}.shared_head.norm.weight",
+ data_type=self.data_type_,
+ )
+
+ self.wte_weight_: EmbeddingWeight = None
+ self.lm_head_weight_: LMHeadWeight = None
diff --git a/lightllm/models/glm4_moe_lite_mtp/model.py b/lightllm/models/glm4_moe_lite_mtp/model.py
new file mode 100644
index 000000000..dea415abb
--- /dev/null
+++ b/lightllm/models/glm4_moe_lite_mtp/model.py
@@ -0,0 +1,90 @@
+from typing import List
+from lightllm.models.glm4_moe_lite.model import Glm4MoeLiteTpPartModel
+from lightllm.models.glm4_moe_lite_mtp.layer_infer.pre_layer_infer import Glm4MoeLiteMTPPreLayerInfer
+from lightllm.models.glm4_moe_lite_mtp.layer_weights.pre_and_post_layer_weight import (
+ Glm4MoeLiteMTPPreAndPostLayerWeight,
+)
+from lightllm.common.basemodel import TpPartBaseModel
+from lightllm.common.basemodel.basemodel import load_hf_weights
+
+
+class Glm4MoeLiteMTPModel(Glm4MoeLiteTpPartModel):
+
+ pre_and_post_weight_class = Glm4MoeLiteMTPPreAndPostLayerWeight
+ pre_layer_infer_class = Glm4MoeLiteMTPPreLayerInfer
+
+ def __init__(self, kvargs: dict):
+ self._pre_init(kvargs)
+ super().__init__(kvargs)
+
+ def _pre_init(self, kvargs: dict):
+ self.main_model: TpPartBaseModel = kvargs.pop("main_model")
+ self.mtp_previous_draft_models: List[TpPartBaseModel] = kvargs.pop("mtp_previous_draft_models")
+
+ def _init_custom(self):
+ self._cos_cached = self.main_model._cos_cached
+ self._sin_cached = self.main_model._sin_cached
+
+ def _init_req_manager(self):
+ self.req_manager = self.main_model.req_manager
+
+ def _init_mem_manager(self):
+ self.mem_manager = self.main_model.mem_manager
+
+ def _init_weights(self, start_layer_index=None):
+ assert start_layer_index is None
+
+ mtp_layer_start = self.config["num_hidden_layers"]
+ num_mtp_layers = self.config.get("num_nextn_predict_layers", 1)
+
+ self.pre_post_weight = self.pre_and_post_weight_class(
+ self.data_type, network_config=self.config, quant_cfg=self.quant_cfg
+ )
+
+ self.trans_layers_weight = [
+ self.transformer_weight_class(
+ i,
+ self.data_type,
+ network_config=self.config,
+ quant_cfg=self.quant_cfg,
+ )
+ for i in range(mtp_layer_start, mtp_layer_start + num_mtp_layers)
+ ]
+
+ load_hf_weights(
+ self.data_type,
+ weight_dir=self.weight_dir_,
+ pre_post_layer=self.pre_post_weight,
+ transformer_layer_list=self.trans_layers_weight,
+ weight_dict=self.weight_dict,
+ )
+
+ self.pre_post_weight.verify_load()
+ [weight.verify_load() for weight in self.trans_layers_weight]
+
+ self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_
+ self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_
+
+ def _init_infer_layer(self, start_layer_index=None):
+ assert start_layer_index is None
+
+ self.pre_infer = self.pre_layer_infer_class(network_config=self.config)
+ self.post_infer = self.post_layer_infer_class(network_config=self.config)
+
+ total_pre_layers_num = len(self.main_model.layers_infer)
+ total_pre_layers_num += sum(
+ [len(previous_model.layers_infer) for previous_model in self.mtp_previous_draft_models]
+ )
+
+ num_mtp_layers = self.config.get("num_nextn_predict_layers", 1)
+ self.layers_infer = [
+ self.transformer_layer_infer_class(i, network_config=self.config)
+ for i in range(total_pre_layers_num, total_pre_layers_num + num_mtp_layers)
+ ]
+
+ def _init_some_value(self):
+ super()._init_some_value()
+ self.layers_num = self.config.get("num_nextn_predict_layers", 1)
+
+ def autotune_layers(self):
+ return self.config.get("num_nextn_predict_layers", 1)
diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py
index 3cf04a14c..73b9bad4a 100644
--- a/lightllm/server/api_cli.py
+++ b/lightllm/server/api_cli.py
@@ -128,7 +128,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--tool_call_parser",
type=str,
- choices=["qwen25", "llama3", "mistral", "deepseekv3", "qwen", "deepseekv31"],
+ choices=["qwen25", "llama3", "mistral", "deepseekv3", "qwen", "deepseekv31", "glm47", "kimi_k2"],
default=None,
help="tool call parser type",
)
@@ -259,7 +259,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
)
parser.add_argument("--disable_dynamic_prompt_cache", action="store_true", help="disable dynamic prompt cache")
- parser.add_argument("--chunked_prefill_size", type=int, default=None, help="chunked prefill size")
+ parser.add_argument("--chunked_prefill_size", type=int, default=8192, help="chunked prefill size")
parser.add_argument("--disable_chunked_prefill", action="store_true", help="whether to disable chunked prefill")
parser.add_argument("--diverse_mode", action="store_true", help="diversity generation mode")
parser.add_argument("--token_healing_mode", action="store_true", help="code model infer mode")
diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py
index 1620cff13..9214715b1 100644
--- a/lightllm/server/function_call_parser.py
+++ b/lightllm/server/function_call_parser.py
@@ -241,7 +241,7 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami
if start_idx >= len(current_text):
return StreamingParseResult()
- (obj, end_idx) = _partial_json_loads(current_text[start_idx:], flags)
+ obj, end_idx = _partial_json_loads(current_text[start_idx:], flags)
is_current_complete = _is_complete_json(current_text[start_idx : start_idx + end_idx])
@@ -1173,6 +1173,276 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami
return StreamingParseResult(normal_text=current_text)
+class Glm47Detector(BaseFormatDetector):
+ """
+ Detector for GLM-4.7/GLM-4.7-Flash model function call format.
+
+ The GLM-4.7 format uses an XML-style envelope with arg_key/arg_value pairs
+ instead of JSON arguments.
+
+ Format Structure:
+ ```
+ function_name
+ param1
+ value1
+ param2
+ value2
+
+ ```
+
+ Example:
+ ```
+ tool_brave_web_search_post
+ query
+ test search
+ count
+ 5
+
+ ```
+
+ Key Components:
+ - Tool Call Tags: `` and `` wrap each individual call
+ - Function Name: Appears on the first line after ``
+ - Arguments: Pairs of `name` and `value`
+
+ Reference: https://github.com/vllm-project/vllm/blob/main/vllm/tool_parsers/glm4_moe_tool_parser.py
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.bot_token = ""
+ self.eot_token = ""
+ self.tool_call_separator = "\n"
+
+ # Regex patterns for parsing GLM-4.7 tool calls
+ # Match complete tool call blocks
+ self.func_call_regex = re.compile(r".*?", re.DOTALL)
+ # Extract function name and arguments from a tool call block
+ # Function name can be followed by newline OR directly by
+ # Pattern: function_name(\n|)...
+ self.func_detail_regex = re.compile(
+ r"([^<\n]+?)(?:\n|(?=)|(?=))(.*?)", re.DOTALL
+ )
+ # Extract arg_key/arg_value pairs
+ self.func_arg_regex = re.compile(r"(.*?)\s*(.*?)", re.DOTALL)
+
+ self._last_arguments = ""
+ self._normal_text_buffer = ""
+
+ def has_tool_call(self, text: str) -> bool:
+ """Check if the text contains a GLM-4.7 format tool call."""
+ return self.bot_token in text
+
+ def _parse_xml_arguments(self, arg_text: str) -> dict:
+ """
+ Parse XML-style arguments into a dictionary.
+
+ Args:
+ arg_text: The text containing / pairs
+
+ Returns:
+ Dictionary of argument name to value
+ """
+ if not arg_text:
+ return {}
+
+ args = {}
+ matches = self.func_arg_regex.findall(arg_text)
+ for key, value in matches:
+ key = key.strip()
+ value = value.strip()
+ # Try to parse value as JSON for complex types (arrays, objects, numbers, booleans)
+ try:
+ parsed_value = json.loads(value)
+ args[key] = parsed_value
+ except (json.JSONDecodeError, ValueError):
+ # Keep as string if not valid JSON
+ args[key] = value
+ return args
+
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
+ """
+ One-time parsing: Detects and parses tool calls in the provided text.
+
+ :param text: The complete text to parse.
+ :param tools: List of available tools.
+ :return: StreamingParseResult with normal_text and parsed calls.
+ """
+ idx = text.find(self.bot_token)
+ normal_text = text[:idx].strip() if idx != -1 else text
+
+ if self.bot_token not in text:
+ return StreamingParseResult(normal_text=normal_text, calls=[])
+
+ tool_indices = self._get_tool_indices(tools)
+ calls = []
+
+ # Find all ... blocks
+ match_result_list = self.func_call_regex.findall(text)
+
+ for match_result in match_result_list:
+ try:
+ # Extract function name and arguments
+ func_detail = self.func_detail_regex.search(match_result)
+ if not func_detail:
+ logger.warning(f"Failed to parse GLM-4.7 tool call: {match_result}")
+ continue
+
+ func_name = func_detail.group(1).strip()
+ arg_text = func_detail.group(2) if func_detail.group(2) else ""
+
+ # Validate function name
+ if func_name not in tool_indices:
+ logger.warning(f"Model attempted to call undefined function: {func_name}")
+ continue
+
+ # Parse XML arguments to JSON
+ func_args = self._parse_xml_arguments(arg_text)
+
+ calls.append(
+ ToolCallItem(
+ tool_index=tool_indices[func_name],
+ name=func_name,
+ parameters=json.dumps(func_args, ensure_ascii=False),
+ )
+ )
+ except Exception as e:
+ logger.warning(f"Failed to parse GLM-4.7 tool call: {match_result}, error: {str(e)}")
+ continue
+
+ return StreamingParseResult(normal_text=normal_text, calls=calls)
+
+ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> StreamingParseResult:
+ """
+ Streaming incremental parsing for GLM-4.7 tool calls.
+
+ This handles the streaming case where tool calls arrive incrementally.
+ """
+ self._buffer += new_text
+ current_text = self._buffer
+
+ # Check if we have a tool call starting
+ if not self.has_tool_call(current_text):
+ # Check for partial bot_token at the end
+ partial_len = self._ends_with_partial_token(current_text, self.bot_token)
+ if partial_len:
+ # Might be partial bot_token, keep buffering
+ return StreamingParseResult()
+
+ # No tool call, emit as normal text
+ self._buffer = ""
+ # Clean up any stray end tokens
+ if self.eot_token in new_text:
+ new_text = new_text.replace(self.eot_token, "")
+ return StreamingParseResult(normal_text=new_text)
+
+ # Build tool indices if not already built
+ if not hasattr(self, "_tool_indices"):
+ self._tool_indices = self._get_tool_indices(tools)
+
+ calls: List[ToolCallItem] = []
+
+ try:
+ # Check if we have a complete tool call
+ if self.eot_token in current_text:
+ # We have at least one complete tool call
+ # Parse all complete tool calls
+ result = self.detect_and_parse(current_text, tools)
+
+ # Find the end of the last complete tool call
+ last_end = current_text.rfind(self.eot_token)
+ if last_end != -1:
+ remaining = current_text[last_end + len(self.eot_token) :]
+ self._buffer = remaining.lstrip()
+ else:
+ self._buffer = ""
+
+ # Reset state for next tool call
+ self.current_tool_id = -1
+ self.current_tool_name_sent = False
+ self._last_arguments = ""
+
+ return result
+
+ # We have a partial tool call - try to stream it
+ # Extract what we can from the partial tool call
+ tool_call_start = current_text.find(self.bot_token)
+ if tool_call_start == -1:
+ return StreamingParseResult()
+
+ # Get content after
+ content_after_start = current_text[tool_call_start + len(self.bot_token) :]
+
+ # Try to extract function name (first line after )
+ newline_pos = content_after_start.find("\n")
+ if newline_pos == -1:
+ # Still waiting for function name to complete
+ return StreamingParseResult()
+
+ func_name = content_after_start[:newline_pos].strip()
+
+ # Initialize state if this is the first tool call
+ if self.current_tool_id == -1:
+ self.current_tool_id = 0
+ self.prev_tool_call_arr = []
+ self.streamed_args_for_tool = [""]
+
+ # Ensure we have enough entries
+ while len(self.prev_tool_call_arr) <= self.current_tool_id:
+ self.prev_tool_call_arr.append({})
+ while len(self.streamed_args_for_tool) <= self.current_tool_id:
+ self.streamed_args_for_tool.append("")
+
+ # Check if function name is valid
+ if func_name and func_name in self._tool_indices:
+ if not self.current_tool_name_sent:
+ # Send function name first
+ calls.append(
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ name=func_name,
+ parameters="",
+ )
+ )
+ self.current_tool_name_sent = True
+ self.prev_tool_call_arr[self.current_tool_id] = {
+ "name": func_name,
+ "arguments": {},
+ }
+ else:
+ # Stream arguments incrementally
+ arg_text = content_after_start[newline_pos + 1 :]
+ current_args = self._parse_xml_arguments(arg_text)
+
+ if current_args:
+ current_args_json = json.dumps(current_args, ensure_ascii=False)
+ prev_args = self.prev_tool_call_arr[self.current_tool_id].get("arguments", {})
+ prev_args_json = json.dumps(prev_args, ensure_ascii=False) if prev_args else ""
+
+ if current_args_json != prev_args_json:
+ # Calculate the diff
+ sent = len(self.streamed_args_for_tool[self.current_tool_id])
+ argument_diff = current_args_json[sent:]
+
+ if argument_diff:
+ calls.append(
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ name=None,
+ parameters=argument_diff,
+ )
+ )
+ self.streamed_args_for_tool[self.current_tool_id] += argument_diff
+
+ self.prev_tool_call_arr[self.current_tool_id]["arguments"] = current_args
+
+ return StreamingParseResult(normal_text="", calls=calls)
+
+ except Exception as e:
+ logger.error(f"Error in GLM-4.7 parse_streaming_increment: {e}")
+ return StreamingParseResult(normal_text="", calls=calls)
+
+
class FunctionCallParser:
"""
Parser for function/tool calls in model outputs.
@@ -1185,6 +1455,7 @@ class FunctionCallParser:
ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = {
"deepseekv3": DeepSeekV3Detector,
"deepseekv31": DeepSeekV31Detector,
+ "glm47": Glm47Detector,
"kimi_k2": KimiK2Detector,
"llama3": Llama32Detector,
"mistral": MistralDetector,
diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py
index 805c9b8e5..64310d6b0 100644
--- a/lightllm/server/router/model_infer/mode_backend/base_backend.py
+++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py
@@ -37,6 +37,7 @@
from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel
from lightllm.models.qwen3_moe_mtp.model import Qwen3MOEMTPModel
from lightllm.models.mistral_mtp.model import MistralMTPModel
+from lightllm.models.glm4_moe_lite_mtp.model import Glm4MoeLiteMTPModel
from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample
from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token
from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet
@@ -328,6 +329,9 @@ def init_mtp_draft_model(self, main_kvargs: dict):
elif mtp_model_cfg["model_type"] == "mistral":
assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"]
self.draft_models.append(MistralMTPModel(mtp_model_kvargs))
+ elif mtp_model_cfg["model_type"] == "glm4_moe_lite":
+ assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"]
+ self.draft_models.append(Glm4MoeLiteMTPModel(mtp_model_kvargs))
else:
assert False, f"error mtp mode {mtp_model_cfg['model_type']}"
diff --git a/test/acc/bfcl/eval_bfcl.py b/test/acc/bfcl/eval_bfcl.py
new file mode 100644
index 000000000..59b81c26d
--- /dev/null
+++ b/test/acc/bfcl/eval_bfcl.py
@@ -0,0 +1,686 @@
+#!/usr/bin/env python3
+"""
+BFCL (Berkeley Function Calling Leaderboard) Evaluation Script for LightLLM
+
+This script evaluates function/tool calling capabilities on the BFCL benchmark.
+
+Usage:
+ # Start LightLLM server first:
+ python -m lightllm.server.api_server --model_dir /path/to/GLM-4.7-Flash --tp 1
+
+ # Run evaluation:
+ python eval_bfcl.py \
+ --model_name GLM-4.7-Flash \
+ --base_url http://localhost:8000/v1 \
+ --test_category simple
+
+Test Categories:
+ - simple: Single function calls (400 examples)
+ - multiple: Select one function from multiple options (200 examples)
+ - parallel: Multiple function calls in parallel (200 examples)
+ - parallel_multiple: Combination of parallel and multiple (200 examples)
+ - java: Java function calls (100 examples)
+ - javascript: JavaScript function calls (70 examples)
+ - irrelevance: Detect when no function should be called
+ - all: Run all categories
+
+Requirements:
+ pip install openai tqdm huggingface_hub
+"""
+
+import argparse
+import json
+import os
+import re
+import ast
+import time
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from typing import List, Dict, Any, Optional, Tuple
+from dataclasses import dataclass
+from collections import defaultdict
+
+from tqdm import tqdm
+
+try:
+ from openai import OpenAI
+except ImportError:
+ print("Please install openai: pip install openai")
+ exit(1)
+
+try:
+ from huggingface_hub import hf_hub_download
+except ImportError:
+ print("Please install huggingface_hub: pip install huggingface_hub")
+ exit(1)
+
+
+# BFCL Dataset on HuggingFace
+BFCL_REPO = "gorilla-llm/Berkeley-Function-Calling-Leaderboard"
+
+# Test category mappings to filenames
+TEST_CATEGORIES = {
+ "simple": "BFCL_v3_simple.json",
+ "multiple": "BFCL_v3_multiple.json",
+ "parallel": "BFCL_v3_parallel.json",
+ "parallel_multiple": "BFCL_v3_parallel_multiple.json",
+ "java": "BFCL_v3_java.json",
+ "javascript": "BFCL_v3_javascript.json",
+ "irrelevance": "BFCL_v3_irrelevance.json",
+ "live_simple": "BFCL_v3_live_simple.json",
+ "live_multiple": "BFCL_v3_live_multiple.json",
+ "live_parallel": "BFCL_v3_live_parallel.json",
+ "live_parallel_multiple": "BFCL_v3_live_parallel_multiple.json",
+ "rest": "BFCL_v3_rest.json",
+ "sql": "BFCL_v3_sql.json",
+}
+
+# Possible answer files for ground truth
+ANSWER_FILES = {
+ "simple": "possible_answer/BFCL_v3_simple.json",
+ "multiple": "possible_answer/BFCL_v3_multiple.json",
+ "parallel": "possible_answer/BFCL_v3_parallel.json",
+ "parallel_multiple": "possible_answer/BFCL_v3_parallel_multiple.json",
+ "java": "possible_answer/BFCL_v3_java.json",
+ "javascript": "possible_answer/BFCL_v3_javascript.json",
+ "live_simple": "possible_answer/BFCL_v3_live_simple.json",
+ "live_multiple": "possible_answer/BFCL_v3_live_multiple.json",
+ "live_parallel": "possible_answer/BFCL_v3_live_parallel.json",
+ "live_parallel_multiple": "possible_answer/BFCL_v3_live_parallel_multiple.json",
+ "sql": "possible_answer/BFCL_v3_sql.json",
+}
+
+
+@dataclass
+class EvalResult:
+ """Result of a single evaluation."""
+
+ task_id: str
+ category: str
+ passed: bool
+ model_output: str
+ expected: Any
+ error: Optional[str] = None
+
+
+def download_bfcl_file(filename: str) -> str:
+ """Download a BFCL file from HuggingFace Hub."""
+ try:
+ local_path = hf_hub_download(
+ repo_id=BFCL_REPO,
+ filename=filename,
+ repo_type="dataset",
+ )
+ return local_path
+ except Exception as e:
+ print(f"Error downloading {filename}: {e}")
+ return None
+
+
+def load_jsonl_or_json(filepath: str) -> List[Dict[str, Any]]:
+ """Load data from JSON or JSONL file."""
+ data = []
+ with open(filepath, "r", encoding="utf-8") as f:
+ content = f.read().strip()
+ # Try as JSON array first
+ try:
+ data = json.loads(content)
+ if isinstance(data, dict):
+ data = [data]
+ except json.JSONDecodeError:
+ # Try as JSONL
+ f.seek(0)
+ for line in f:
+ line = line.strip()
+ if line:
+ try:
+ data.append(json.loads(line))
+ except json.JSONDecodeError:
+ continue
+ return data
+
+
+def load_bfcl_data(category: str, limit: Optional[int] = None) -> List[Dict[str, Any]]:
+ """Load BFCL dataset for a specific category."""
+ filename = TEST_CATEGORIES.get(category)
+ if not filename:
+ print(f"Unknown category: {category}")
+ return []
+
+ print(f"Downloading {filename} from HuggingFace...")
+ filepath = download_bfcl_file(filename)
+ if not filepath:
+ return []
+
+ print(f"Loading data from {filepath}")
+ data = load_jsonl_or_json(filepath)
+
+ # Also load ground truth answers if available
+ answer_file = ANSWER_FILES.get(category)
+ if answer_file:
+ print(f"Downloading answer file {answer_file}...")
+ answer_path = download_bfcl_file(answer_file)
+ if answer_path:
+ answers = load_jsonl_or_json(answer_path)
+ # Create a mapping from id to answer
+ answer_map = {}
+ for ans in answers:
+ ans_id = ans.get("id", "")
+ answer_map[ans_id] = ans.get("ground_truth", ans.get("result", []))
+
+ # Merge answers into data
+ for item in data:
+ item_id = item.get("id", "")
+ if item_id in answer_map:
+ item["ground_truth"] = answer_map[item_id]
+
+ if limit:
+ data = data[:limit]
+
+ print(f"Loaded {len(data)} examples for category: {category}")
+ return data
+
+
+def fix_schema_types(schema: Any) -> Any:
+ """
+ Fix Python type names to JSON Schema types.
+ BFCL uses Python type names like 'dict', 'list' but JSON Schema needs 'object', 'array'.
+ """
+ if isinstance(schema, dict):
+ result = {}
+ for key, value in schema.items():
+ if key == "type" and isinstance(value, str):
+ # Map Python types to JSON Schema types
+ type_mapping = {
+ "dict": "object",
+ "list": "array",
+ "str": "string",
+ "int": "integer",
+ "float": "number",
+ "bool": "boolean",
+ "NoneType": "null",
+ "tuple": "array",
+ }
+ result[key] = type_mapping.get(value, value)
+ else:
+ result[key] = fix_schema_types(value)
+ return result
+ elif isinstance(schema, list):
+ return [fix_schema_types(item) for item in schema]
+ else:
+ return schema
+
+
+def convert_to_openai_tools(functions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """Convert BFCL function format to OpenAI tools format."""
+ tools = []
+ for func in functions:
+ if isinstance(func, str):
+ func = json.loads(func)
+
+ # Fix the parameters schema to use valid JSON Schema types
+ parameters = fix_schema_types(func.get("parameters", {}))
+
+ tool = {
+ "type": "function",
+ "function": {
+ "name": func.get("name", ""),
+ "description": func.get("description", ""),
+ "parameters": parameters,
+ },
+ }
+ tools.append(tool)
+ return tools
+
+
+def parse_function_call(response: str) -> List[Dict[str, Any]]:
+ """Parse function calls from model response."""
+ calls = []
+
+ # Try to parse as JSON array
+ try:
+ parsed = json.loads(response)
+ if isinstance(parsed, list):
+ return parsed
+ elif isinstance(parsed, dict):
+ return [parsed]
+ except json.JSONDecodeError:
+ pass
+
+ # Try to find function call patterns
+ # Pattern 1: function_name(args)
+ func_pattern = r"(\w+)\s*\((.*?)\)"
+ matches = re.findall(func_pattern, response, re.DOTALL)
+ for name, args_str in matches:
+ try:
+ # Try to parse args as Python dict/kwargs
+ args_str = args_str.strip()
+ if args_str:
+ # Convert to dict format
+ args = eval(f"dict({args_str})")
+ else:
+ args = {}
+ calls.append({"name": name, "arguments": args})
+ except:
+ pass
+
+ # Pattern 2: JSON-like tool_calls
+ tool_call_pattern = r'\{"name":\s*"([^"]+)",\s*"arguments":\s*(\{[^}]+\})\}'
+ matches = re.findall(tool_call_pattern, response)
+ for name, args_str in matches:
+ try:
+ args = json.loads(args_str)
+ calls.append({"name": name, "arguments": args})
+ except:
+ pass
+
+ return calls
+
+
+def extract_tool_calls_from_response(response) -> List[Dict[str, Any]]:
+ """Extract tool calls from OpenAI API response."""
+ calls = []
+
+ if hasattr(response, "choices") and response.choices:
+ choice = response.choices[0]
+ message = choice.message
+
+ # Check for tool_calls in response
+ if hasattr(message, "tool_calls") and message.tool_calls:
+ for tool_call in message.tool_calls:
+ func = tool_call.function
+ try:
+ args = json.loads(func.arguments) if func.arguments else {}
+ except json.JSONDecodeError:
+ args = {}
+ calls.append({"name": func.name, "arguments": args})
+
+ # Also check content for function calls (some models output in content)
+ if hasattr(message, "content") and message.content:
+ content_calls = parse_function_call(message.content)
+ if content_calls and not calls:
+ calls = content_calls
+
+ return calls
+
+
+def normalize_value(value: Any) -> Any:
+ """Normalize values for comparison."""
+ if isinstance(value, str):
+ # Try to parse as number
+ try:
+ return float(value)
+ except ValueError:
+ return value.lower().strip()
+ elif isinstance(value, bool):
+ return value
+ elif isinstance(value, (int, float)):
+ return float(value)
+ elif isinstance(value, list):
+ return [normalize_value(v) for v in value]
+ elif isinstance(value, dict):
+ return {k: normalize_value(v) for k, v in value.items()}
+ return value
+
+
+def value_matches_expected(predicted_value: Any, expected_values: Any) -> bool:
+ """
+ Check if predicted value matches expected value(s).
+ BFCL format: expected values can be a list of acceptable values.
+ """
+ # Normalize predicted value
+ pred_normalized = normalize_value(predicted_value)
+
+ # If expected is a list, check if predicted matches any item
+ if isinstance(expected_values, list):
+ for exp_val in expected_values:
+ exp_normalized = normalize_value(exp_val)
+ if pred_normalized == exp_normalized:
+ return True
+ # Also try string comparison for edge cases
+ if str(pred_normalized) == str(exp_normalized):
+ return True
+ return False
+ else:
+ exp_normalized = normalize_value(expected_values)
+ return pred_normalized == exp_normalized or str(pred_normalized) == str(exp_normalized)
+
+
+def compare_function_calls(
+ predicted: List[Dict[str, Any]], expected: List[Dict[str, Any]], strict: bool = False
+) -> Tuple[bool, str]:
+ """Compare predicted function calls with expected ones."""
+ if not predicted and not expected:
+ return True, ""
+
+ if len(predicted) != len(expected):
+ return False, f"Count mismatch: predicted {len(predicted)}, expected {len(expected)}"
+
+ # Sort by function name for comparison
+ pred_sorted = sorted(predicted, key=lambda x: x.get("name", ""))
+ exp_sorted = sorted(expected, key=lambda x: x.get("name", ""))
+
+ for pred, exp in zip(pred_sorted, exp_sorted):
+ pred_name = pred.get("name", "")
+ exp_name = exp.get("name", "")
+
+ if pred_name != exp_name:
+ return False, f"Function name mismatch: {pred_name} vs {exp_name}"
+
+ pred_args = pred.get("arguments", {})
+ exp_args = exp.get("arguments", {})
+
+ # Check required arguments match (BFCL format: values are lists of acceptable values)
+ for key, expected_values in exp_args.items():
+ if key not in pred_args:
+ return False, f"Missing argument {key} in {pred_name}"
+ if not value_matches_expected(pred_args[key], expected_values):
+ return False, f"Argument {key} mismatch in {pred_name}"
+
+ return True, ""
+
+
+def parse_expected_output(ground_truth: Any) -> List[Dict[str, Any]]:
+ """
+ Parse expected output from BFCL ground truth.
+
+ BFCL format: [{"func_name": {"arg1": [val1, val2], "arg2": [val3]}}]
+ Convert to: [{"name": "func_name", "arguments": {"arg1": [val1, val2], "arg2": [val3]}}]
+ """
+ if isinstance(ground_truth, str):
+ try:
+ ground_truth = json.loads(ground_truth)
+ except json.JSONDecodeError:
+ # Try parsing as Python literal
+ try:
+ ground_truth = ast.literal_eval(ground_truth)
+ except:
+ return []
+
+ if not ground_truth:
+ return []
+
+ # Ensure it's a list
+ if isinstance(ground_truth, dict):
+ ground_truth = [ground_truth]
+
+ result = []
+ for item in ground_truth:
+ if isinstance(item, dict):
+ # Check if it's already in standard format {"name": ..., "arguments": ...}
+ if "name" in item and "arguments" in item:
+ result.append(item)
+ else:
+ # BFCL format: {"func_name": {"arg1": [v1], "arg2": [v2]}}
+ for func_name, args in item.items():
+ if isinstance(args, dict):
+ result.append({"name": func_name, "arguments": args})
+ else:
+ # Handle edge case where args might not be a dict
+ result.append({"name": func_name, "arguments": {}})
+
+ return result
+
+
+class BFCLEvaluator:
+ """BFCL Benchmark Evaluator using OpenAI-compatible API."""
+
+ def __init__(
+ self,
+ base_url: str,
+ model_name: str,
+ api_key: str = "EMPTY",
+ max_tokens: int = 1024,
+ temperature: float = 0.0,
+ ):
+ self.client = OpenAI(base_url=base_url, api_key=api_key)
+ self.model_name = model_name
+ self.max_tokens = max_tokens
+ self.temperature = temperature
+
+ def generate_response(
+ self, prompt: str, tools: List[Dict[str, Any]], system_prompt: Optional[str] = None
+ ) -> Tuple[Any, List[Dict[str, Any]]]:
+ """Generate response from the model with tool calling."""
+ messages = []
+
+ if system_prompt:
+ messages.append({"role": "system", "content": system_prompt})
+
+ messages.append({"role": "user", "content": prompt})
+
+ try:
+ response = self.client.chat.completions.create(
+ model=self.model_name,
+ messages=messages,
+ tools=tools if tools else None,
+ tool_choice="auto" if tools else None,
+ max_tokens=self.max_tokens,
+ temperature=self.temperature,
+ )
+ tool_calls = extract_tool_calls_from_response(response)
+ return response, tool_calls
+ except Exception as e:
+ print(f"API Error: {e}")
+ return None, []
+
+ def evaluate_single(self, item: Dict[str, Any], category: str) -> EvalResult:
+ """Evaluate a single BFCL example."""
+ task_id = item.get("id", "unknown")
+
+ # Extract question and functions
+ question = item.get("question", [[{"role": "user", "content": ""}]])
+ if isinstance(question, str):
+ prompt = question
+ elif isinstance(question, list) and question:
+ if isinstance(question[0], dict):
+ prompt = question[0].get("content", "")
+ elif isinstance(question[0], list) and question[0]:
+ prompt = question[0][0].get("content", "")
+ else:
+ prompt = str(question[0])
+ else:
+ prompt = str(question)
+
+ # Get functions
+ functions = item.get("function", [])
+ if isinstance(functions, str):
+ try:
+ functions = json.loads(functions)
+ except:
+ functions = []
+
+ if not isinstance(functions, list):
+ functions = [functions]
+
+ # Convert to OpenAI tools format
+ tools = convert_to_openai_tools(functions)
+
+ # Get expected output
+ ground_truth = item.get("ground_truth", item.get("answer", []))
+ expected = parse_expected_output(ground_truth)
+
+ # Generate response
+ system_prompt = (
+ "You are a helpful assistant that can use tools/functions to help answer questions. "
+ "When you need to call a function, use the provided tools."
+ )
+
+ response, predicted_calls = self.generate_response(prompt, tools, system_prompt)
+
+ if response is None:
+ return EvalResult(
+ task_id=task_id,
+ category=category,
+ passed=False,
+ model_output="",
+ expected=expected,
+ error="API call failed",
+ )
+
+ # For irrelevance category, model should NOT call any function
+ if "irrelevance" in category.lower():
+ passed = len(predicted_calls) == 0
+ error = "Model called function when it shouldn't" if not passed else None
+ else:
+ # Compare function calls
+ passed, error = compare_function_calls(predicted_calls, expected)
+
+ model_output = json.dumps(predicted_calls, indent=2) if predicted_calls else str(response)
+
+ return EvalResult(
+ task_id=task_id, category=category, passed=passed, model_output=model_output, expected=expected, error=error
+ )
+
+ def evaluate_category(self, category: str, limit: Optional[int] = None, num_workers: int = 4) -> Dict[str, Any]:
+ """Evaluate all examples in a category."""
+ print(f"\nLoading BFCL dataset for category: {category}")
+ data = load_bfcl_data(category, limit)
+
+ if not data:
+ print(f"No data found for category: {category}")
+ return {"category": category, "total": 0, "passed": 0, "accuracy": 0.0}
+
+ print(f"Loaded {len(data)} examples")
+
+ results = []
+
+ # Use ThreadPoolExecutor for concurrent evaluation
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
+ futures = {executor.submit(self.evaluate_single, item, category): item for item in data}
+
+ for future in tqdm(as_completed(futures), total=len(futures), desc=f"Evaluating {category}"):
+ try:
+ result = future.result()
+ results.append(result)
+ except Exception as e:
+ print(f"Error evaluating: {e}")
+
+ # Calculate metrics
+ total = len(results)
+ passed = sum(1 for r in results if r.passed)
+ accuracy = passed / total * 100 if total > 0 else 0.0
+
+ # Collect errors for analysis
+ errors = defaultdict(int)
+ for r in results:
+ if not r.passed and r.error:
+ errors[r.error[:50]] += 1
+
+ return {
+ "category": category,
+ "total": total,
+ "passed": passed,
+ "accuracy": accuracy,
+ "results": results,
+ "error_summary": dict(errors),
+ }
+
+
+def main():
+ parser = argparse.ArgumentParser(description="BFCL Evaluation for LightLLM")
+ parser.add_argument("--model_name", type=str, required=True, help="Model name")
+ parser.add_argument(
+ "--base_url", type=str, default="http://localhost:8000/v1", help="OpenAI-compatible API base URL"
+ )
+ parser.add_argument("--api_key", type=str, default="EMPTY", help="API key (use EMPTY for local)")
+ parser.add_argument(
+ "--test_category",
+ type=str,
+ default="simple",
+ choices=list(TEST_CATEGORIES.keys()) + ["all"],
+ help="Test category to evaluate",
+ )
+ parser.add_argument("--limit", type=int, default=None, help="Limit number of examples (for testing)")
+ parser.add_argument("--num_workers", type=int, default=4, help="Number of concurrent workers")
+ parser.add_argument("--max_tokens", type=int, default=1024, help="Maximum tokens to generate")
+ parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature")
+ parser.add_argument("--output", "-o", type=str, default=None, help="Output file for detailed results")
+
+ args = parser.parse_args()
+
+ print("=" * 60)
+ print("BFCL (Berkeley Function Calling Leaderboard) Evaluation")
+ print("=" * 60)
+ print(f"Model: {args.model_name}")
+ print(f"API URL: {args.base_url}")
+ print(f"Test Category: {args.test_category}")
+ print()
+
+ evaluator = BFCLEvaluator(
+ base_url=args.base_url,
+ model_name=args.model_name,
+ api_key=args.api_key,
+ max_tokens=args.max_tokens,
+ temperature=args.temperature,
+ )
+
+ # Determine categories to evaluate
+ if args.test_category == "all":
+ categories = list(TEST_CATEGORIES.keys())
+ else:
+ categories = [args.test_category]
+
+ all_results = {}
+
+ for category in categories:
+ result = evaluator.evaluate_category(category, limit=args.limit, num_workers=args.num_workers)
+ all_results[category] = result
+
+ print(f"\n{category.upper()} Results:")
+ print(f" Total: {result['total']}")
+ print(f" Passed: {result['passed']}")
+ print(f" Accuracy: {result['accuracy']:.2f}%")
+
+ if result.get("error_summary"):
+ print(" Common errors:")
+ for error, count in sorted(result["error_summary"].items(), key=lambda x: -x[1])[:5]:
+ print(f" - {error}: {count}")
+
+ # Print summary
+ print("\n" + "=" * 60)
+ print("SUMMARY")
+ print("=" * 60)
+ print(f"{'Category':<25} {'Total':>8} {'Passed':>8} {'Accuracy':>10}")
+ print("-" * 60)
+
+ total_all = 0
+ passed_all = 0
+
+ for category, result in all_results.items():
+ print(f"{category:<25} {result['total']:>8} {result['passed']:>8} {result['accuracy']:>9.2f}%")
+ total_all += result["total"]
+ passed_all += result["passed"]
+
+ if len(all_results) > 1:
+ print("-" * 60)
+ overall_acc = passed_all / total_all * 100 if total_all > 0 else 0
+ print(f"{'OVERALL':<25} {total_all:>8} {passed_all:>8} {overall_acc:>9.2f}%")
+
+ print("=" * 60)
+
+ # Save detailed results
+ if args.output:
+ output_data = {
+ "model": args.model_name,
+ "config": {
+ "base_url": args.base_url,
+ "max_tokens": args.max_tokens,
+ "temperature": args.temperature,
+ },
+ "results": {
+ cat: {
+ "total": r["total"],
+ "passed": r["passed"],
+ "accuracy": r["accuracy"],
+ "error_summary": r.get("error_summary", {}),
+ }
+ for cat, r in all_results.items()
+ },
+ }
+ with open(args.output, "w") as f:
+ json.dump(output_data, f, indent=2)
+ print(f"\nResults saved to {args.output}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/test/acc/bfcl/requirements.txt b/test/acc/bfcl/requirements.txt
new file mode 100644
index 000000000..e57d2da41
--- /dev/null
+++ b/test/acc/bfcl/requirements.txt
@@ -0,0 +1,13 @@
+# Evaluation benchmark dependencies
+aiohttp>=3.8.0
+tqdm>=4.64.0
+transformers>=4.30.0
+numpy>=1.21.0
+openai>=1.0.0
+huggingface_hub>=0.20.0
+
+# Optional: official human-eval package for dataset loading
+# pip install git+https://github.com/openai/human-eval.git
+
+# Optional: official BFCL evaluation package
+# pip install bfcl-eval
diff --git a/test/acc/bfcl/run_bfcl.sh b/test/acc/bfcl/run_bfcl.sh
new file mode 100644
index 000000000..2e68f8380
--- /dev/null
+++ b/test/acc/bfcl/run_bfcl.sh
@@ -0,0 +1,48 @@
+#!/bin/bash
+# BFCL (Berkeley Function Calling Leaderboard) evaluation script for LightLLM
+#
+# Prerequisites:
+# 1. Start LightLLM server:
+# python -m lightllm.server.api_server \
+# --model_dir /path/to/GLM-4.7-Flash \
+# --tp 1 \
+# --port 8000
+#
+# 2. Install dependencies:
+# pip install openai tqdm datasets
+
+set -e
+
+# Configuration
+MODEL_NAME="${MODEL_NAME:-GLM-4.7-Flash}"
+BASE_URL="${BASE_URL:-http://localhost:8000/v1}"
+PORT="${PORT:-8000}"
+TEST_CATEGORY="${TEST_CATEGORY:-simple}"
+NUM_WORKERS="${NUM_WORKERS:-4}"
+
+# Check if server is running
+if ! curl -s "http://localhost:${PORT}/health" > /dev/null 2>&1; then
+ echo "Error: LightLLM server not running on port ${PORT}"
+ echo "Start the server first with:"
+ echo " python -m lightllm.server.api_server --model_dir /path/to/model --tp 1 --port ${PORT}"
+ exit 1
+fi
+
+echo "=========================================="
+echo "BFCL Function Calling Evaluation"
+echo "=========================================="
+echo "Model: ${MODEL_NAME}"
+echo "Server: ${BASE_URL}"
+echo "Test Category: ${TEST_CATEGORY}"
+echo ""
+
+# Run evaluation
+python "$(dirname "$0")/eval_bfcl.py" \
+ --model_name "${MODEL_NAME}" \
+ --base_url "${BASE_URL}" \
+ --test_category "${TEST_CATEGORY}" \
+ --num_workers "${NUM_WORKERS}" \
+ --output "bfcl_results_${TEST_CATEGORY}_$(date +%Y%m%d_%H%M%S).json"
+
+echo ""
+echo "Evaluation complete!"