From 5f36055320a24b8836e505185e086ce9caf5aea7 Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 27 Jan 2026 15:59:55 +0000 Subject: [PATCH 01/14] feat: add GLM-4.7-Flash (glm4_moe_lite) model support - Add glm4_moe_lite model implementation with MLA attention - Add glm4_moe_lite_mtp for multi-token prediction support - Refactor attention kernels to use dynamic batch size - Add kernel configs for H200 GPU optimization - Add BFCL evaluation scripts for function calling --- .../basemodel/attention/flashinfer/mla.py | 7 +- .../common/basemodel/attention/triton/mla.py | 3 +- lightllm/common/basemodel/basemodel.py | 1 + .../flash_decoding/gqa_flash_decoding_vsm.py | 11 +- .../triton_kernel/fused_moe/grouped_topk.py | 2 +- .../triton_kernel/fused_moe/topk_select.py | 8 +- .../mla_att/decode_att/gqa_flash_decoding.py | 13 +- .../context_flashattention_nopad_with_v.py | 46 +- lightllm/common/quantization/no_quant.py | 4 +- ...num=1,use_fp8_w8a8=false}_NVIDIA_H200.json | 110 +++ ...num=4,use_fp8_w8a8=false}_NVIDIA_H200.json | 110 +++ .../{topk_num=4}_NVIDIA_H200.json | 50 ++ ...orch.bfloat16,topk_num=4}_NVIDIA_H200.json | 74 ++ ...=20,dtype=torch.bfloat16}_NVIDIA_H200.json | 74 ++ ...out_dtype=torch.bfloat16}_NVIDIA_H200.json | 74 ++ ...out_dtype=torch.bfloat16}_NVIDIA_H200.json | 104 +++ lightllm/models/__init__.py | 1 + lightllm/models/deepseek2/model.py | 2 + lightllm/models/glm4_moe_lite/__init__.py | 4 + lightllm/models/glm4_moe_lite/infer_struct.py | 12 + .../glm4_moe_lite/layer_infer/__init__.py | 3 + .../layer_infer/transformer_layer_infer.py | 96 +++ .../glm4_moe_lite/layer_weights/__init__.py | 3 + .../layer_weights/transformer_layer_weight.py | 117 +++ lightllm/models/glm4_moe_lite/model.py | 74 ++ lightllm/models/glm4_moe_lite_mtp/__init__.py | 3 + .../glm4_moe_lite_mtp/layer_infer/__init__.py | 3 + .../layer_infer/pre_layer_infer.py | 82 +++ .../layer_weights/__init__.py | 5 + .../pre_and_post_layer_weight.py | 48 ++ lightllm/models/glm4_moe_lite_mtp/model.py | 90 +++ lightllm/server/api_cli.py | 6 +- lightllm/server/function_call_parser.py | 273 ++++++- .../model_infer/mode_backend/base_backend.py | 4 + test/eval/eval_bfcl.py | 686 ++++++++++++++++++ test/eval/requirements.txt | 13 + test/eval/run_bfcl.sh | 48 ++ 37 files changed, 2230 insertions(+), 34 deletions(-) create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=4}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=4}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=20,dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=10240,out_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/models/glm4_moe_lite/__init__.py create mode 100644 lightllm/models/glm4_moe_lite/infer_struct.py create mode 100644 lightllm/models/glm4_moe_lite/layer_infer/__init__.py create mode 100644 lightllm/models/glm4_moe_lite/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/glm4_moe_lite/layer_weights/__init__.py create mode 100644 lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/glm4_moe_lite/model.py create mode 100644 lightllm/models/glm4_moe_lite_mtp/__init__.py create mode 100644 lightllm/models/glm4_moe_lite_mtp/layer_infer/__init__.py create mode 100644 lightllm/models/glm4_moe_lite_mtp/layer_infer/pre_layer_infer.py create mode 100644 lightllm/models/glm4_moe_lite_mtp/layer_weights/__init__.py create mode 100644 lightllm/models/glm4_moe_lite_mtp/layer_weights/pre_and_post_layer_weight.py create mode 100644 lightllm/models/glm4_moe_lite_mtp/model.py create mode 100755 test/eval/eval_bfcl.py create mode 100644 test/eval/requirements.txt create mode 100755 test/eval/run_bfcl.sh diff --git a/lightllm/common/basemodel/attention/flashinfer/mla.py b/lightllm/common/basemodel/attention/flashinfer/mla.py index 6c74b22e1..537dbee22 100644 --- a/lightllm/common/basemodel/attention/flashinfer/mla.py +++ b/lightllm/common/basemodel/attention/flashinfer/mla.py @@ -16,6 +16,8 @@ def __init__(self, model): self.qk_nope_head_dim = model.qk_nope_head_dim self.qk_rope_head_dim = model.qk_rope_head_dim self.kv_lora_rank = model.kv_lora_rank + # v_head_dim may differ from qk_nope_head_dim (e.g., GLM-4.7-Flash: v_head_dim=256, qk_nope_head_dim=192) + self.v_head_dim = getattr(model, "v_head_dim", self.qk_nope_head_dim) self.q_data_type = model.data_type self.kv_data_type = model.data_type self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) @@ -69,7 +71,7 @@ def init_state(self): num_qo_heads=self.backend.tp_q_head_num, num_kv_heads=self.backend.tp_q_head_num, head_dim_qk=self.backend.qk_nope_head_dim + self.backend.qk_rope_head_dim, - head_dim_vo=self.backend.qk_nope_head_dim, + head_dim_vo=self.backend.v_head_dim, # Use v_head_dim, not qk_nope_head_dim q_data_type=self.backend.q_data_type, causal=True, sm_scale=self.backend.softmax_scale, @@ -101,7 +103,8 @@ def _mla_prefill_att( ) -> torch.Tensor: self.backend: MlaFlashInferAttBackend = self.backend # for typing k_nope, k_rope = k - o_tensor = alloc_func((q.shape[0], q.shape[1], k_nope.shape[2]), q.dtype, device="cuda") + # Output dimension is v_head_dim (from v.shape[-1]), not qk_nope_head_dim + o_tensor = alloc_func((q.shape[0], q.shape[1], v.shape[-1]), q.dtype, device="cuda") q_head_num = q.shape[1] k = torch.cat([k_nope, torch.repeat_interleave(k_rope, q_head_num, dim=-2)], dim=-1) self.prefill_wrapper.run(q, k, v, out=o_tensor) diff --git a/lightllm/common/basemodel/attention/triton/mla.py b/lightllm/common/basemodel/attention/triton/mla.py index 8288193ad..fbdae4012 100644 --- a/lightllm/common/basemodel/attention/triton/mla.py +++ b/lightllm/common/basemodel/attention/triton/mla.py @@ -44,7 +44,8 @@ def _mla_prefill_att( qk_rope_head_dim = 64 q_nope, q_rope = q[:, :, :-qk_rope_head_dim], q[:, :, -qk_rope_head_dim:] - o_tensor = alloc_func(q_nope.shape, dtype=q_nope.dtype, device=q.device) + # GLM-4.7-Flash : v_head_dim != qk_nope_head_dim + o_tensor = alloc_func((q_nope.shape[0], q_nope.shape[1], v.shape[-1]), dtype=q_nope.dtype, device=q.device) k_nope, k_rope = k assert att_control.mla_prefill softmax_scale = att_control.mla_prefill_dict["softmax_scale"] diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index e6405e4d7..5c1d2b871 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -1022,6 +1022,7 @@ def _gen_special_model_input(self, token_num: int): "Deepseek3MTPModel" in str(self.__class__) or "Qwen3MOEMTPModel" in str(self.__class__) or "MistralMTPModel" in str(self.__class__) + or "Glm4MoeLiteMTPModel" in str(self.__class__) ) if is_mtp_draft_model: special_model_input["mtp_draft_input_hiddens"] = torch.randn( diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py index 6a9bb79c7..141587ff3 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py @@ -81,8 +81,11 @@ def _fwd_kernel_calcu_index_and_block_seq( vsm_count, batch_size, BLOCK_N: tl.constexpr, + MAX_BATCH_SIZE: tl.constexpr, ): - b_seq_len = tl.load(b_seq_len + tl.arange(0, 2048), mask=tl.arange(0, 2048) < batch_size, other=0) + b_seq_len = tl.load( + b_seq_len + tl.arange(0, MAX_BATCH_SIZE), mask=tl.arange(0, MAX_BATCH_SIZE) < batch_size, other=0 + ) total_token_num = tl.sum(b_seq_len) block_seq = tl.cdiv(total_token_num, vsm_count * 4) @@ -93,9 +96,9 @@ def _fwd_kernel_calcu_index_and_block_seq( cumsum_seq_len = tl.cumsum(block_seq_len) batch_start_index = cumsum_seq_len - block_seq_len tl.store( - mid_o_batch_start_index + tl.arange(0, 2048), + mid_o_batch_start_index + tl.arange(0, MAX_BATCH_SIZE), batch_start_index, - mask=tl.arange(0, 2048) < batch_size, + mask=tl.arange(0, MAX_BATCH_SIZE) < batch_size, ) tl.store(mid_o_decode_att_block_seq, block_seq) @@ -455,7 +458,6 @@ def gqa_token_decode_attention_flash_decoding_vsm( ) if not hasattr(infer_state, "decode_att_block_seq"): - assert batch_size <= 2048 decode_att_block_seq = torch.empty( [ 1, @@ -477,6 +479,7 @@ def gqa_token_decode_attention_flash_decoding_vsm( num_vsm, batch_size, BLOCK_N=run_config["BLOCK_N"], + MAX_BATCH_SIZE=triton.next_power_of_2(batch_size), num_warps=4, ) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py index fb0323cd4..2687adf14 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py @@ -227,7 +227,7 @@ def triton_grouped_topk( scores_buffer = torch.empty((token_num, total_expert_num), dtype=dtype, device="cuda") out_topk_weights = torch.empty((token_num, topk), dtype=torch.float32, device="cuda") - out_topk_ids = torch.empty((token_num, topk), dtype=torch.long, device="cuda") + out_topk_ids = torch.empty((token_num, topk), dtype=torch.int32, device="cuda") assert total_expert_num % num_expert_group == 0 diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py b/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py index 72c3a381e..7ac5a03b5 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py @@ -196,10 +196,12 @@ def select_experts( scoring_func=scoring_func, ) else: - group_score_topk_num = 1 - # for deepseek v3 - if topk_group == 4 and num_expert_group == 8 and top_k == 8: + if correction_bias is not None: group_score_topk_num = 2 + elif topk_group == 4 and num_expert_group == 8 and top_k == 8: + group_score_topk_num = 2 + else: + group_score_topk_num = 1 topk_weights, topk_ids = triton_grouped_topk( hidden_states=hidden_states, diff --git a/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py index 28839b5f5..063181d99 100644 --- a/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py @@ -67,7 +67,6 @@ def gqa_token_decode_attention_flash_decoding( ) if not hasattr(infer_state, "decode_att_block_seq"): - assert batch_size <= 2048 decode_att_block_seq = torch.empty( [ 1, @@ -89,6 +88,7 @@ def gqa_token_decode_attention_flash_decoding( vsm_count, batch_size, BLOCK_N=BLOCK_N, + MAX_BATCH_SIZE=triton.next_power_of_2(batch_size), num_warps=4, ) @@ -134,8 +134,11 @@ def _fwd_kernel_calcu_index_and_block_seq( num_sm, batch_size, BLOCK_N: tl.constexpr, + MAX_BATCH_SIZE: tl.constexpr, ): - b_seq_len = tl.load(b_seq_len_ptr + tl.arange(0, 2048), mask=tl.arange(0, 2048) < batch_size, other=0) + b_seq_len = tl.load( + b_seq_len_ptr + tl.arange(0, MAX_BATCH_SIZE), mask=tl.arange(0, MAX_BATCH_SIZE) < batch_size, other=0 + ) total_token_num = tl.sum(b_seq_len) block_seq = tl.cast(total_token_num / (num_sm * 4), dtype=tl.int32) + 1 @@ -144,6 +147,10 @@ def _fwd_kernel_calcu_index_and_block_seq( block_seq_len = tl.cdiv(b_seq_len, block_seq) cumsum_seq_len = tl.cumsum(block_seq_len) batch_start_index = cumsum_seq_len - block_seq_len - tl.store(mid_o_batch_start_index_ptr + tl.arange(0, 2048), batch_start_index, mask=tl.arange(0, 2048) < batch_size) + tl.store( + mid_o_batch_start_index_ptr + tl.arange(0, MAX_BATCH_SIZE), + batch_start_index, + mask=tl.arange(0, MAX_BATCH_SIZE) < batch_size, + ) tl.store(mid_o_decode_att_block_seq_ptr, block_seq) return diff --git a/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/context_flashattention_nopad_with_v.py b/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/context_flashattention_nopad_with_v.py index be0635182..d79020844 100644 --- a/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/context_flashattention_nopad_with_v.py +++ b/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/context_flashattention_nopad_with_v.py @@ -36,6 +36,9 @@ def _fwd_kernel_with_v( BLOCK_DMODEL: tl.constexpr, BLOCK_ROPE_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + BLOCK_V_DMODEL: tl.constexpr, + ACTUAL_DMODEL: tl.constexpr, + ACTUAL_V_DMODEL: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -53,8 +56,13 @@ def _fwd_kernel_with_v( # initialize offsets offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) + offs_v_d = tl.arange(0, BLOCK_V_DMODEL) offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + + d_mask = offs_d < ACTUAL_DMODEL + v_d_mask = offs_v_d < ACTUAL_V_DMODEL + off_q = (cur_batch_in_q_start_index + offs_m[:, None]) * stride_q_bs + cur_head * stride_q_h + offs_d[None, :] off_q_rope = ( (cur_batch_in_q_start_index + offs_m[:, None]) * stride_q_rope_bs @@ -63,9 +71,10 @@ def _fwd_kernel_with_v( ) off_k = offs_n[None, :] * stride_k_bs + cur_k_head * stride_k_h + offs_d[:, None] off_k_rope = offs_n[None, :] * stride_k_rope_bs + offs_rope_d[:, None] - off_v = offs_n[:, None] * stride_vbs + cur_k_head * stride_vh + offs_d[None, :] + off_v = offs_n[:, None] * stride_vbs + cur_k_head * stride_vh + offs_v_d[None, :] - q = tl.load(Q_nope + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + q_mask = (offs_m[:, None] < cur_batch_seq_len) & d_mask[None, :] + q = tl.load(Q_nope + off_q, mask=q_mask, other=0.0) q_rope = tl.load(Q_rope + off_q_rope, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) k_ptrs = K_nope + off_k @@ -75,7 +84,7 @@ def _fwd_kernel_with_v( # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_V_DMODEL], dtype=tl.float32) block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) block_end_loc = tl.minimum((start_m + 1) * BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len) @@ -83,14 +92,16 @@ def _fwd_kernel_with_v( for start_n in range(0, block_mask * block_end_loc, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- + k_seq_mask = (start_n + offs_n[None, :]) < block_end_loc + k_mask = k_seq_mask & d_mask[:, None] k = tl.load( k_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_k_bs, - mask=(start_n + offs_n[None, :]) < block_end_loc, + mask=k_mask, other=0.0, ) k_rope = tl.load( k_rope_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_k_rope_bs, - mask=(start_n + offs_n[None, :]) < block_end_loc, + mask=k_seq_mask, other=0.0, ) @@ -112,9 +123,11 @@ def _fwd_kernel_with_v( # -- update output accumulator -- acc = acc * alpha[:, None] # update acc + v_seq_mask = (start_n + offs_n[:, None]) < block_end_loc + v_mask = v_seq_mask & v_d_mask[None, :] v = tl.load( v_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < block_end_loc, + mask=v_mask, other=0.0, ) p = p.to(v.dtype) @@ -124,9 +137,10 @@ def _fwd_kernel_with_v( acc = acc / l_i[:, None] # initialize pointers to output - off_o = (cur_batch_in_q_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] + off_o = (cur_batch_in_q_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_v_d[None, :] out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + o_mask = (offs_m[:, None] < cur_batch_seq_len) & v_d_mask[None, :] + tl.store(out_ptrs, acc, mask=o_mask) return @@ -149,13 +163,14 @@ def context_attention_fwd_with_v( BLOCK = 128 if not is_tesla() else 64 q_nope_dim = q_nope.shape[-1] q_rope_dim = q_rope.shape[-1] + v_dim = v.shape[-1] assert q_nope_dim == k_nope.shape[-1] assert q_rope_dim == k_rope.shape[-1] - assert q_nope_dim in {16, 32, 64, 128, 256, 512} - assert q_rope_dim in {16, 32, 64, 128, 256} - assert q_nope_dim == v.shape[-1] - if q_nope_dim >= 512: + q_nope_dim_padded = triton.next_power_of_2(q_nope_dim) + v_dim_padded = triton.next_power_of_2(v_dim) + + if q_nope_dim_padded >= 512 or v_dim_padded >= 512: BLOCK = 64 if not is_tesla() else 32 else: BLOCK = 128 if not is_tesla() else 64 @@ -167,7 +182,7 @@ def context_attention_fwd_with_v( batch, head = b_seq_len.shape[0], q_nope.shape[1] grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - num_warps = 4 if q_nope_dim <= 64 else 8 + num_warps = 4 if q_nope_dim_padded <= 64 else 8 _fwd_kernel_with_v[grid]( q_nope, @@ -194,9 +209,12 @@ def context_attention_fwd_with_v( o.stride(1), b_prompt_cache_len=b_prompt_cache_len, BLOCK_M=BLOCK, - BLOCK_DMODEL=q_nope_dim, + BLOCK_DMODEL=q_nope_dim_padded, BLOCK_ROPE_DMODEL=q_rope_dim, BLOCK_N=BLOCK, + BLOCK_V_DMODEL=v_dim_padded, + ACTUAL_DMODEL=q_nope_dim, + ACTUAL_V_DMODEL=v_dim, num_warps=num_warps, num_stages=1, ) diff --git a/lightllm/common/quantization/no_quant.py b/lightllm/common/quantization/no_quant.py index 3bf023f8a..fa926ad6f 100644 --- a/lightllm/common/quantization/no_quant.py +++ b/lightllm/common/quantization/no_quant.py @@ -28,8 +28,8 @@ def apply( device = input_tensor.device if use_custom_tensor_mananger: out = g_cache_manager.alloc_tensor(shape, dtype, device=device) - else: - out = torch.empty(shape, dtype=dtype, device=device) + else: + out = torch.empty(shape, dtype=dtype, device=device) if bias is None: return torch.mm(input_tensor, weight, out=out) return torch.addmm(bias, input_tensor, weight, out=out) diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..deb97363c --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "400": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "65536": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..a6c93c3f6 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=4}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=4}_NVIDIA_H200.json new file mode 100644 index 000000000..0f0c175b9 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=4}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "100": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE": 128, + "num_warps": 2 + }, + "16384": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, + "2048": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE": 128, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=4}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=4}_NVIDIA_H200.json new file mode 100644 index 000000000..c6c3d54ff --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=4}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 64, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "100": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "128": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 1 + }, + "16": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16384": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 1 + }, + "256": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "4096": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "64": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 16 + }, + "8": { + "BLOCK_DIM": 64, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=20,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=20,dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..5601eab76 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=20,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_SEQ": 32, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 5, + "num_warps": 8 + }, + "100": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 3, + "num_warps": 1 + }, + "1024": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 2, + "num_stages": 5, + "num_warps": 1 + }, + "128": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 8, + "num_stages": 5, + "num_warps": 1 + }, + "16": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 3, + "num_warps": 2 + }, + "16384": { + "BLOCK_SEQ": 8, + "HEAD_PARALLEL_NUM": 2, + "num_stages": 5, + "num_warps": 1 + }, + "2048": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 2, + "num_stages": 5, + "num_warps": 1 + }, + "256": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 8, + "num_stages": 1, + "num_warps": 1 + }, + "32": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 4, + "num_warps": 4 + }, + "4096": { + "BLOCK_SEQ": 2, + "HEAD_PARALLEL_NUM": 2, + "num_stages": 3, + "num_warps": 1 + }, + "64": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 1, + "num_warps": 1 + }, + "8": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=10240,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=10240,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..b82f25e17 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=10240,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "100": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "1024": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 32, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "32": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "4096": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..ab4644621 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,104 @@ +{ + "1": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "100": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "1024": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "16": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 8 + }, + "4": { + "BLOCK_M": 1, + "BLOCK_N": 32, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "400": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "4096": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "65536": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 539b32dec..095f73679 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -18,6 +18,7 @@ from lightllm.models.gemma_2b.model import Gemma_2bTpPartModel from lightllm.models.phi3.model import Phi3TpPartModel from lightllm.models.deepseek2.model import Deepseek2TpPartModel +from lightllm.models.glm4_moe_lite.model import Glm4MoeLiteTpPartModel from lightllm.models.internvl.model import ( InternVLLlamaTpPartModel, InternVLPhi3TpPartModel, diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index f0739a8a8..c7f13bb63 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -43,6 +43,8 @@ def _init_some_value(self): self.qk_rope_head_dim = self.config["qk_rope_head_dim"] self.q_lora_rank = self.config["q_lora_rank"] self.kv_lora_rank = self.config["kv_lora_rank"] + # v_head_dim defaults to qk_nope_head_dim for DeepSeek-V2, but GLM-4.7-Flash has different value + self.v_head_dim = self.config.get("v_head_dim", self.qk_nope_head_dim) self.head_dim_ = self.kv_lora_rank + self.qk_rope_head_dim def _init_custom(self): diff --git a/lightllm/models/glm4_moe_lite/__init__.py b/lightllm/models/glm4_moe_lite/__init__.py new file mode 100644 index 000000000..b00657090 --- /dev/null +++ b/lightllm/models/glm4_moe_lite/__init__.py @@ -0,0 +1,4 @@ +from lightllm.models.glm4_moe_lite.model import Glm4MoeLiteTpPartModel +from lightllm.models.glm4_moe_lite.infer_struct import Glm4MoeLiteInferStateInfo + +__all__ = ["Glm4MoeLiteTpPartModel", "Glm4MoeLiteInferStateInfo"] diff --git a/lightllm/models/glm4_moe_lite/infer_struct.py b/lightllm/models/glm4_moe_lite/infer_struct.py new file mode 100644 index 000000000..92c350abc --- /dev/null +++ b/lightllm/models/glm4_moe_lite/infer_struct.py @@ -0,0 +1,12 @@ +from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo + + +class Glm4MoeLiteInferStateInfo(Deepseek2InferStateInfo): + """Inference state for GLM-4.7-Flash (glm4_moe_lite architecture). + + Inherits from Deepseek2InferStateInfo as GLM-4.7-Flash uses the same + MLA (Multi-Head Latent Attention) mechanism as DeepSeek-V2/V3. + """ + + def __init__(self): + super().__init__() diff --git a/lightllm/models/glm4_moe_lite/layer_infer/__init__.py b/lightllm/models/glm4_moe_lite/layer_infer/__init__.py new file mode 100644 index 000000000..a95580535 --- /dev/null +++ b/lightllm/models/glm4_moe_lite/layer_infer/__init__.py @@ -0,0 +1,3 @@ +from lightllm.models.glm4_moe_lite.layer_infer.transformer_layer_infer import Glm4MoeLiteTransformerLayerInfer + +__all__ = ["Glm4MoeLiteTransformerLayerInfer"] diff --git a/lightllm/models/glm4_moe_lite/layer_infer/transformer_layer_infer.py b/lightllm/models/glm4_moe_lite/layer_infer/transformer_layer_infer.py new file mode 100644 index 000000000..bcea872fd --- /dev/null +++ b/lightllm/models/glm4_moe_lite/layer_infer/transformer_layer_infer.py @@ -0,0 +1,96 @@ +import os +import torch +import torch.distributed as dist +import triton +from functools import partial +from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from lightllm.distributed.communication_op import reduce_scatter_tensor + + +class Glm4MoeLiteTransformerLayerInfer(Deepseek2TransformerLayerInfer): + def __init__(self, layer_num, network_config): + self._glm4_layer_num = layer_num + self._glm4_first_k_dense = network_config.get("first_k_dense_replace", 0) + self._glm4_has_routed_experts = network_config.get("n_routed_experts") is not None + super().__init__(layer_num, network_config) + + @property + def is_moe(self): + return self._glm4_has_routed_experts and self._glm4_layer_num >= self._glm4_first_k_dense + + @is_moe.setter + def is_moe(self, value): + pass + + def _bind_ffn(self): + if self.is_moe: + moe_mode = os.environ.get("MOE_MODE", "TP") + if moe_mode == "EP": + self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn_edp, self) + self._tpsp_ffn = self._tpsp_ffn_ep + else: + self._ffn = partial(Glm4MoeLiteTransformerLayerInfer._moe_ffn, self) + self._tpsp_ffn = self._tpsp_ffn_tp + else: + self._ffn = partial(LlamaTransformerLayerInfer._ffn, self) + self._tpsp_ffn = self._tpsp_ffn_tp + + def _get_o(self, input: torch.Tensor, infer_state, layer_weight) -> torch.Tensor: + if input.shape[2] == self.kv_lora_rank: + input = layer_weight.v_b_proj_.bmm(input.transpose(0, 1)).transpose(0, 1) + o_tensor = layer_weight.o_weight_.mm(input.reshape(-1, self.tp_q_head_num_ * self.v_head_dim)) + return o_tensor + + def _tpsp_get_o(self, input, infer_state, layer_weight) -> torch.Tensor: + if infer_state.need_dp_prefill_balance: + input = infer_state._all_to_all_balance_get(data=input) + + if input.shape[2] == self.kv_lora_rank: + input = layer_weight.v_b_proj_.bmm(input.transpose(0, 1)).transpose(0, 1) + + input = input.reshape(-1, self.tp_q_head_num_ * self.v_head_dim) + dest_size = triton.cdiv(input.shape[0], self.tp_world_size_) * self.tp_world_size_ + o_tensor = self.alloc_tensor((dest_size, self.embed_dim_), dtype=input.dtype, device=input.device) + layer_weight.o_weight_.mm(input, out=o_tensor[0 : len(infer_state.input_ids), :]) + e_o_tensor = o_tensor[len(infer_state.input_ids) :, :] + if e_o_tensor.shape[0] > 0: + e_o_tensor.fill_(0) + + if self.tp_world_size_ > 1: + sp_token_num = o_tensor.shape[0] // self.tp_world_size_ + reduce_o_tensor = self.alloc_tensor((sp_token_num, self.embed_dim_), dtype=input.dtype, device=input.device) + reduce_scatter_tensor( + output=reduce_o_tensor, + input=o_tensor, + op=dist.ReduceOp.SUM, + group=infer_state.dist_group, + async_op=False, + ) + o_tensor = reduce_o_tensor + + return o_tensor + + def _moe_ffn(self, input, infer_state, layer_weight): + hidden_states = input.view(-1, self.embed_dim_) + num_tokens, hidden_dim = hidden_states.shape + + if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0: + shared_output = LlamaTransformerLayerInfer._ffn(self, hidden_states, infer_state, layer_weight) + + router_logits = layer_weight.moe_gate.mm(hidden_states.to(torch.float32)) + + layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, + renormalize=self.norm_topk_prob, + use_grouped_topk=self.n_group, + topk_group=self.topk_group, + num_expert_group=self.n_group, + ) + + if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0: + hidden_states.add_(shared_output) + + return hidden_states.view(num_tokens, hidden_dim) diff --git a/lightllm/models/glm4_moe_lite/layer_weights/__init__.py b/lightllm/models/glm4_moe_lite/layer_weights/__init__.py new file mode 100644 index 000000000..1fd5e36f8 --- /dev/null +++ b/lightllm/models/glm4_moe_lite/layer_weights/__init__.py @@ -0,0 +1,3 @@ +from lightllm.models.glm4_moe_lite.layer_weights.transformer_layer_weight import Glm4MoeLiteTransformerLayerWeight + +__all__ = ["Glm4MoeLiteTransformerLayerWeight"] diff --git a/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py b/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py new file mode 100644 index 000000000..2a10f6fdc --- /dev/null +++ b/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py @@ -0,0 +1,117 @@ +import os +import torch +from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, FusedMoeWeight + + +class Glm4MoeLiteTransformerLayerWeight(Deepseek2TransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) + + def _parse_config(self): + # Call parent's _parse_config to set n_embed, moe_inter, and other required attributes + super()._parse_config() + + # Override is_moe calculation for GLM4 (no moe_layer_freq check) + self.is_moe = self.network_config_.get( + "n_routed_experts" + ) is not None and self.layer_num_ >= self.network_config_.get("first_k_dense_replace", 0) + + # Override num_fused_shared_experts with GLM4-specific logic + from lightllm.utils.envs_utils import get_env_start_args + + self.num_fused_shared_experts = 0 + if get_env_start_args().enable_fused_shared_experts and self.is_moe: + moe_mode = os.getenv("MOE_MODE", "TP") + assert moe_mode == "TP" + self.num_fused_shared_experts = self.network_config_.get("n_shared_experts", 0) + + def load_hf_weights(self, weights): + from lightllm.common.basemodel import TransformerLayerWeight + from lightllm.models.deepseek2.triton_kernel.weight_dequant import weight_dequant + + kv_b_quant_method = self.quant_cfg.get_quant_method(self.layer_num_, "kv_b_proj") + weight_scale_suffix = None + if self.quant_cfg.quantized_weight: + weight_scale_suffix = kv_b_quant_method.weight_scale_suffix + + if f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight" in weights: + kv_b_proj_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight"] + # for quantized weights, dequantize first + if self.quant_cfg.quantized_weight: + kv_b_proj_ = weight_dequant( + kv_b_proj_.cuda(), + weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + weight_scale_suffix].cuda(), + ).cpu() + # Use GLM4-specific split methods (different from DeepSeek2's dimensions) + k_b_proj_ = self._load_kb(kv_b_proj_) + v_b_proj_ = self._load_vb(kv_b_proj_) + weights[f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight"] = k_b_proj_ + weights[f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight"] = v_b_proj_ + + # rename the shared experts weight + if self.num_fused_shared_experts > 0: + self._rename_shared_experts(weights, weight_scale_suffix) + + return TransformerLayerWeight.load_hf_weights(self, weights) + + def _load_kb(self, kv_b_proj_): + kv_dim = self.qk_nope_head_dim + self.v_head_dim + k_b_proj_ = kv_b_proj_.view(self.num_attention_heads, kv_dim, self.kv_lora_rank)[:, : self.qk_nope_head_dim, :] + return k_b_proj_.contiguous().to(kv_b_proj_.dtype) + + def _load_kb_scale(self, kv_b_proj_, block_size): + kv_dim = self.qk_nope_head_dim + self.v_head_dim + k_b_proj_scale_ = kv_b_proj_.view( + self.num_attention_heads, kv_dim // block_size, self.kv_lora_rank // block_size + )[:, : self.qk_nope_head_dim // block_size, :] + return k_b_proj_scale_.contiguous().to(kv_b_proj_.dtype) + + def _load_vb(self, kv_b_proj_): + kv_dim = self.qk_nope_head_dim + self.v_head_dim + v_b_proj_ = kv_b_proj_.T.view(self.kv_lora_rank, self.num_attention_heads, kv_dim)[ + :, :, self.qk_nope_head_dim : + ].transpose(0, 1) + return v_b_proj_.contiguous().to(kv_b_proj_.dtype) + + def _load_vb_scale(self, kv_b_proj_scale_, block_size): + kv_dim = self.qk_nope_head_dim + self.v_head_dim + v_b_proj_scale_ = kv_b_proj_scale_.T.view( + self.kv_lora_rank // block_size, + self.num_attention_heads, + kv_dim // block_size, + )[:, :, self.qk_nope_head_dim // block_size :].transpose(0, 1) + return v_b_proj_scale_.contiguous().to(kv_b_proj_scale_.dtype) + + def _init_moe(self): + moe_intermediate_size = self.network_config_["moe_intermediate_size"] + hidden_size = self.network_config_["hidden_size"] + + self.moe_gate = ROWMMWeight( + in_dim=hidden_size, + out_dims=[self.n_routed_experts], + weight_names=f"model.layers.{self.layer_num_}.mlp.gate.weight", + data_type=torch.float32, # Router gate needs float32 for numerical stability + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + + if self.num_fused_shared_experts == 0: + self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts", is_shared_experts=True) + + self.experts = FusedMoeWeight( + gate_proj_name="gate_proj", + down_proj_name="down_proj", + up_proj_name="up_proj", + e_score_correction_bias_name=self.e_score_correction_bias_name, + weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts", + n_routed_experts=self.n_routed_experts, + hidden_size=hidden_size, + moe_intermediate_size=moe_intermediate_size, + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), + num_fused_shared_experts=self.num_fused_shared_experts, + layer_num=self.layer_num_, + network_config=self.network_config_, + ) diff --git a/lightllm/models/glm4_moe_lite/model.py b/lightllm/models/glm4_moe_lite/model.py new file mode 100644 index 000000000..a5970ab59 --- /dev/null +++ b/lightllm/models/glm4_moe_lite/model.py @@ -0,0 +1,74 @@ +import torch +from lightllm.models.registry import ModelRegistry +from lightllm.models.deepseek2.model import Deepseek2TpPartModel +from lightllm.models.glm4_moe_lite.layer_infer.transformer_layer_infer import Glm4MoeLiteTransformerLayerInfer +from lightllm.models.glm4_moe_lite.layer_weights.transformer_layer_weight import Glm4MoeLiteTransformerLayerWeight +from lightllm.models.glm4_moe_lite.infer_struct import Glm4MoeLiteInferStateInfo +from lightllm.distributed.communication_op import dist_group_manager +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +@ModelRegistry("glm4_moe_lite") +class Glm4MoeLiteTpPartModel(Deepseek2TpPartModel): + + transformer_weight_class = Glm4MoeLiteTransformerLayerWeight + transformer_layer_infer_class = Glm4MoeLiteTransformerLayerInfer + infer_state_class = Glm4MoeLiteInferStateInfo + + def __init__(self, kvargs): + super().__init__(kvargs) + + def _init_config(self): + super()._init_config() + + if "moe_layer_freq" not in self.config and self.config.get("n_routed_experts"): + self.config["moe_layer_freq"] = 1 + + if "routed_scaling_factor" not in self.config: + self.config["routed_scaling_factor"] = 1.8 + + if "topk_method" not in self.config: + self.config["topk_method"] = "noaux_tc" + + if "scoring_func" not in self.config: + self.config["scoring_func"] = "sigmoid" + + logger.info( + f"GLM-4.7-Flash config: " + f"n_routed_experts={self.config.get('n_routed_experts')}, " + f"n_shared_experts={self.config.get('n_shared_experts')}, " + f"num_experts_per_tok={self.config.get('num_experts_per_tok')}, " + f"first_k_dense_replace={self.config.get('first_k_dense_replace')}, " + f"routed_scaling_factor={self.config.get('routed_scaling_factor')}, " + f"scoring_func={self.config.get('scoring_func')}" + ) + + def _init_custom(self): + self._init_to_get_yarn_rotary() + dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"]) + + def _init_to_get_yarn_rotary(self): + rope_scaling = self.config.get("rope_scaling") + + if rope_scaling is None: + self._init_glm4_standard_rotary() + else: + super()._init_to_get_yarn_rotary() + + def _init_glm4_standard_rotary(self): + rope_theta = self.config.get("rope_theta", 1000000.0) + qk_rope_head_dim = self.config.get("qk_rope_head_dim", 64) + max_position_embeddings = self.config.get("max_position_embeddings", 202752) + + dim = qk_rope_head_dim + + inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2, device="cpu", dtype=torch.float32) / dim)) + + max_seq_len = max(max_position_embeddings, self.max_seq_length) + t = torch.arange(max_seq_len, device="cpu", dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + + self._cos_cached = torch.cos(freqs).to(self.data_type).cuda() + self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() diff --git a/lightllm/models/glm4_moe_lite_mtp/__init__.py b/lightllm/models/glm4_moe_lite_mtp/__init__.py new file mode 100644 index 000000000..96b6659c8 --- /dev/null +++ b/lightllm/models/glm4_moe_lite_mtp/__init__.py @@ -0,0 +1,3 @@ +from lightllm.models.glm4_moe_lite_mtp.model import Glm4MoeLiteMTPModel + +__all__ = ["Glm4MoeLiteMTPModel"] diff --git a/lightllm/models/glm4_moe_lite_mtp/layer_infer/__init__.py b/lightllm/models/glm4_moe_lite_mtp/layer_infer/__init__.py new file mode 100644 index 000000000..e357bfa19 --- /dev/null +++ b/lightllm/models/glm4_moe_lite_mtp/layer_infer/__init__.py @@ -0,0 +1,3 @@ +from lightllm.models.glm4_moe_lite_mtp.layer_infer.pre_layer_infer import Glm4MoeLiteMTPPreLayerInfer + +__all__ = ["Glm4MoeLiteMTPPreLayerInfer"] diff --git a/lightllm/models/glm4_moe_lite_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/glm4_moe_lite_mtp/layer_infer/pre_layer_infer.py new file mode 100644 index 000000000..6994d21f7 --- /dev/null +++ b/lightllm/models/glm4_moe_lite_mtp/layer_infer/pre_layer_infer.py @@ -0,0 +1,82 @@ +import torch + +from lightllm.models.glm4_moe_lite_mtp.layer_weights.pre_and_post_layer_weight import ( + Glm4MoeLiteMTPPreAndPostLayerWeight, +) +from lightllm.models.glm4_moe_lite.infer_struct import Glm4MoeLiteInferStateInfo +from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer + + +class Glm4MoeLiteMTPPreLayerInfer(LlamaPreLayerInfer): + def __init__(self, network_config): + super().__init__(network_config) + self.eps_ = network_config["rms_norm_eps"] + self.hidden_size = network_config["hidden_size"] + + def _mtp_context_forward( + self, + input_embdings, + infer_state: Glm4MoeLiteInferStateInfo, + layer_weight: Glm4MoeLiteMTPPreAndPostLayerWeight, + ): + tgt_embdings = infer_state.mtp_draft_input_hiddens + assert ( + input_embdings.shape[0] == tgt_embdings.shape[0] + ), f"shape {input_embdings.shape} != shape {tgt_embdings.shape}" + + layer_weight.enorm_weight_.rmsnorm_forward( + input=input_embdings, + eps=self.eps_, + out=input_embdings, + ) + layer_weight.hnorm_weight_.rmsnorm_forward( + input=tgt_embdings, + eps=self.eps_, + out=tgt_embdings, + ) + cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) + + ans_logics = layer_weight.eh_proj_weight_.mm(cat_embdings) + return ans_logics + + def _mtp_token_forward( + self, + input_embdings, + infer_state: Glm4MoeLiteInferStateInfo, + layer_weight: Glm4MoeLiteMTPPreAndPostLayerWeight, + ): + tgt_embdings = infer_state.mtp_draft_input_hiddens + assert input_embdings.shape[0] == tgt_embdings.shape[0] + + layer_weight.enorm_weight_.rmsnorm_forward( + input=input_embdings, + eps=self.eps_, + out=input_embdings, + ) + layer_weight.hnorm_weight_.rmsnorm_forward( + input=tgt_embdings, + eps=self.eps_, + out=tgt_embdings, + ) + cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) + + ans_logics = layer_weight.eh_proj_weight_.mm(cat_embdings) + return ans_logics + + def context_forward( + self, + input_ids, + infer_state: Glm4MoeLiteInferStateInfo, + layer_weight: Glm4MoeLiteMTPPreAndPostLayerWeight, + ): + input_embdings = super().context_forward(input_ids, infer_state, layer_weight) + return self._mtp_context_forward(input_embdings, infer_state, layer_weight) + + def token_forward( + self, + input_ids, + infer_state: Glm4MoeLiteInferStateInfo, + layer_weight: Glm4MoeLiteMTPPreAndPostLayerWeight, + ): + input_embdings = super().token_forward(input_ids, infer_state, layer_weight) + return self._mtp_token_forward(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/glm4_moe_lite_mtp/layer_weights/__init__.py b/lightllm/models/glm4_moe_lite_mtp/layer_weights/__init__.py new file mode 100644 index 000000000..57fe578cf --- /dev/null +++ b/lightllm/models/glm4_moe_lite_mtp/layer_weights/__init__.py @@ -0,0 +1,5 @@ +from lightllm.models.glm4_moe_lite_mtp.layer_weights.pre_and_post_layer_weight import ( + Glm4MoeLiteMTPPreAndPostLayerWeight, +) + +__all__ = ["Glm4MoeLiteMTPPreAndPostLayerWeight"] diff --git a/lightllm/models/glm4_moe_lite_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/glm4_moe_lite_mtp/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..78b4d411c --- /dev/null +++ b/lightllm/models/glm4_moe_lite_mtp/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,48 @@ +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + EmbeddingWeight, + LMHeadWeight, + RMSNormWeight, + ROWMMWeight, +) +from lightllm.common.quantization import Quantcfg + + +class Glm4MoeLiteMTPPreAndPostLayerWeight(PreAndPostLayerWeight): + def __init__(self, data_type, network_config, quant_cfg: Quantcfg): + super().__init__(data_type, network_config) + self.quant_cfg: Quantcfg = quant_cfg + + mtp_layer_idx = network_config["num_hidden_layers"] + hidden_size = network_config["hidden_size"] + + self.eh_proj_weight_ = ROWMMWeight( + in_dim=hidden_size * 2, + out_dims=[hidden_size], + weight_names=f"model.layers.{mtp_layer_idx}.eh_proj.weight", + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(mtp_layer_idx, "eh_proj"), + tp_rank=0, + tp_world_size=1, + ) + + self.enorm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=f"model.layers.{mtp_layer_idx}.enorm.weight", + data_type=self.data_type_, + ) + + self.hnorm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=f"model.layers.{mtp_layer_idx}.hnorm.weight", + data_type=self.data_type_, + ) + + self.final_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=f"model.layers.{mtp_layer_idx}.shared_head.norm.weight", + data_type=self.data_type_, + ) + + self.wte_weight_: EmbeddingWeight = None + self.lm_head_weight_: LMHeadWeight = None diff --git a/lightllm/models/glm4_moe_lite_mtp/model.py b/lightllm/models/glm4_moe_lite_mtp/model.py new file mode 100644 index 000000000..dea415abb --- /dev/null +++ b/lightllm/models/glm4_moe_lite_mtp/model.py @@ -0,0 +1,90 @@ +from typing import List +from lightllm.models.glm4_moe_lite.model import Glm4MoeLiteTpPartModel +from lightllm.models.glm4_moe_lite_mtp.layer_infer.pre_layer_infer import Glm4MoeLiteMTPPreLayerInfer +from lightllm.models.glm4_moe_lite_mtp.layer_weights.pre_and_post_layer_weight import ( + Glm4MoeLiteMTPPreAndPostLayerWeight, +) +from lightllm.common.basemodel import TpPartBaseModel +from lightllm.common.basemodel.basemodel import load_hf_weights + + +class Glm4MoeLiteMTPModel(Glm4MoeLiteTpPartModel): + + pre_and_post_weight_class = Glm4MoeLiteMTPPreAndPostLayerWeight + pre_layer_infer_class = Glm4MoeLiteMTPPreLayerInfer + + def __init__(self, kvargs: dict): + self._pre_init(kvargs) + super().__init__(kvargs) + + def _pre_init(self, kvargs: dict): + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mtp_previous_draft_models: List[TpPartBaseModel] = kvargs.pop("mtp_previous_draft_models") + + def _init_custom(self): + self._cos_cached = self.main_model._cos_cached + self._sin_cached = self.main_model._sin_cached + + def _init_req_manager(self): + self.req_manager = self.main_model.req_manager + + def _init_mem_manager(self): + self.mem_manager = self.main_model.mem_manager + + def _init_weights(self, start_layer_index=None): + assert start_layer_index is None + + mtp_layer_start = self.config["num_hidden_layers"] + num_mtp_layers = self.config.get("num_nextn_predict_layers", 1) + + self.pre_post_weight = self.pre_and_post_weight_class( + self.data_type, network_config=self.config, quant_cfg=self.quant_cfg + ) + + self.trans_layers_weight = [ + self.transformer_weight_class( + i, + self.data_type, + network_config=self.config, + quant_cfg=self.quant_cfg, + ) + for i in range(mtp_layer_start, mtp_layer_start + num_mtp_layers) + ] + + load_hf_weights( + self.data_type, + weight_dir=self.weight_dir_, + pre_post_layer=self.pre_post_weight, + transformer_layer_list=self.trans_layers_weight, + weight_dict=self.weight_dict, + ) + + self.pre_post_weight.verify_load() + [weight.verify_load() for weight in self.trans_layers_weight] + + self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ + self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ + + def _init_infer_layer(self, start_layer_index=None): + assert start_layer_index is None + + self.pre_infer = self.pre_layer_infer_class(network_config=self.config) + self.post_infer = self.post_layer_infer_class(network_config=self.config) + + total_pre_layers_num = len(self.main_model.layers_infer) + total_pre_layers_num += sum( + [len(previous_model.layers_infer) for previous_model in self.mtp_previous_draft_models] + ) + + num_mtp_layers = self.config.get("num_nextn_predict_layers", 1) + self.layers_infer = [ + self.transformer_layer_infer_class(i, network_config=self.config) + for i in range(total_pre_layers_num, total_pre_layers_num + num_mtp_layers) + ] + + def _init_some_value(self): + super()._init_some_value() + self.layers_num = self.config.get("num_nextn_predict_layers", 1) + + def autotune_layers(self): + return self.config.get("num_nextn_predict_layers", 1) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index e49b0cc67..1652318f6 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -119,7 +119,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--batch_max_tokens", type=int, - default=None, + default=16384, help="max tokens num for new cat batch, it control prefill batch size to Preventing OOM", ) parser.add_argument( @@ -128,7 +128,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--tool_call_parser", type=str, - choices=["qwen25", "llama3", "mistral", "deepseekv3", "qwen", "deepseekv31"], + choices=["qwen25", "llama3", "mistral", "deepseekv3", "qwen", "deepseekv31", "glm47", "kimi_k2"], default=None, help="tool call parser type", ) @@ -259,7 +259,7 @@ def make_argument_parser() -> argparse.ArgumentParser: ) parser.add_argument("--disable_dynamic_prompt_cache", action="store_true", help="disable dynamic prompt cache") - parser.add_argument("--chunked_prefill_size", type=int, default=4096, help="chunked prefill size") + parser.add_argument("--chunked_prefill_size", type=int, default=8192, help="chunked prefill size") parser.add_argument("--disable_chunked_prefill", action="store_true", help="whether to disable chunked prefill") parser.add_argument("--diverse_mode", action="store_true", help="diversity generation mode") parser.add_argument("--token_healing_mode", action="store_true", help="code model infer mode") diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py index 1620cff13..9214715b1 100644 --- a/lightllm/server/function_call_parser.py +++ b/lightllm/server/function_call_parser.py @@ -241,7 +241,7 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami if start_idx >= len(current_text): return StreamingParseResult() - (obj, end_idx) = _partial_json_loads(current_text[start_idx:], flags) + obj, end_idx = _partial_json_loads(current_text[start_idx:], flags) is_current_complete = _is_complete_json(current_text[start_idx : start_idx + end_idx]) @@ -1173,6 +1173,276 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami return StreamingParseResult(normal_text=current_text) +class Glm47Detector(BaseFormatDetector): + """ + Detector for GLM-4.7/GLM-4.7-Flash model function call format. + + The GLM-4.7 format uses an XML-style envelope with arg_key/arg_value pairs + instead of JSON arguments. + + Format Structure: + ``` + function_name + param1 + value1 + param2 + value2 + + ``` + + Example: + ``` + tool_brave_web_search_post + query + test search + count + 5 + + ``` + + Key Components: + - Tool Call Tags: `` and `` wrap each individual call + - Function Name: Appears on the first line after `` + - Arguments: Pairs of `name` and `value` + + Reference: https://github.com/vllm-project/vllm/blob/main/vllm/tool_parsers/glm4_moe_tool_parser.py + """ + + def __init__(self): + super().__init__() + self.bot_token = "" + self.eot_token = "" + self.tool_call_separator = "\n" + + # Regex patterns for parsing GLM-4.7 tool calls + # Match complete tool call blocks + self.func_call_regex = re.compile(r".*?", re.DOTALL) + # Extract function name and arguments from a tool call block + # Function name can be followed by newline OR directly by + # Pattern: function_name(\n|)... + self.func_detail_regex = re.compile( + r"([^<\n]+?)(?:\n|(?=)|(?=))(.*?)", re.DOTALL + ) + # Extract arg_key/arg_value pairs + self.func_arg_regex = re.compile(r"(.*?)\s*(.*?)", re.DOTALL) + + self._last_arguments = "" + self._normal_text_buffer = "" + + def has_tool_call(self, text: str) -> bool: + """Check if the text contains a GLM-4.7 format tool call.""" + return self.bot_token in text + + def _parse_xml_arguments(self, arg_text: str) -> dict: + """ + Parse XML-style arguments into a dictionary. + + Args: + arg_text: The text containing / pairs + + Returns: + Dictionary of argument name to value + """ + if not arg_text: + return {} + + args = {} + matches = self.func_arg_regex.findall(arg_text) + for key, value in matches: + key = key.strip() + value = value.strip() + # Try to parse value as JSON for complex types (arrays, objects, numbers, booleans) + try: + parsed_value = json.loads(value) + args[key] = parsed_value + except (json.JSONDecodeError, ValueError): + # Keep as string if not valid JSON + args[key] = value + return args + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + """ + One-time parsing: Detects and parses tool calls in the provided text. + + :param text: The complete text to parse. + :param tools: List of available tools. + :return: StreamingParseResult with normal_text and parsed calls. + """ + idx = text.find(self.bot_token) + normal_text = text[:idx].strip() if idx != -1 else text + + if self.bot_token not in text: + return StreamingParseResult(normal_text=normal_text, calls=[]) + + tool_indices = self._get_tool_indices(tools) + calls = [] + + # Find all ... blocks + match_result_list = self.func_call_regex.findall(text) + + for match_result in match_result_list: + try: + # Extract function name and arguments + func_detail = self.func_detail_regex.search(match_result) + if not func_detail: + logger.warning(f"Failed to parse GLM-4.7 tool call: {match_result}") + continue + + func_name = func_detail.group(1).strip() + arg_text = func_detail.group(2) if func_detail.group(2) else "" + + # Validate function name + if func_name not in tool_indices: + logger.warning(f"Model attempted to call undefined function: {func_name}") + continue + + # Parse XML arguments to JSON + func_args = self._parse_xml_arguments(arg_text) + + calls.append( + ToolCallItem( + tool_index=tool_indices[func_name], + name=func_name, + parameters=json.dumps(func_args, ensure_ascii=False), + ) + ) + except Exception as e: + logger.warning(f"Failed to parse GLM-4.7 tool call: {match_result}, error: {str(e)}") + continue + + return StreamingParseResult(normal_text=normal_text, calls=calls) + + def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> StreamingParseResult: + """ + Streaming incremental parsing for GLM-4.7 tool calls. + + This handles the streaming case where tool calls arrive incrementally. + """ + self._buffer += new_text + current_text = self._buffer + + # Check if we have a tool call starting + if not self.has_tool_call(current_text): + # Check for partial bot_token at the end + partial_len = self._ends_with_partial_token(current_text, self.bot_token) + if partial_len: + # Might be partial bot_token, keep buffering + return StreamingParseResult() + + # No tool call, emit as normal text + self._buffer = "" + # Clean up any stray end tokens + if self.eot_token in new_text: + new_text = new_text.replace(self.eot_token, "") + return StreamingParseResult(normal_text=new_text) + + # Build tool indices if not already built + if not hasattr(self, "_tool_indices"): + self._tool_indices = self._get_tool_indices(tools) + + calls: List[ToolCallItem] = [] + + try: + # Check if we have a complete tool call + if self.eot_token in current_text: + # We have at least one complete tool call + # Parse all complete tool calls + result = self.detect_and_parse(current_text, tools) + + # Find the end of the last complete tool call + last_end = current_text.rfind(self.eot_token) + if last_end != -1: + remaining = current_text[last_end + len(self.eot_token) :] + self._buffer = remaining.lstrip() + else: + self._buffer = "" + + # Reset state for next tool call + self.current_tool_id = -1 + self.current_tool_name_sent = False + self._last_arguments = "" + + return result + + # We have a partial tool call - try to stream it + # Extract what we can from the partial tool call + tool_call_start = current_text.find(self.bot_token) + if tool_call_start == -1: + return StreamingParseResult() + + # Get content after + content_after_start = current_text[tool_call_start + len(self.bot_token) :] + + # Try to extract function name (first line after ) + newline_pos = content_after_start.find("\n") + if newline_pos == -1: + # Still waiting for function name to complete + return StreamingParseResult() + + func_name = content_after_start[:newline_pos].strip() + + # Initialize state if this is the first tool call + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [""] + + # Ensure we have enough entries + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + # Check if function name is valid + if func_name and func_name in self._tool_indices: + if not self.current_tool_name_sent: + # Send function name first + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=func_name, + parameters="", + ) + ) + self.current_tool_name_sent = True + self.prev_tool_call_arr[self.current_tool_id] = { + "name": func_name, + "arguments": {}, + } + else: + # Stream arguments incrementally + arg_text = content_after_start[newline_pos + 1 :] + current_args = self._parse_xml_arguments(arg_text) + + if current_args: + current_args_json = json.dumps(current_args, ensure_ascii=False) + prev_args = self.prev_tool_call_arr[self.current_tool_id].get("arguments", {}) + prev_args_json = json.dumps(prev_args, ensure_ascii=False) if prev_args else "" + + if current_args_json != prev_args_json: + # Calculate the diff + sent = len(self.streamed_args_for_tool[self.current_tool_id]) + argument_diff = current_args_json[sent:] + + if argument_diff: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=argument_diff, + ) + ) + self.streamed_args_for_tool[self.current_tool_id] += argument_diff + + self.prev_tool_call_arr[self.current_tool_id]["arguments"] = current_args + + return StreamingParseResult(normal_text="", calls=calls) + + except Exception as e: + logger.error(f"Error in GLM-4.7 parse_streaming_increment: {e}") + return StreamingParseResult(normal_text="", calls=calls) + + class FunctionCallParser: """ Parser for function/tool calls in model outputs. @@ -1185,6 +1455,7 @@ class FunctionCallParser: ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = { "deepseekv3": DeepSeekV3Detector, "deepseekv31": DeepSeekV31Detector, + "glm47": Glm47Detector, "kimi_k2": KimiK2Detector, "llama3": Llama32Detector, "mistral": MistralDetector, diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 805c9b8e5..64310d6b0 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -37,6 +37,7 @@ from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel from lightllm.models.qwen3_moe_mtp.model import Qwen3MOEMTPModel from lightllm.models.mistral_mtp.model import MistralMTPModel +from lightllm.models.glm4_moe_lite_mtp.model import Glm4MoeLiteMTPModel from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet @@ -328,6 +329,9 @@ def init_mtp_draft_model(self, main_kvargs: dict): elif mtp_model_cfg["model_type"] == "mistral": assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"] self.draft_models.append(MistralMTPModel(mtp_model_kvargs)) + elif mtp_model_cfg["model_type"] == "glm4_moe_lite": + assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] + self.draft_models.append(Glm4MoeLiteMTPModel(mtp_model_kvargs)) else: assert False, f"error mtp mode {mtp_model_cfg['model_type']}" diff --git a/test/eval/eval_bfcl.py b/test/eval/eval_bfcl.py new file mode 100755 index 000000000..59b81c26d --- /dev/null +++ b/test/eval/eval_bfcl.py @@ -0,0 +1,686 @@ +#!/usr/bin/env python3 +""" +BFCL (Berkeley Function Calling Leaderboard) Evaluation Script for LightLLM + +This script evaluates function/tool calling capabilities on the BFCL benchmark. + +Usage: + # Start LightLLM server first: + python -m lightllm.server.api_server --model_dir /path/to/GLM-4.7-Flash --tp 1 + + # Run evaluation: + python eval_bfcl.py \ + --model_name GLM-4.7-Flash \ + --base_url http://localhost:8000/v1 \ + --test_category simple + +Test Categories: + - simple: Single function calls (400 examples) + - multiple: Select one function from multiple options (200 examples) + - parallel: Multiple function calls in parallel (200 examples) + - parallel_multiple: Combination of parallel and multiple (200 examples) + - java: Java function calls (100 examples) + - javascript: JavaScript function calls (70 examples) + - irrelevance: Detect when no function should be called + - all: Run all categories + +Requirements: + pip install openai tqdm huggingface_hub +""" + +import argparse +import json +import os +import re +import ast +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import List, Dict, Any, Optional, Tuple +from dataclasses import dataclass +from collections import defaultdict + +from tqdm import tqdm + +try: + from openai import OpenAI +except ImportError: + print("Please install openai: pip install openai") + exit(1) + +try: + from huggingface_hub import hf_hub_download +except ImportError: + print("Please install huggingface_hub: pip install huggingface_hub") + exit(1) + + +# BFCL Dataset on HuggingFace +BFCL_REPO = "gorilla-llm/Berkeley-Function-Calling-Leaderboard" + +# Test category mappings to filenames +TEST_CATEGORIES = { + "simple": "BFCL_v3_simple.json", + "multiple": "BFCL_v3_multiple.json", + "parallel": "BFCL_v3_parallel.json", + "parallel_multiple": "BFCL_v3_parallel_multiple.json", + "java": "BFCL_v3_java.json", + "javascript": "BFCL_v3_javascript.json", + "irrelevance": "BFCL_v3_irrelevance.json", + "live_simple": "BFCL_v3_live_simple.json", + "live_multiple": "BFCL_v3_live_multiple.json", + "live_parallel": "BFCL_v3_live_parallel.json", + "live_parallel_multiple": "BFCL_v3_live_parallel_multiple.json", + "rest": "BFCL_v3_rest.json", + "sql": "BFCL_v3_sql.json", +} + +# Possible answer files for ground truth +ANSWER_FILES = { + "simple": "possible_answer/BFCL_v3_simple.json", + "multiple": "possible_answer/BFCL_v3_multiple.json", + "parallel": "possible_answer/BFCL_v3_parallel.json", + "parallel_multiple": "possible_answer/BFCL_v3_parallel_multiple.json", + "java": "possible_answer/BFCL_v3_java.json", + "javascript": "possible_answer/BFCL_v3_javascript.json", + "live_simple": "possible_answer/BFCL_v3_live_simple.json", + "live_multiple": "possible_answer/BFCL_v3_live_multiple.json", + "live_parallel": "possible_answer/BFCL_v3_live_parallel.json", + "live_parallel_multiple": "possible_answer/BFCL_v3_live_parallel_multiple.json", + "sql": "possible_answer/BFCL_v3_sql.json", +} + + +@dataclass +class EvalResult: + """Result of a single evaluation.""" + + task_id: str + category: str + passed: bool + model_output: str + expected: Any + error: Optional[str] = None + + +def download_bfcl_file(filename: str) -> str: + """Download a BFCL file from HuggingFace Hub.""" + try: + local_path = hf_hub_download( + repo_id=BFCL_REPO, + filename=filename, + repo_type="dataset", + ) + return local_path + except Exception as e: + print(f"Error downloading {filename}: {e}") + return None + + +def load_jsonl_or_json(filepath: str) -> List[Dict[str, Any]]: + """Load data from JSON or JSONL file.""" + data = [] + with open(filepath, "r", encoding="utf-8") as f: + content = f.read().strip() + # Try as JSON array first + try: + data = json.loads(content) + if isinstance(data, dict): + data = [data] + except json.JSONDecodeError: + # Try as JSONL + f.seek(0) + for line in f: + line = line.strip() + if line: + try: + data.append(json.loads(line)) + except json.JSONDecodeError: + continue + return data + + +def load_bfcl_data(category: str, limit: Optional[int] = None) -> List[Dict[str, Any]]: + """Load BFCL dataset for a specific category.""" + filename = TEST_CATEGORIES.get(category) + if not filename: + print(f"Unknown category: {category}") + return [] + + print(f"Downloading {filename} from HuggingFace...") + filepath = download_bfcl_file(filename) + if not filepath: + return [] + + print(f"Loading data from {filepath}") + data = load_jsonl_or_json(filepath) + + # Also load ground truth answers if available + answer_file = ANSWER_FILES.get(category) + if answer_file: + print(f"Downloading answer file {answer_file}...") + answer_path = download_bfcl_file(answer_file) + if answer_path: + answers = load_jsonl_or_json(answer_path) + # Create a mapping from id to answer + answer_map = {} + for ans in answers: + ans_id = ans.get("id", "") + answer_map[ans_id] = ans.get("ground_truth", ans.get("result", [])) + + # Merge answers into data + for item in data: + item_id = item.get("id", "") + if item_id in answer_map: + item["ground_truth"] = answer_map[item_id] + + if limit: + data = data[:limit] + + print(f"Loaded {len(data)} examples for category: {category}") + return data + + +def fix_schema_types(schema: Any) -> Any: + """ + Fix Python type names to JSON Schema types. + BFCL uses Python type names like 'dict', 'list' but JSON Schema needs 'object', 'array'. + """ + if isinstance(schema, dict): + result = {} + for key, value in schema.items(): + if key == "type" and isinstance(value, str): + # Map Python types to JSON Schema types + type_mapping = { + "dict": "object", + "list": "array", + "str": "string", + "int": "integer", + "float": "number", + "bool": "boolean", + "NoneType": "null", + "tuple": "array", + } + result[key] = type_mapping.get(value, value) + else: + result[key] = fix_schema_types(value) + return result + elif isinstance(schema, list): + return [fix_schema_types(item) for item in schema] + else: + return schema + + +def convert_to_openai_tools(functions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Convert BFCL function format to OpenAI tools format.""" + tools = [] + for func in functions: + if isinstance(func, str): + func = json.loads(func) + + # Fix the parameters schema to use valid JSON Schema types + parameters = fix_schema_types(func.get("parameters", {})) + + tool = { + "type": "function", + "function": { + "name": func.get("name", ""), + "description": func.get("description", ""), + "parameters": parameters, + }, + } + tools.append(tool) + return tools + + +def parse_function_call(response: str) -> List[Dict[str, Any]]: + """Parse function calls from model response.""" + calls = [] + + # Try to parse as JSON array + try: + parsed = json.loads(response) + if isinstance(parsed, list): + return parsed + elif isinstance(parsed, dict): + return [parsed] + except json.JSONDecodeError: + pass + + # Try to find function call patterns + # Pattern 1: function_name(args) + func_pattern = r"(\w+)\s*\((.*?)\)" + matches = re.findall(func_pattern, response, re.DOTALL) + for name, args_str in matches: + try: + # Try to parse args as Python dict/kwargs + args_str = args_str.strip() + if args_str: + # Convert to dict format + args = eval(f"dict({args_str})") + else: + args = {} + calls.append({"name": name, "arguments": args}) + except: + pass + + # Pattern 2: JSON-like tool_calls + tool_call_pattern = r'\{"name":\s*"([^"]+)",\s*"arguments":\s*(\{[^}]+\})\}' + matches = re.findall(tool_call_pattern, response) + for name, args_str in matches: + try: + args = json.loads(args_str) + calls.append({"name": name, "arguments": args}) + except: + pass + + return calls + + +def extract_tool_calls_from_response(response) -> List[Dict[str, Any]]: + """Extract tool calls from OpenAI API response.""" + calls = [] + + if hasattr(response, "choices") and response.choices: + choice = response.choices[0] + message = choice.message + + # Check for tool_calls in response + if hasattr(message, "tool_calls") and message.tool_calls: + for tool_call in message.tool_calls: + func = tool_call.function + try: + args = json.loads(func.arguments) if func.arguments else {} + except json.JSONDecodeError: + args = {} + calls.append({"name": func.name, "arguments": args}) + + # Also check content for function calls (some models output in content) + if hasattr(message, "content") and message.content: + content_calls = parse_function_call(message.content) + if content_calls and not calls: + calls = content_calls + + return calls + + +def normalize_value(value: Any) -> Any: + """Normalize values for comparison.""" + if isinstance(value, str): + # Try to parse as number + try: + return float(value) + except ValueError: + return value.lower().strip() + elif isinstance(value, bool): + return value + elif isinstance(value, (int, float)): + return float(value) + elif isinstance(value, list): + return [normalize_value(v) for v in value] + elif isinstance(value, dict): + return {k: normalize_value(v) for k, v in value.items()} + return value + + +def value_matches_expected(predicted_value: Any, expected_values: Any) -> bool: + """ + Check if predicted value matches expected value(s). + BFCL format: expected values can be a list of acceptable values. + """ + # Normalize predicted value + pred_normalized = normalize_value(predicted_value) + + # If expected is a list, check if predicted matches any item + if isinstance(expected_values, list): + for exp_val in expected_values: + exp_normalized = normalize_value(exp_val) + if pred_normalized == exp_normalized: + return True + # Also try string comparison for edge cases + if str(pred_normalized) == str(exp_normalized): + return True + return False + else: + exp_normalized = normalize_value(expected_values) + return pred_normalized == exp_normalized or str(pred_normalized) == str(exp_normalized) + + +def compare_function_calls( + predicted: List[Dict[str, Any]], expected: List[Dict[str, Any]], strict: bool = False +) -> Tuple[bool, str]: + """Compare predicted function calls with expected ones.""" + if not predicted and not expected: + return True, "" + + if len(predicted) != len(expected): + return False, f"Count mismatch: predicted {len(predicted)}, expected {len(expected)}" + + # Sort by function name for comparison + pred_sorted = sorted(predicted, key=lambda x: x.get("name", "")) + exp_sorted = sorted(expected, key=lambda x: x.get("name", "")) + + for pred, exp in zip(pred_sorted, exp_sorted): + pred_name = pred.get("name", "") + exp_name = exp.get("name", "") + + if pred_name != exp_name: + return False, f"Function name mismatch: {pred_name} vs {exp_name}" + + pred_args = pred.get("arguments", {}) + exp_args = exp.get("arguments", {}) + + # Check required arguments match (BFCL format: values are lists of acceptable values) + for key, expected_values in exp_args.items(): + if key not in pred_args: + return False, f"Missing argument {key} in {pred_name}" + if not value_matches_expected(pred_args[key], expected_values): + return False, f"Argument {key} mismatch in {pred_name}" + + return True, "" + + +def parse_expected_output(ground_truth: Any) -> List[Dict[str, Any]]: + """ + Parse expected output from BFCL ground truth. + + BFCL format: [{"func_name": {"arg1": [val1, val2], "arg2": [val3]}}] + Convert to: [{"name": "func_name", "arguments": {"arg1": [val1, val2], "arg2": [val3]}}] + """ + if isinstance(ground_truth, str): + try: + ground_truth = json.loads(ground_truth) + except json.JSONDecodeError: + # Try parsing as Python literal + try: + ground_truth = ast.literal_eval(ground_truth) + except: + return [] + + if not ground_truth: + return [] + + # Ensure it's a list + if isinstance(ground_truth, dict): + ground_truth = [ground_truth] + + result = [] + for item in ground_truth: + if isinstance(item, dict): + # Check if it's already in standard format {"name": ..., "arguments": ...} + if "name" in item and "arguments" in item: + result.append(item) + else: + # BFCL format: {"func_name": {"arg1": [v1], "arg2": [v2]}} + for func_name, args in item.items(): + if isinstance(args, dict): + result.append({"name": func_name, "arguments": args}) + else: + # Handle edge case where args might not be a dict + result.append({"name": func_name, "arguments": {}}) + + return result + + +class BFCLEvaluator: + """BFCL Benchmark Evaluator using OpenAI-compatible API.""" + + def __init__( + self, + base_url: str, + model_name: str, + api_key: str = "EMPTY", + max_tokens: int = 1024, + temperature: float = 0.0, + ): + self.client = OpenAI(base_url=base_url, api_key=api_key) + self.model_name = model_name + self.max_tokens = max_tokens + self.temperature = temperature + + def generate_response( + self, prompt: str, tools: List[Dict[str, Any]], system_prompt: Optional[str] = None + ) -> Tuple[Any, List[Dict[str, Any]]]: + """Generate response from the model with tool calling.""" + messages = [] + + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + + messages.append({"role": "user", "content": prompt}) + + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=messages, + tools=tools if tools else None, + tool_choice="auto" if tools else None, + max_tokens=self.max_tokens, + temperature=self.temperature, + ) + tool_calls = extract_tool_calls_from_response(response) + return response, tool_calls + except Exception as e: + print(f"API Error: {e}") + return None, [] + + def evaluate_single(self, item: Dict[str, Any], category: str) -> EvalResult: + """Evaluate a single BFCL example.""" + task_id = item.get("id", "unknown") + + # Extract question and functions + question = item.get("question", [[{"role": "user", "content": ""}]]) + if isinstance(question, str): + prompt = question + elif isinstance(question, list) and question: + if isinstance(question[0], dict): + prompt = question[0].get("content", "") + elif isinstance(question[0], list) and question[0]: + prompt = question[0][0].get("content", "") + else: + prompt = str(question[0]) + else: + prompt = str(question) + + # Get functions + functions = item.get("function", []) + if isinstance(functions, str): + try: + functions = json.loads(functions) + except: + functions = [] + + if not isinstance(functions, list): + functions = [functions] + + # Convert to OpenAI tools format + tools = convert_to_openai_tools(functions) + + # Get expected output + ground_truth = item.get("ground_truth", item.get("answer", [])) + expected = parse_expected_output(ground_truth) + + # Generate response + system_prompt = ( + "You are a helpful assistant that can use tools/functions to help answer questions. " + "When you need to call a function, use the provided tools." + ) + + response, predicted_calls = self.generate_response(prompt, tools, system_prompt) + + if response is None: + return EvalResult( + task_id=task_id, + category=category, + passed=False, + model_output="", + expected=expected, + error="API call failed", + ) + + # For irrelevance category, model should NOT call any function + if "irrelevance" in category.lower(): + passed = len(predicted_calls) == 0 + error = "Model called function when it shouldn't" if not passed else None + else: + # Compare function calls + passed, error = compare_function_calls(predicted_calls, expected) + + model_output = json.dumps(predicted_calls, indent=2) if predicted_calls else str(response) + + return EvalResult( + task_id=task_id, category=category, passed=passed, model_output=model_output, expected=expected, error=error + ) + + def evaluate_category(self, category: str, limit: Optional[int] = None, num_workers: int = 4) -> Dict[str, Any]: + """Evaluate all examples in a category.""" + print(f"\nLoading BFCL dataset for category: {category}") + data = load_bfcl_data(category, limit) + + if not data: + print(f"No data found for category: {category}") + return {"category": category, "total": 0, "passed": 0, "accuracy": 0.0} + + print(f"Loaded {len(data)} examples") + + results = [] + + # Use ThreadPoolExecutor for concurrent evaluation + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = {executor.submit(self.evaluate_single, item, category): item for item in data} + + for future in tqdm(as_completed(futures), total=len(futures), desc=f"Evaluating {category}"): + try: + result = future.result() + results.append(result) + except Exception as e: + print(f"Error evaluating: {e}") + + # Calculate metrics + total = len(results) + passed = sum(1 for r in results if r.passed) + accuracy = passed / total * 100 if total > 0 else 0.0 + + # Collect errors for analysis + errors = defaultdict(int) + for r in results: + if not r.passed and r.error: + errors[r.error[:50]] += 1 + + return { + "category": category, + "total": total, + "passed": passed, + "accuracy": accuracy, + "results": results, + "error_summary": dict(errors), + } + + +def main(): + parser = argparse.ArgumentParser(description="BFCL Evaluation for LightLLM") + parser.add_argument("--model_name", type=str, required=True, help="Model name") + parser.add_argument( + "--base_url", type=str, default="http://localhost:8000/v1", help="OpenAI-compatible API base URL" + ) + parser.add_argument("--api_key", type=str, default="EMPTY", help="API key (use EMPTY for local)") + parser.add_argument( + "--test_category", + type=str, + default="simple", + choices=list(TEST_CATEGORIES.keys()) + ["all"], + help="Test category to evaluate", + ) + parser.add_argument("--limit", type=int, default=None, help="Limit number of examples (for testing)") + parser.add_argument("--num_workers", type=int, default=4, help="Number of concurrent workers") + parser.add_argument("--max_tokens", type=int, default=1024, help="Maximum tokens to generate") + parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature") + parser.add_argument("--output", "-o", type=str, default=None, help="Output file for detailed results") + + args = parser.parse_args() + + print("=" * 60) + print("BFCL (Berkeley Function Calling Leaderboard) Evaluation") + print("=" * 60) + print(f"Model: {args.model_name}") + print(f"API URL: {args.base_url}") + print(f"Test Category: {args.test_category}") + print() + + evaluator = BFCLEvaluator( + base_url=args.base_url, + model_name=args.model_name, + api_key=args.api_key, + max_tokens=args.max_tokens, + temperature=args.temperature, + ) + + # Determine categories to evaluate + if args.test_category == "all": + categories = list(TEST_CATEGORIES.keys()) + else: + categories = [args.test_category] + + all_results = {} + + for category in categories: + result = evaluator.evaluate_category(category, limit=args.limit, num_workers=args.num_workers) + all_results[category] = result + + print(f"\n{category.upper()} Results:") + print(f" Total: {result['total']}") + print(f" Passed: {result['passed']}") + print(f" Accuracy: {result['accuracy']:.2f}%") + + if result.get("error_summary"): + print(" Common errors:") + for error, count in sorted(result["error_summary"].items(), key=lambda x: -x[1])[:5]: + print(f" - {error}: {count}") + + # Print summary + print("\n" + "=" * 60) + print("SUMMARY") + print("=" * 60) + print(f"{'Category':<25} {'Total':>8} {'Passed':>8} {'Accuracy':>10}") + print("-" * 60) + + total_all = 0 + passed_all = 0 + + for category, result in all_results.items(): + print(f"{category:<25} {result['total']:>8} {result['passed']:>8} {result['accuracy']:>9.2f}%") + total_all += result["total"] + passed_all += result["passed"] + + if len(all_results) > 1: + print("-" * 60) + overall_acc = passed_all / total_all * 100 if total_all > 0 else 0 + print(f"{'OVERALL':<25} {total_all:>8} {passed_all:>8} {overall_acc:>9.2f}%") + + print("=" * 60) + + # Save detailed results + if args.output: + output_data = { + "model": args.model_name, + "config": { + "base_url": args.base_url, + "max_tokens": args.max_tokens, + "temperature": args.temperature, + }, + "results": { + cat: { + "total": r["total"], + "passed": r["passed"], + "accuracy": r["accuracy"], + "error_summary": r.get("error_summary", {}), + } + for cat, r in all_results.items() + }, + } + with open(args.output, "w") as f: + json.dump(output_data, f, indent=2) + print(f"\nResults saved to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/test/eval/requirements.txt b/test/eval/requirements.txt new file mode 100644 index 000000000..e57d2da41 --- /dev/null +++ b/test/eval/requirements.txt @@ -0,0 +1,13 @@ +# Evaluation benchmark dependencies +aiohttp>=3.8.0 +tqdm>=4.64.0 +transformers>=4.30.0 +numpy>=1.21.0 +openai>=1.0.0 +huggingface_hub>=0.20.0 + +# Optional: official human-eval package for dataset loading +# pip install git+https://github.com/openai/human-eval.git + +# Optional: official BFCL evaluation package +# pip install bfcl-eval diff --git a/test/eval/run_bfcl.sh b/test/eval/run_bfcl.sh new file mode 100755 index 000000000..2e68f8380 --- /dev/null +++ b/test/eval/run_bfcl.sh @@ -0,0 +1,48 @@ +#!/bin/bash +# BFCL (Berkeley Function Calling Leaderboard) evaluation script for LightLLM +# +# Prerequisites: +# 1. Start LightLLM server: +# python -m lightllm.server.api_server \ +# --model_dir /path/to/GLM-4.7-Flash \ +# --tp 1 \ +# --port 8000 +# +# 2. Install dependencies: +# pip install openai tqdm datasets + +set -e + +# Configuration +MODEL_NAME="${MODEL_NAME:-GLM-4.7-Flash}" +BASE_URL="${BASE_URL:-http://localhost:8000/v1}" +PORT="${PORT:-8000}" +TEST_CATEGORY="${TEST_CATEGORY:-simple}" +NUM_WORKERS="${NUM_WORKERS:-4}" + +# Check if server is running +if ! curl -s "http://localhost:${PORT}/health" > /dev/null 2>&1; then + echo "Error: LightLLM server not running on port ${PORT}" + echo "Start the server first with:" + echo " python -m lightllm.server.api_server --model_dir /path/to/model --tp 1 --port ${PORT}" + exit 1 +fi + +echo "==========================================" +echo "BFCL Function Calling Evaluation" +echo "==========================================" +echo "Model: ${MODEL_NAME}" +echo "Server: ${BASE_URL}" +echo "Test Category: ${TEST_CATEGORY}" +echo "" + +# Run evaluation +python "$(dirname "$0")/eval_bfcl.py" \ + --model_name "${MODEL_NAME}" \ + --base_url "${BASE_URL}" \ + --test_category "${TEST_CATEGORY}" \ + --num_workers "${NUM_WORKERS}" \ + --output "bfcl_results_${TEST_CATEGORY}_$(date +%Y%m%d_%H%M%S).json" + +echo "" +echo "Evaluation complete!" From 7d9222cc97daf74c061aa3846da06d067333738e Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 28 Jan 2026 02:16:47 +0000 Subject: [PATCH 02/14] fix --- ...num=4,use_fp8_w8a8=false}_NVIDIA_H200.json | 110 ++++++++++++++++++ ...num=4,use_fp8_w8a8=false}_NVIDIA_H200.json | 110 ++++++++++++++++++ ...num=1,use_fp8_w8a8=false}_NVIDIA_H200.json | 110 ++++++++++++++++++ ...num=1,use_fp8_w8a8=false}_NVIDIA_H200.json | 110 ++++++++++++++++++ .../{topk_num=4}_NVIDIA_H200.json | 4 + ...orch.bfloat16,topk_num=4}_NVIDIA_H200.json | 6 + ...=10,dtype=torch.bfloat16}_NVIDIA_H200.json | 74 ++++++++++++ ...M=5,dtype=torch.bfloat16}_NVIDIA_H200.json | 74 ++++++++++++ ...out_dtype=torch.bfloat16}_NVIDIA_H200.json | 74 ++++++++++++ ...out_dtype=torch.bfloat16}_NVIDIA_H200.json | 110 ++++++++++++++++++ ...out_dtype=torch.bfloat16}_NVIDIA_H200.json | 74 ++++++++++++ ...out_dtype=torch.bfloat16}_NVIDIA_H200.json | 104 +++++++++++++++++ .../layer_infer/transformer_layer_infer.py | 2 +- .../layer_weights/transformer_layer_weight.py | 2 +- lightllm/models/glm4_moe_lite/infer_struct.py | 6 - .../layer_weights/transformer_layer_weight.py | 6 - lightllm/models/glm4_moe_lite/model.py | 20 ---- lightllm/server/api_cli.py | 2 +- 18 files changed, 963 insertions(+), 35 deletions(-) create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=1536,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=768,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=10,dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=5,dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=2560,out_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=5120,out_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=768,out_dtype=torch.bfloat16}_NVIDIA_H200.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=1536,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=1536,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..bae1eb462 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=1536,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..292e2e124 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "16640": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..5ea445367 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "400": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "66560": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=768,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=768,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..f8c805f76 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=768,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "400": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "65536": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=4}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=4}_NVIDIA_H200.json index 0f0c175b9..1764ad455 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=4}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=4}_NVIDIA_H200.json @@ -23,6 +23,10 @@ "BLOCK_SIZE": 256, "num_warps": 8 }, + "16640": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, "2048": { "BLOCK_SIZE": 128, "num_warps": 4 diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=4}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=4}_NVIDIA_H200.json index c6c3d54ff..8da1f07f1 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=4}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=4}_NVIDIA_H200.json @@ -35,6 +35,12 @@ "NUM_STAGE": 4, "num_warps": 4 }, + "16640": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, "2048": { "BLOCK_DIM": 1024, "BLOCK_M": 1, diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=10,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=10,dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..41de2bf65 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=10,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_SEQ": 32, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 3, + "num_warps": 8 + }, + "100": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 5, + "num_warps": 1 + }, + "1024": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 2, + "num_stages": 5, + "num_warps": 2 + }, + "128": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 8, + "num_stages": 4, + "num_warps": 1 + }, + "16": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 2, + "num_warps": 2 + }, + "16384": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 1, + "num_stages": 2, + "num_warps": 1 + }, + "2048": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 2, + "num_stages": 2, + "num_warps": 1 + }, + "256": { + "BLOCK_SEQ": 2, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 1, + "num_warps": 2 + }, + "32": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 2, + "num_warps": 8 + }, + "4096": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 1, + "num_stages": 5, + "num_warps": 1 + }, + "64": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 5, + "num_warps": 1 + }, + "8": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 2, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=5,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=5,dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..feb169f56 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=5,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_SEQ": 2, + "HEAD_PARALLEL_NUM": 8, + "num_stages": 3, + "num_warps": 1 + }, + "100": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 8, + "num_stages": 5, + "num_warps": 4 + }, + "1024": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 2, + "num_stages": 4, + "num_warps": 2 + }, + "128": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 8, + "num_stages": 1, + "num_warps": 1 + }, + "16": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 16, + "num_stages": 2, + "num_warps": 1 + }, + "16640": { + "BLOCK_SEQ": 4, + "HEAD_PARALLEL_NUM": 1, + "num_stages": 2, + "num_warps": 1 + }, + "2048": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 1, + "num_stages": 3, + "num_warps": 1 + }, + "256": { + "BLOCK_SEQ": 2, + "HEAD_PARALLEL_NUM": 8, + "num_stages": 3, + "num_warps": 2 + }, + "32": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 8, + "num_stages": 1, + "num_warps": 1 + }, + "4096": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 1, + "num_stages": 3, + "num_warps": 1 + }, + "64": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 8, + "num_stages": 5, + "num_warps": 8 + }, + "8": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 8, + "num_stages": 3, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=2560,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=2560,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..c041200ea --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=2560,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_M": 32, + "BLOCK_N": 64, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "100": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "1024": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "16": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 8 + }, + "16640": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "32": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 8 + }, + "4096": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 8 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..63ee27ea2 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_M": 128, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 8 + }, + "100": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "1024": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "16": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 8 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16640": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 8 + }, + "32": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 8 + }, + "4": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 4, + "num_warps": 8 + }, + "400": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "4096": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 8 + }, + "66560": { + "BLOCK_M": 64, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=5120,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=5120,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..340664b11 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=5120,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_M": 8, + "BLOCK_N": 32, + "NUM_STAGES": 2, + "num_warps": 8 + }, + "100": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "1024": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "16": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 8 + }, + "16384": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 8 + }, + "32": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 8 + }, + "4096": { + "BLOCK_M": 32, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=768,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=768,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..78530fc09 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=768,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,104 @@ +{ + "1": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 8 + }, + "100": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "1024": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "16": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "16384": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "32": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "4": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "400": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "4096": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "65536": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index e1e435cce..023929b96 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -33,7 +33,7 @@ def __init__(self, layer_num, network_config): self.is_moe = ( network_config["n_routed_experts"] is not None and layer_num >= network_config["first_k_dense_replace"] - and layer_num % network_config["moe_layer_freq"] == 0 + and layer_num % network_config.get("moe_layer_freq", 1) == 0 ) self.n_shared_experts = network_config["n_shared_experts"] diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 783e70e64..96adc4b87 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -25,7 +25,7 @@ def _parse_config(self): self.is_moe = ( self.network_config_["n_routed_experts"] is not None and self.layer_num_ >= self.network_config_["first_k_dense_replace"] - and self.layer_num_ % self.network_config_["moe_layer_freq"] == 0 + and self.layer_num_ % self.network_config_.get("moe_layer_freq", 1) == 0 ) self.tp_q_head_num_ = self.network_config_["num_attention_heads"] self.tp_q_head_num_ = self.tp_q_head_num_ // self.tp_world_size_ diff --git a/lightllm/models/glm4_moe_lite/infer_struct.py b/lightllm/models/glm4_moe_lite/infer_struct.py index 92c350abc..38d879d84 100644 --- a/lightllm/models/glm4_moe_lite/infer_struct.py +++ b/lightllm/models/glm4_moe_lite/infer_struct.py @@ -2,11 +2,5 @@ class Glm4MoeLiteInferStateInfo(Deepseek2InferStateInfo): - """Inference state for GLM-4.7-Flash (glm4_moe_lite architecture). - - Inherits from Deepseek2InferStateInfo as GLM-4.7-Flash uses the same - MLA (Multi-Head Latent Attention) mechanism as DeepSeek-V2/V3. - """ - def __init__(self): super().__init__() diff --git a/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py b/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py index 2a10f6fdc..8a2196706 100644 --- a/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py @@ -9,15 +9,12 @@ def __init__(self, layer_num, data_type, network_config, quant_cfg=None): super().__init__(layer_num, data_type, network_config, quant_cfg) def _parse_config(self): - # Call parent's _parse_config to set n_embed, moe_inter, and other required attributes super()._parse_config() - # Override is_moe calculation for GLM4 (no moe_layer_freq check) self.is_moe = self.network_config_.get( "n_routed_experts" ) is not None and self.layer_num_ >= self.network_config_.get("first_k_dense_replace", 0) - # Override num_fused_shared_experts with GLM4-specific logic from lightllm.utils.envs_utils import get_env_start_args self.num_fused_shared_experts = 0 @@ -37,19 +34,16 @@ def load_hf_weights(self, weights): if f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight" in weights: kv_b_proj_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight"] - # for quantized weights, dequantize first if self.quant_cfg.quantized_weight: kv_b_proj_ = weight_dequant( kv_b_proj_.cuda(), weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + weight_scale_suffix].cuda(), ).cpu() - # Use GLM4-specific split methods (different from DeepSeek2's dimensions) k_b_proj_ = self._load_kb(kv_b_proj_) v_b_proj_ = self._load_vb(kv_b_proj_) weights[f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight"] = k_b_proj_ weights[f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight"] = v_b_proj_ - # rename the shared experts weight if self.num_fused_shared_experts > 0: self._rename_shared_experts(weights, weight_scale_suffix) diff --git a/lightllm/models/glm4_moe_lite/model.py b/lightllm/models/glm4_moe_lite/model.py index a5970ab59..367928f2e 100644 --- a/lightllm/models/glm4_moe_lite/model.py +++ b/lightllm/models/glm4_moe_lite/model.py @@ -22,29 +22,9 @@ def __init__(self, kvargs): def _init_config(self): super()._init_config() - - if "moe_layer_freq" not in self.config and self.config.get("n_routed_experts"): - self.config["moe_layer_freq"] = 1 - - if "routed_scaling_factor" not in self.config: - self.config["routed_scaling_factor"] = 1.8 - - if "topk_method" not in self.config: - self.config["topk_method"] = "noaux_tc" - if "scoring_func" not in self.config: self.config["scoring_func"] = "sigmoid" - logger.info( - f"GLM-4.7-Flash config: " - f"n_routed_experts={self.config.get('n_routed_experts')}, " - f"n_shared_experts={self.config.get('n_shared_experts')}, " - f"num_experts_per_tok={self.config.get('num_experts_per_tok')}, " - f"first_k_dense_replace={self.config.get('first_k_dense_replace')}, " - f"routed_scaling_factor={self.config.get('routed_scaling_factor')}, " - f"scoring_func={self.config.get('scoring_func')}" - ) - def _init_custom(self): self._init_to_get_yarn_rotary() dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"]) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 1652318f6..5b69128b4 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -119,7 +119,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--batch_max_tokens", type=int, - default=16384, + default=None, help="max tokens num for new cat batch, it control prefill batch size to Preventing OOM", ) parser.add_argument( From 5ae7fb857fa374ece41f84f05e5f4658ea2630bd Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 28 Jan 2026 02:39:38 +0000 Subject: [PATCH 03/14] add config --- ...6,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json | 9 +++++++++ ...6,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json | 9 +++++++++ ...EAD_NUM=20,dtype=torch.bfloat16}_NVIDIA_H200.json | 6 ++++++ ...=10240,out_dtype=torch.bfloat16}_NVIDIA_H200.json | 6 ++++++ ...N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json | 12 ++++++++++++ 5 files changed, 42 insertions(+) diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json index deb97363c..ea5ca845d 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -98,6 +98,15 @@ "num_stages": 3, "num_warps": 4 }, + "66560": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, "8192": { "BLOCK_SIZE_K": 64, "BLOCK_SIZE_M": 128, diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json index a6c93c3f6..27b601a56 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -53,6 +53,15 @@ "num_stages": 3, "num_warps": 4 }, + "16640": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, "2048": { "BLOCK_SIZE_K": 64, "BLOCK_SIZE_M": 128, diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=20,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=20,dtype=torch.bfloat16}_NVIDIA_H200.json index 5601eab76..056b6a747 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=20,dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=20,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -35,6 +35,12 @@ "num_stages": 5, "num_warps": 1 }, + "16640": { + "BLOCK_SEQ": 4, + "HEAD_PARALLEL_NUM": 1, + "num_stages": 4, + "num_warps": 1 + }, "2048": { "BLOCK_SEQ": 1, "HEAD_PARALLEL_NUM": 2, diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=10240,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=10240,out_dtype=torch.bfloat16}_NVIDIA_H200.json index b82f25e17..70f6af6d6 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=10240,out_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=10240,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -35,6 +35,12 @@ "NUM_STAGES": 4, "num_warps": 1 }, + "16640": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, "2048": { "BLOCK_M": 32, "BLOCK_N": 256, diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json index ab4644621..c4ee7dd0f 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -35,6 +35,12 @@ "NUM_STAGES": 4, "num_warps": 1 }, + "16640": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, "2048": { "BLOCK_M": 8, "BLOCK_N": 256, @@ -89,6 +95,12 @@ "NUM_STAGES": 4, "num_warps": 1 }, + "66560": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, "8": { "BLOCK_M": 1, "BLOCK_N": 64, From 9c9aa283dca5ed9d386bb4f82bc36b576bb91998 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 28 Jan 2026 03:07:00 +0000 Subject: [PATCH 04/14] fix mtp --- ...6,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json | 9 +++++++++ ...6,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json | 9 +++++++++ ...EAD_NUM=10,dtype=torch.bfloat16}_NVIDIA_H200.json | 6 ++++++ ...N=5120,out_dtype=torch.bfloat16}_NVIDIA_H200.json | 6 ++++++ ...{N=768,out_dtype=torch.bfloat16}_NVIDIA_H200.json | 12 ++++++++++++ .../glm4_moe_lite_mtp/layer_infer/pre_layer_infer.py | 8 ++++---- 6 files changed, 46 insertions(+), 4 deletions(-) diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=1536,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=1536,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json index bae1eb462..f7df66542 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=1536,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=1536,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -53,6 +53,15 @@ "num_stages": 3, "num_warps": 4 }, + "16640": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, "2048": { "BLOCK_SIZE_K": 64, "BLOCK_SIZE_M": 128, diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=768,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=768,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json index f8c805f76..063afa697 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=768,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=768,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -98,6 +98,15 @@ "num_stages": 3, "num_warps": 4 }, + "66560": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, "8192": { "BLOCK_SIZE_K": 64, "BLOCK_SIZE_M": 128, diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=10,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=10,dtype=torch.bfloat16}_NVIDIA_H200.json index 41de2bf65..1b75882e7 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=10,dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=10,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -35,6 +35,12 @@ "num_stages": 2, "num_warps": 1 }, + "16640": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 1, + "num_stages": 2, + "num_warps": 1 + }, "2048": { "BLOCK_SEQ": 1, "HEAD_PARALLEL_NUM": 2, diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=5120,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=5120,out_dtype=torch.bfloat16}_NVIDIA_H200.json index 340664b11..796d8f01b 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=5120,out_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=5120,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -35,6 +35,12 @@ "NUM_STAGES": 4, "num_warps": 1 }, + "16640": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, "2048": { "BLOCK_M": 32, "BLOCK_N": 128, diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=768,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=768,out_dtype=torch.bfloat16}_NVIDIA_H200.json index 78530fc09..dfd9dd729 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=768,out_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=768,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -35,6 +35,12 @@ "NUM_STAGES": 4, "num_warps": 1 }, + "16640": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, "2048": { "BLOCK_M": 8, "BLOCK_N": 128, @@ -89,6 +95,12 @@ "NUM_STAGES": 4, "num_warps": 1 }, + "66560": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, "8": { "BLOCK_M": 1, "BLOCK_N": 128, diff --git a/lightllm/models/glm4_moe_lite_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/glm4_moe_lite_mtp/layer_infer/pre_layer_infer.py index 6994d21f7..1f2dc71d5 100644 --- a/lightllm/models/glm4_moe_lite_mtp/layer_infer/pre_layer_infer.py +++ b/lightllm/models/glm4_moe_lite_mtp/layer_infer/pre_layer_infer.py @@ -24,12 +24,12 @@ def _mtp_context_forward( input_embdings.shape[0] == tgt_embdings.shape[0] ), f"shape {input_embdings.shape} != shape {tgt_embdings.shape}" - layer_weight.enorm_weight_.rmsnorm_forward( + layer_weight.enorm_weight_( input=input_embdings, eps=self.eps_, out=input_embdings, ) - layer_weight.hnorm_weight_.rmsnorm_forward( + layer_weight.hnorm_weight_( input=tgt_embdings, eps=self.eps_, out=tgt_embdings, @@ -48,12 +48,12 @@ def _mtp_token_forward( tgt_embdings = infer_state.mtp_draft_input_hiddens assert input_embdings.shape[0] == tgt_embdings.shape[0] - layer_weight.enorm_weight_.rmsnorm_forward( + layer_weight.enorm_weight_( input=input_embdings, eps=self.eps_, out=input_embdings, ) - layer_weight.hnorm_weight_.rmsnorm_forward( + layer_weight.hnorm_weight_( input=tgt_embdings, eps=self.eps_, out=tgt_embdings, From bb19ce4a7f5d7e81dc49b16c1ce5a8db7662d88f Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 28 Jan 2026 03:07:15 +0000 Subject: [PATCH 05/14] add docs --- docs/CN/source/cookbook/glm4_deployment.rst | 213 ++++++++++++++++++++ docs/CN/source/index.rst | 8 +- docs/EN/source/cookbook/glm4_deployment.rst | 213 ++++++++++++++++++++ docs/EN/source/index.rst | 8 +- 4 files changed, 440 insertions(+), 2 deletions(-) create mode 100644 docs/CN/source/cookbook/glm4_deployment.rst create mode 100644 docs/EN/source/cookbook/glm4_deployment.rst diff --git a/docs/CN/source/cookbook/glm4_deployment.rst b/docs/CN/source/cookbook/glm4_deployment.rst new file mode 100644 index 000000000..ec63490d4 --- /dev/null +++ b/docs/CN/source/cookbook/glm4_deployment.rst @@ -0,0 +1,213 @@ +.. _glm4_deployment: + +GLM-4.7-Flash 模型部署指南 +=========================== + +LightLLM 支持 GLM-4.7-Flash (glm4_moe_lite) 模型系列的部署,该模型采用 MoE 架构。本文档提供详细的部署配置、函数调用和 MTP(多令牌预测)支持信息。 + +模型概述 +-------- + +**主要特性:** + +- 分组 MoE,支持 top-k 专家选择 +- 支持 ``vanilla_with_att`` 和 ``eagle_with_att`` MTP 模式 +- 兼容 FlashAttention3 后端 +- 支持 XML 风格参数格式的函数调用 + +模型参考:https://huggingface.co/zai-org/GLM-4.7-Flash + +推荐启动脚本 (H200) +------------------- + +**基础启动命令:** + +.. code-block:: bash + + LIGHTLLM_TRITON_AUTOTUNE_LEVEL=1 LOADWORKER=18 \ + python -m lightllm.server.api_server \ + --model_dir /path/to/GLM-4.7-Flash/ \ + --tp 1 \ + --max_req_total_len 202752 \ + --chunked_prefill_size 8192 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend flashinfer \ + --graph_max_batch_size 512 \ + --tool_call_parser glm47 \ + --reasoning_parser glm45 \ + --host 0.0.0.0 \ + --port 8000 + +**参数说明:** + +- ``LIGHTLLM_TRITON_AUTOTUNE_LEVEL=1``: 启用 Triton 自动调优以获得最佳内核性能 +- ``LOADWORKER=18``: 模型加载线程数,加快权重加载速度 +- ``--tp 1``: 张量并行度(单 GPU) +- ``--max_req_total_len 202752``: 最大请求总长度 +- ``--chunked_prefill_size 8192``: 预填充处理的分块大小 +- ``--llm_prefill_att_backend fa3``: 预填充阶段使用 FlashAttention3 +- ``--llm_decode_att_backend flashinfer``: 解码阶段使用 FlashInfer +- ``--graph_max_batch_size 512``: CUDA graph 最大批处理大小 +- ``--tool_call_parser glm47``: 使用 GLM-4.7 函数调用解析器 +- ``--reasoning_parser glm45``: 使用 GLM-4.5 推理解析器 + +MTP(多令牌预测)模式 +--------------------- + +要启用 MTP 进行推测解码,请添加以下参数: + +.. code-block:: bash + + LIGHTLLM_TRITON_AUTOTUNE_LEVEL=1 LOADWORKER=18 \ + python -m lightllm.server.api_server \ + --model_dir /path/to/GLM-4.7-Flash/ \ + --tp 1 \ + --max_req_total_len 202752 \ + --chunked_prefill_size 8192 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend flashinfer \ + --graph_max_batch_size 512 \ + --tool_call_parser glm47 \ + --reasoning_parser glm45 \ + --mtp_step 4 \ + --mtp_mode eagle_with_att \ + --mtp_draft_model_dir /path/to/GLM-4.7-Flash/ \ + --host 0.0.0.0 \ + --port 8000 + +**MTP 参数说明:** + +- ``--mtp_step 4``: 每个 MTP 步骤预测的令牌数 +- ``--mtp_mode eagle_with_att``: MTP 模式(支持 ``vanilla_with_att`` 和 ``eagle_with_att``) +- ``--mtp_draft_model_dir``: MTP 草稿模型路径 + +函数调用支持 +------------ + +GLM-4.7-Flash 使用新的 ``Glm47Detector`` 类来解析 XML 风格的工具调用。 + +**函数调用格式:** + +.. code-block:: xml + + func_name + keyvalue + + +**特性:** + +- 完整的流式支持,支持增量解析 +- 兼容 OpenAI 风格的函数调用 API + +测试与验证 +---------- + +基础功能测试 +~~~~~~~~~~~~ + +.. code-block:: bash + + curl http://localhost:8000/generate \ + -H "Content-Type: application/json" \ + -d '{ + "inputs": "什么是人工智能?", + "parameters":{ + "max_new_tokens": 100, + "frequency_penalty": 1 + } + }' + +OpenAI 兼容聊天接口 +~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "GLM-4.7-Flash", + "messages": [{"role": "user", "content": "你好"}], + "max_tokens": 100 + }' + +性能基准测试 +------------ + +函数调用测试结果 (BFCL v3) +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. list-table:: + :header-rows: 1 + :widths: 40 20 + + * - 类别 + - LightLLM + * - simple + - 62.50% + * - multiple + - 54.50% + * - parallel + - 69.50% + * - parallel_multiple + - 61.50% + * - java + - 66.00% + * - javascript + - 48.00% + * - irrelevance + - 83.33% + * - live_simple + - 45.74% + * - live_multiple + - 34.00% + * - live_parallel + - 25.00% + * - live_parallel_multiple + - 37.50% + * - rest + - 2.86% + * - sql + - 28.00% + * - **总体** + - **49.12%** + +速度测试结果 (ShareGPT 2000 条提示,4×H200) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. list-table:: + :header-rows: 1 + :widths: 30 20 20 30 + + * - 工作负载 + - 输出 (tok/s) + - TTFT (ms) + - 端到端延迟 (ms) + * - burst + - 6442 + - 11476 + - 27719 + * - high-conc (512) + - **6728** + - 1099 + - 11240 + * - moderate (10 req/s) + - 1798 + - 196 + - 5746 + * - steady (5 req/s) + - 917 + - 154 + - 2797 + +硬件要求 +-------- + +**测试配置:** + +- 4× NVIDIA H200 (每卡 80GB HBM3) +- NVLink 4.0 互联 + +**最低要求:** + +- 基础部署需要单张 NVIDIA H100/H200 GPU(80GB 显存) +- 生产环境建议使用多 GPU 配置 diff --git a/docs/CN/source/index.rst b/docs/CN/source/index.rst index b97b2c759..58b747359 100755 --- a/docs/CN/source/index.rst +++ b/docs/CN/source/index.rst @@ -57,7 +57,13 @@ Lightllm 整合了众多的开源方案的优点,包括但不限于 FasterTran 思考解析(Reasoning Parser) APIServer 参数详解 lightllm api介绍 - + +.. toctree:: + :maxdepth: 1 + :caption: Cookbook + + GLM-4.7-Flash 部署 + .. toctree:: :maxdepth: 1 :caption: 模型支持 diff --git a/docs/EN/source/cookbook/glm4_deployment.rst b/docs/EN/source/cookbook/glm4_deployment.rst new file mode 100644 index 000000000..c3807f812 --- /dev/null +++ b/docs/EN/source/cookbook/glm4_deployment.rst @@ -0,0 +1,213 @@ +.. _glm4_deployment: + +GLM-4.7-Flash Model Deployment Guide +===================================== + +LightLLM supports deployment of GLM-4.7-Flash (glm4_moe_lite) model family with MoE architecture. This document provides detailed information on deployment configuration, function calling, and MTP (Multi-Token Prediction) support. + +Model Overview +-------------- + +**Key Features:** + +- Grouped MoE with top-k expert selection +- Support for ``vanilla_with_att`` and ``eagle_with_att`` MTP modes +- Compatible with FlashAttention3 backend +- Function calling support with XML-style argument format + +Model Reference: https://huggingface.co/zai-org/GLM-4.7-Flash + +Recommended Launch Script (H200) +-------------------------------- + +**Basic Launch Command:** + +.. code-block:: bash + + LIGHTLLM_TRITON_AUTOTUNE_LEVEL=1 LOADWORKER=18 \ + python -m lightllm.server.api_server \ + --model_dir /path/to/GLM-4.7-Flash/ \ + --tp 1 \ + --max_req_total_len 202752 \ + --chunked_prefill_size 8192 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend flashinfer \ + --graph_max_batch_size 512 \ + --tool_call_parser glm47 \ + --reasoning_parser glm45 \ + --host 0.0.0.0 \ + --port 8000 + +**Parameter Description:** + +- ``LIGHTLLM_TRITON_AUTOTUNE_LEVEL=1``: Enable Triton autotuning for optimal kernel performance +- ``LOADWORKER=18``: Number of model loading threads for faster weight loading +- ``--tp 1``: Tensor parallelism degree (single GPU) +- ``--max_req_total_len 202752``: Maximum total request length +- ``--chunked_prefill_size 8192``: Chunk size for prefill processing +- ``--llm_prefill_att_backend fa3``: Use FlashAttention3 for prefill +- ``--llm_decode_att_backend flashinfer``: Use FlashInfer for decode +- ``--graph_max_batch_size 512``: Maximum batch size for CUDA graph +- ``--tool_call_parser glm47``: Use GLM-4.7 function calling parser +- ``--reasoning_parser glm45``: Use GLM-4.5 reasoning parser + +MTP (Multi-Token Prediction) Mode +--------------------------------- + +To enable MTP for speculative decoding, add the following parameters: + +.. code-block:: bash + + LIGHTLLM_TRITON_AUTOTUNE_LEVEL=1 LOADWORKER=18 \ + python -m lightllm.server.api_server \ + --model_dir /path/to/GLM-4.7-Flash/ \ + --tp 1 \ + --max_req_total_len 202752 \ + --chunked_prefill_size 8192 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend flashinfer \ + --graph_max_batch_size 512 \ + --tool_call_parser glm47 \ + --reasoning_parser glm45 \ + --mtp_step 4 \ + --mtp_mode eagle_with_att \ + --mtp_draft_model_dir /path/to/GLM-4.7-Flash/ \ + --host 0.0.0.0 \ + --port 8000 + +**MTP Parameters:** + +- ``--mtp_step 4``: Number of tokens to predict in each MTP step +- ``--mtp_mode eagle_with_att``: MTP mode (supports ``vanilla_with_att`` and ``eagle_with_att``) +- ``--mtp_draft_model_dir``: Path to the draft model for MTP + +Function Calling Support +------------------------ + +GLM-4.7-Flash uses a new ``Glm47Detector`` class for parsing XML-style tool calls. + +**Function Call Format:** + +.. code-block:: xml + + func_name + keyvalue + + +**Features:** + +- Full streaming support for incremental parsing +- Compatible with OpenAI-style function calling API + +Testing and Validation +---------------------- + +Basic Functionality Testing +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + curl http://localhost:8000/generate \ + -H "Content-Type: application/json" \ + -d '{ + "inputs": "What is AI?", + "parameters":{ + "max_new_tokens": 100, + "frequency_penalty": 1 + } + }' + +OpenAI-Compatible Chat Completions +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "GLM-4.7-Flash", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 100 + }' + +Performance Benchmarks +---------------------- + +Function Calling Test Results (BFCL v3) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. list-table:: + :header-rows: 1 + :widths: 40 20 + + * - Category + - LightLLM + * - simple + - 62.50% + * - multiple + - 54.50% + * - parallel + - 69.50% + * - parallel_multiple + - 61.50% + * - java + - 66.00% + * - javascript + - 48.00% + * - irrelevance + - 83.33% + * - live_simple + - 45.74% + * - live_multiple + - 34.00% + * - live_parallel + - 25.00% + * - live_parallel_multiple + - 37.50% + * - rest + - 2.86% + * - sql + - 28.00% + * - **OVERALL** + - **49.12%** + +Speed Test Results (ShareGPT 2000 prompts, 4×H200) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. list-table:: + :header-rows: 1 + :widths: 30 20 20 30 + + * - Workload + - Output (tok/s) + - TTFT (ms) + - E2E Latency (ms) + * - burst + - 6442 + - 11476 + - 27719 + * - high-conc (512) + - **6728** + - 1099 + - 11240 + * - moderate (10 req/s) + - 1798 + - 196 + - 5746 + * - steady (5 req/s) + - 917 + - 154 + - 2797 + +Hardware Requirements +--------------------- + +**Tested Configuration:** + +- 4× NVIDIA H200 (80GB HBM3 each) +- NVLink 4.0 interconnect + +**Minimum Requirements:** + +- Single NVIDIA H100/H200 GPU with 80GB memory for basic deployment +- Multiple GPUs recommended for production workloads diff --git a/docs/EN/source/index.rst b/docs/EN/source/index.rst index 07eaaa42e..5ad3c63c1 100755 --- a/docs/EN/source/index.rst +++ b/docs/EN/source/index.rst @@ -56,7 +56,13 @@ Documentation List Reasoning Parser APIServer Parameters Lightllm API Introduction - + +.. toctree:: + :maxdepth: 1 + :caption: Cookbook + + GLM-4.7-Flash Deployment + .. toctree:: :maxdepth: 1 :caption: Model Support From 08b9c74617ff900fb6f84164af60d0926d962243 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 28 Jan 2026 09:25:21 +0000 Subject: [PATCH 06/14] simplify weight --- .../layer_weights/transformer_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 62 +------------------ 2 files changed, 6 insertions(+), 60 deletions(-) diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 96adc4b87..3eb09f917 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -65,7 +65,9 @@ def _init_weight(self): self._init_norm() def _split_kv_b_proj(self, kv_b_proj_): - kv_b_proj_ = kv_b_proj_.view(self.num_attention_heads, self.qk_nope_head_dim * 2, self.kv_lora_rank) + kv_b_proj_ = kv_b_proj_.view( + self.num_attention_heads, self.qk_nope_head_dim + self.v_head_dim, self.kv_lora_rank + ) k_b_proj_, v_b_proj_ = torch.split(kv_b_proj_, [self.qk_nope_head_dim, self.v_head_dim], dim=-2) # num_attention_heads x qk_nope_head_dim x kv_lora_rank k_b_proj_ = k_b_proj_.contiguous().to(kv_b_proj_.dtype) diff --git a/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py b/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py index 8a2196706..3395c4e20 100644 --- a/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py @@ -17,65 +17,9 @@ def _parse_config(self): from lightllm.utils.envs_utils import get_env_start_args - self.num_fused_shared_experts = 0 - if get_env_start_args().enable_fused_shared_experts and self.is_moe: - moe_mode = os.getenv("MOE_MODE", "TP") - assert moe_mode == "TP" - self.num_fused_shared_experts = self.network_config_.get("n_shared_experts", 0) - - def load_hf_weights(self, weights): - from lightllm.common.basemodel import TransformerLayerWeight - from lightllm.models.deepseek2.triton_kernel.weight_dequant import weight_dequant - - kv_b_quant_method = self.quant_cfg.get_quant_method(self.layer_num_, "kv_b_proj") - weight_scale_suffix = None - if self.quant_cfg.quantized_weight: - weight_scale_suffix = kv_b_quant_method.weight_scale_suffix - - if f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight" in weights: - kv_b_proj_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight"] - if self.quant_cfg.quantized_weight: - kv_b_proj_ = weight_dequant( - kv_b_proj_.cuda(), - weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + weight_scale_suffix].cuda(), - ).cpu() - k_b_proj_ = self._load_kb(kv_b_proj_) - v_b_proj_ = self._load_vb(kv_b_proj_) - weights[f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight"] = k_b_proj_ - weights[f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight"] = v_b_proj_ - - if self.num_fused_shared_experts > 0: - self._rename_shared_experts(weights, weight_scale_suffix) - - return TransformerLayerWeight.load_hf_weights(self, weights) - - def _load_kb(self, kv_b_proj_): - kv_dim = self.qk_nope_head_dim + self.v_head_dim - k_b_proj_ = kv_b_proj_.view(self.num_attention_heads, kv_dim, self.kv_lora_rank)[:, : self.qk_nope_head_dim, :] - return k_b_proj_.contiguous().to(kv_b_proj_.dtype) - - def _load_kb_scale(self, kv_b_proj_, block_size): - kv_dim = self.qk_nope_head_dim + self.v_head_dim - k_b_proj_scale_ = kv_b_proj_.view( - self.num_attention_heads, kv_dim // block_size, self.kv_lora_rank // block_size - )[:, : self.qk_nope_head_dim // block_size, :] - return k_b_proj_scale_.contiguous().to(kv_b_proj_.dtype) - - def _load_vb(self, kv_b_proj_): - kv_dim = self.qk_nope_head_dim + self.v_head_dim - v_b_proj_ = kv_b_proj_.T.view(self.kv_lora_rank, self.num_attention_heads, kv_dim)[ - :, :, self.qk_nope_head_dim : - ].transpose(0, 1) - return v_b_proj_.contiguous().to(kv_b_proj_.dtype) - - def _load_vb_scale(self, kv_b_proj_scale_, block_size): - kv_dim = self.qk_nope_head_dim + self.v_head_dim - v_b_proj_scale_ = kv_b_proj_scale_.T.view( - self.kv_lora_rank // block_size, - self.num_attention_heads, - kv_dim // block_size, - )[:, :, self.qk_nope_head_dim // block_size :].transpose(0, 1) - return v_b_proj_scale_.contiguous().to(kv_b_proj_scale_.dtype) + self.num_fused_shared_experts = self.network_config_.get("n_shared_experts", 0) + if get_env_start_args().enable_ep_moe and self.is_moe: + assert self.num_fused_shared_experts == 0, "n_shared_experts must be 0 when enable_ep_moe" def _init_moe(self): moe_intermediate_size = self.network_config_["moe_intermediate_size"] From 75f7859c135f82ae7ac946a79b5fa6c0c1b9cb79 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 28 Jan 2026 09:26:30 +0000 Subject: [PATCH 07/14] move bfcl scripts --- test/{eval => acc/bfcl}/eval_bfcl.py | 0 test/{eval => acc/bfcl}/requirements.txt | 0 test/{eval => acc/bfcl}/run_bfcl.sh | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename test/{eval => acc/bfcl}/eval_bfcl.py (100%) mode change 100755 => 100644 rename test/{eval => acc/bfcl}/requirements.txt (100%) rename test/{eval => acc/bfcl}/run_bfcl.sh (100%) mode change 100755 => 100644 diff --git a/test/eval/eval_bfcl.py b/test/acc/bfcl/eval_bfcl.py old mode 100755 new mode 100644 similarity index 100% rename from test/eval/eval_bfcl.py rename to test/acc/bfcl/eval_bfcl.py diff --git a/test/eval/requirements.txt b/test/acc/bfcl/requirements.txt similarity index 100% rename from test/eval/requirements.txt rename to test/acc/bfcl/requirements.txt diff --git a/test/eval/run_bfcl.sh b/test/acc/bfcl/run_bfcl.sh old mode 100755 new mode 100644 similarity index 100% rename from test/eval/run_bfcl.sh rename to test/acc/bfcl/run_bfcl.sh From bb7a1b8dc31d9d941ebace6c452c701b08c5a98e Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 28 Jan 2026 09:27:46 +0000 Subject: [PATCH 08/14] fix --- .../layer_infer/transformer_layer_infer.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/lightllm/models/glm4_moe_lite/layer_infer/transformer_layer_infer.py b/lightllm/models/glm4_moe_lite/layer_infer/transformer_layer_infer.py index bcea872fd..85aee2351 100644 --- a/lightllm/models/glm4_moe_lite/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/glm4_moe_lite/layer_infer/transformer_layer_infer.py @@ -23,19 +23,6 @@ def is_moe(self): def is_moe(self, value): pass - def _bind_ffn(self): - if self.is_moe: - moe_mode = os.environ.get("MOE_MODE", "TP") - if moe_mode == "EP": - self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn_edp, self) - self._tpsp_ffn = self._tpsp_ffn_ep - else: - self._ffn = partial(Glm4MoeLiteTransformerLayerInfer._moe_ffn, self) - self._tpsp_ffn = self._tpsp_ffn_tp - else: - self._ffn = partial(LlamaTransformerLayerInfer._ffn, self) - self._tpsp_ffn = self._tpsp_ffn_tp - def _get_o(self, input: torch.Tensor, infer_state, layer_weight) -> torch.Tensor: if input.shape[2] == self.kv_lora_rank: input = layer_weight.v_b_proj_.bmm(input.transpose(0, 1)).transpose(0, 1) From e2a082a85e09099e2896c37ab8e875c10c5e3151 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 28 Jan 2026 09:45:21 +0000 Subject: [PATCH 09/14] simply infer of glm4 --- .../layer_infer/transformer_layer_infer.py | 26 ++++---- .../layer_infer/transformer_layer_infer.py | 59 ------------------- 2 files changed, 16 insertions(+), 69 deletions(-) diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 023929b96..98cc7c229 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -65,10 +65,10 @@ def _bind_ffn(self): if self.is_moe: enable_ep_moe = get_env_start_args().enable_ep_moe if enable_ep_moe: - self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn_edp, self) + self._ffn = self._moe_ffn_edp self._tpsp_ffn = self._tpsp_ffn_ep else: - self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn, self) + self._ffn = self._moe_ffn self._tpsp_ffn = self._tpsp_ffn_tp else: self._ffn = partial(LlamaTransformerLayerInfer._ffn, self) @@ -257,7 +257,7 @@ def _get_o( ) -> torch.Tensor: if input.shape[2] == self.kv_lora_rank: input = layer_weight.v_b_proj_.bmm(input.transpose(0, 1)).transpose(0, 1) - o_tensor = layer_weight.o_weight_.mm(input.reshape(-1, self.tp_q_head_num_ * self.qk_nope_head_dim)) + o_tensor = layer_weight.o_weight_.mm(input.reshape(-1, self.tp_q_head_num_ * self.v_head_dim)) return o_tensor def _tpsp_get_o( @@ -269,7 +269,7 @@ def _tpsp_get_o( if input.shape[2] == self.kv_lora_rank: input = layer_weight.v_b_proj_.bmm(input.transpose(0, 1)).transpose(0, 1) - input = input.reshape(-1, self.tp_q_head_num_ * self.qk_nope_head_dim) + input = input.reshape(-1, self.tp_q_head_num_ * self.v_head_dim) dest_size = triton.cdiv(input.shape[0], self.tp_world_size_) * self.tp_world_size_ o_tensor = self.alloc_tensor((dest_size, self.embed_dim_), dtype=input.dtype, device=input.device) layer_weight.o_weight_.mm(input, out=o_tensor[0 : len(infer_state.input_ids), :]) @@ -302,7 +302,8 @@ def _moe_ffn( if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0: shared_output = LlamaTransformerLayerInfer._ffn(self, hidden_states, infer_state, layer_weight) - router_logits = layer_weight.moe_gate.mm(hidden_states) + moe_gate_dtype = layer_weight.moe_gate.data_type_ + router_logits = layer_weight.moe_gate.mm(hidden_states.to(moe_gate_dtype)) layer_weight.experts.experts( hidden_states, router_logits=router_logits, @@ -327,7 +328,8 @@ def _moe_ffn_edp( if self.n_shared_experts is not None: shared_output = LlamaTransformerLayerInfer._ffn(self, hidden_states, infer_state, layer_weight) - router_logits = layer_weight.moe_gate.mm(hidden_states) + moe_gate_dtype = layer_weight.moe_gate.data_type_ + router_logits = layer_weight.moe_gate.mm(hidden_states.to(moe_gate_dtype)) ep_output = layer_weight.experts.experts( hidden_states, router_logits=router_logits, @@ -405,7 +407,8 @@ def overlap_tpsp_token_forward( input_embdings.add_(_0_o.view(-1, self.embed_dim_)) _0_o = None _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) - _0_router_logits = layer_weight.moe_gate.mm(_0_input1) + moe_gate_dtype = layer_weight.moe_gate.data_type_ + _0_router_logits = layer_weight.moe_gate.mm(_0_input1.to(moe_gate_dtype)) # 1 hook if getattr(infer_state1, "hook", None) is not None: infer_state1.hook() @@ -439,7 +442,8 @@ def overlap_tpsp_token_forward( _1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight) # to do gate and disptatch - _1_router_logits = layer_weight.moe_gate.mm(_1_input1) + moe_gate_dtype = layer_weight.moe_gate.data_type_ + _1_router_logits = layer_weight.moe_gate.mm(_1_input1.to(moe_gate_dtype)) # 0 hook if getattr(infer_state, "hook", None) is not None: infer_state.hook() @@ -529,7 +533,8 @@ def overlap_tpsp_context_forward( input_embdings.add_(_0_o.view(-1, self.embed_dim_)) _0_o = None _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) - _0_router_logits = layer_weight.moe_gate.mm(_0_input1) + moe_gate_dtype = layer_weight.moe_gate.data_type_ + _0_router_logits = layer_weight.moe_gate.mm(_0_input1.to(moe_gate_dtype)) # wait last 1 combine if getattr(infer_state1, "hook", None) is not None: @@ -556,7 +561,8 @@ def overlap_tpsp_context_forward( _1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight) # to do gate and disptatch - _1_router_logits = layer_weight.moe_gate.mm(_1_input1) + moe_gate_dtype = layer_weight.moe_gate.data_type_ + _1_router_logits = layer_weight.moe_gate.mm(_1_input1.to(moe_gate_dtype)) # 0 dispatch execute ( diff --git a/lightllm/models/glm4_moe_lite/layer_infer/transformer_layer_infer.py b/lightllm/models/glm4_moe_lite/layer_infer/transformer_layer_infer.py index 85aee2351..95f41d995 100644 --- a/lightllm/models/glm4_moe_lite/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/glm4_moe_lite/layer_infer/transformer_layer_infer.py @@ -22,62 +22,3 @@ def is_moe(self): @is_moe.setter def is_moe(self, value): pass - - def _get_o(self, input: torch.Tensor, infer_state, layer_weight) -> torch.Tensor: - if input.shape[2] == self.kv_lora_rank: - input = layer_weight.v_b_proj_.bmm(input.transpose(0, 1)).transpose(0, 1) - o_tensor = layer_weight.o_weight_.mm(input.reshape(-1, self.tp_q_head_num_ * self.v_head_dim)) - return o_tensor - - def _tpsp_get_o(self, input, infer_state, layer_weight) -> torch.Tensor: - if infer_state.need_dp_prefill_balance: - input = infer_state._all_to_all_balance_get(data=input) - - if input.shape[2] == self.kv_lora_rank: - input = layer_weight.v_b_proj_.bmm(input.transpose(0, 1)).transpose(0, 1) - - input = input.reshape(-1, self.tp_q_head_num_ * self.v_head_dim) - dest_size = triton.cdiv(input.shape[0], self.tp_world_size_) * self.tp_world_size_ - o_tensor = self.alloc_tensor((dest_size, self.embed_dim_), dtype=input.dtype, device=input.device) - layer_weight.o_weight_.mm(input, out=o_tensor[0 : len(infer_state.input_ids), :]) - e_o_tensor = o_tensor[len(infer_state.input_ids) :, :] - if e_o_tensor.shape[0] > 0: - e_o_tensor.fill_(0) - - if self.tp_world_size_ > 1: - sp_token_num = o_tensor.shape[0] // self.tp_world_size_ - reduce_o_tensor = self.alloc_tensor((sp_token_num, self.embed_dim_), dtype=input.dtype, device=input.device) - reduce_scatter_tensor( - output=reduce_o_tensor, - input=o_tensor, - op=dist.ReduceOp.SUM, - group=infer_state.dist_group, - async_op=False, - ) - o_tensor = reduce_o_tensor - - return o_tensor - - def _moe_ffn(self, input, infer_state, layer_weight): - hidden_states = input.view(-1, self.embed_dim_) - num_tokens, hidden_dim = hidden_states.shape - - if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0: - shared_output = LlamaTransformerLayerInfer._ffn(self, hidden_states, infer_state, layer_weight) - - router_logits = layer_weight.moe_gate.mm(hidden_states.to(torch.float32)) - - layer_weight.experts.experts( - hidden_states, - router_logits=router_logits, - top_k=self.num_experts_per_tok, - renormalize=self.norm_topk_prob, - use_grouped_topk=self.n_group, - topk_group=self.topk_group, - num_expert_group=self.n_group, - ) - - if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0: - hidden_states.add_(shared_output) - - return hidden_states.view(num_tokens, hidden_dim) From e15186bd795e62c5ea5a3ca0666825bd04d1dc60 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 28 Jan 2026 10:40:22 +0000 Subject: [PATCH 10/14] fix --- ...num=1,use_fp8_w8a8=false}_NVIDIA_H200.json | 92 +++++++++++++++++++ ...num=5,use_fp8_w8a8=false}_NVIDIA_H200.json | 92 +++++++++++++++++++ .../{topk_num=5}_NVIDIA_H200.json | 46 ++++++++++ ...orch.bfloat16,topk_num=5}_NVIDIA_H200.json | 62 +++++++++++++ ...out_dtype=torch.bfloat16}_NVIDIA_H200.json | 60 ++++++++++++ .../layer_weights/transformer_layer_weight.py | 7 +- 6 files changed, 356 insertions(+), 3 deletions(-) create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=65,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=65,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=5,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=5}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=5}_NVIDIA_H200.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=65,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=65,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..ecf276831 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=65,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,92 @@ +{ + "10240": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "1280": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "160": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "20480": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "320": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "500": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "5120": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "640": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "80": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "83200": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=65,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=5,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=65,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=5,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..c938b85ef --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=65,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=5,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,92 @@ +{ + "100": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16640": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 5, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=5}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=5}_NVIDIA_H200.json new file mode 100644 index 000000000..77346239c --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=5}_NVIDIA_H200.json @@ -0,0 +1,46 @@ +{ + "100": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "16640": { + "BLOCK_SIZE": 128, + "num_warps": 1 + }, + "2048": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "4096": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, + "64": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE": 128, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=5}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=5}_NVIDIA_H200.json new file mode 100644 index 000000000..ff4f95232 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=5}_NVIDIA_H200.json @@ -0,0 +1,62 @@ +{ + "100": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 2 + }, + "128": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "16": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16640": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "256": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "32": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "4096": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 1 + }, + "64": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json index c4ee7dd0f..0aa022ae3 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -17,18 +17,36 @@ "NUM_STAGES": 1, "num_warps": 4 }, + "10240": { + "BLOCK_M": 32, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, "128": { "BLOCK_M": 1, "BLOCK_N": 256, "NUM_STAGES": 1, "num_warps": 4 }, + "1280": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, "16": { "BLOCK_M": 1, "BLOCK_N": 128, "NUM_STAGES": 4, "num_warps": 4 }, + "160": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, "16384": { "BLOCK_M": 1, "BLOCK_N": 256, @@ -47,6 +65,12 @@ "NUM_STAGES": 4, "num_warps": 1 }, + "20480": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, "256": { "BLOCK_M": 1, "BLOCK_N": 256, @@ -59,6 +83,12 @@ "NUM_STAGES": 4, "num_warps": 8 }, + "320": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, "4": { "BLOCK_M": 1, "BLOCK_N": 32, @@ -77,18 +107,36 @@ "NUM_STAGES": 4, "num_warps": 1 }, + "500": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, "512": { "BLOCK_M": 8, "BLOCK_N": 256, "NUM_STAGES": 4, "num_warps": 4 }, + "5120": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, "64": { "BLOCK_M": 1, "BLOCK_N": 256, "NUM_STAGES": 1, "num_warps": 1 }, + "640": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, "65536": { "BLOCK_M": 1, "BLOCK_N": 128, @@ -107,10 +155,22 @@ "NUM_STAGES": 1, "num_warps": 4 }, + "80": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, "8192": { "BLOCK_M": 8, "BLOCK_N": 256, "NUM_STAGES": 4, "num_warps": 1 + }, + "83200": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 } } \ No newline at end of file diff --git a/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py b/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py index 3395c4e20..848c2df84 100644 --- a/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py @@ -17,9 +17,10 @@ def _parse_config(self): from lightllm.utils.envs_utils import get_env_start_args - self.num_fused_shared_experts = self.network_config_.get("n_shared_experts", 0) - if get_env_start_args().enable_ep_moe and self.is_moe: - assert self.num_fused_shared_experts == 0, "n_shared_experts must be 0 when enable_ep_moe" + self.num_fused_shared_experts = 0 + if get_env_start_args().enable_fused_shared_experts and self.is_moe: + assert not get_env_start_args().enable_ep_moe, "enable_fused_shared_experts can only work with tp mode." + self.num_fused_shared_experts = self.network_config_.get("n_shared_experts", 0) def _init_moe(self): moe_intermediate_size = self.network_config_["moe_intermediate_size"] From 43bbc297dce4dcd7aa6ce7342a32ed044205c787 Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 29 Jan 2026 03:36:59 +0000 Subject: [PATCH 11/14] add config --- ..._num=1,use_fp8_w8a8=false}_NVIDIA_H200.json | 18 ++++++++++++++++++ ..._num=5,use_fp8_w8a8=false}_NVIDIA_H200.json | 18 ++++++++++++++++++ .../{topk_num=5}_NVIDIA_H200.json | 4 ++++ ...torch.bfloat16,topk_num=5}_NVIDIA_H200.json | 12 ++++++++++++ ...,out_dtype=torch.bfloat16}_NVIDIA_H200.json | 12 ++++++++++++ 5 files changed, 64 insertions(+) diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=65,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=65,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json index ecf276831..588bc5a2a 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=65,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=1536,N=2048,expert_num=65,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -44,6 +44,24 @@ "num_stages": 2, "num_warps": 4 }, + "40": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "5": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, "500": { "BLOCK_SIZE_K": 64, "BLOCK_SIZE_M": 64, diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=65,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=5,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=65,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=5,use_fp8_w8a8=false}_NVIDIA_H200.json index c938b85ef..5f05eb2a6 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=65,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=5,use_fp8_w8a8=false}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=3072,expert_num=65,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=5,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -1,4 +1,13 @@ { + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, "100": { "BLOCK_SIZE_K": 64, "BLOCK_SIZE_M": 64, @@ -88,5 +97,14 @@ "NEED_TRANS": false, "num_stages": 2, "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 } } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=5}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=5}_NVIDIA_H200.json index 77346239c..f1d903fd2 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=5}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=5}_NVIDIA_H200.json @@ -1,4 +1,8 @@ { + "1": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, "100": { "BLOCK_SIZE": 128, "num_warps": 8 diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=5}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=5}_NVIDIA_H200.json index ff4f95232..610acfaa5 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=5}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=5}_NVIDIA_H200.json @@ -1,4 +1,10 @@ { + "1": { + "BLOCK_DIM": 128, + "BLOCK_M": 16, + "NUM_STAGE": 4, + "num_warps": 4 + }, "100": { "BLOCK_DIM": 1024, "BLOCK_M": 1, @@ -58,5 +64,11 @@ "BLOCK_M": 1, "NUM_STAGE": 1, "num_warps": 4 + }, + "8": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 } } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json index 0aa022ae3..71bc9d341 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -95,6 +95,12 @@ "NUM_STAGES": 1, "num_warps": 4 }, + "40": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, "400": { "BLOCK_M": 1, "BLOCK_N": 256, @@ -107,6 +113,12 @@ "NUM_STAGES": 4, "num_warps": 1 }, + "5": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 8 + }, "500": { "BLOCK_M": 1, "BLOCK_N": 256, From 22e4a5347a1dc5ebcdf6a2aed0a47f218686ccf7 Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 29 Jan 2026 04:33:26 +0000 Subject: [PATCH 12/14] clean code --- lightllm/common/basemodel/attention/flashinfer/mla.py | 3 +-- lightllm/common/basemodel/attention/triton/mla.py | 1 - ...oat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json | 9 +++++++++ ...oat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json | 9 +++++++++ ...1,Q_HEAD_NUM=5,dtype=torch.bfloat16}_NVIDIA_H200.json | 6 ++++++ .../{N=2560,out_dtype=torch.bfloat16}_NVIDIA_H200.json | 6 ++++++ .../{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json | 6 ++++++ lightllm/common/triton_utils/autotuner.py | 2 +- lightllm/models/deepseek2/model.py | 1 - .../glm4_moe_lite/layer_infer/transformer_layer_infer.py | 7 ------- .../layer_weights/transformer_layer_weight.py | 1 - 11 files changed, 38 insertions(+), 13 deletions(-) diff --git a/lightllm/common/basemodel/attention/flashinfer/mla.py b/lightllm/common/basemodel/attention/flashinfer/mla.py index 537dbee22..c9e831829 100644 --- a/lightllm/common/basemodel/attention/flashinfer/mla.py +++ b/lightllm/common/basemodel/attention/flashinfer/mla.py @@ -71,7 +71,7 @@ def init_state(self): num_qo_heads=self.backend.tp_q_head_num, num_kv_heads=self.backend.tp_q_head_num, head_dim_qk=self.backend.qk_nope_head_dim + self.backend.qk_rope_head_dim, - head_dim_vo=self.backend.v_head_dim, # Use v_head_dim, not qk_nope_head_dim + head_dim_vo=self.backend.v_head_dim, q_data_type=self.backend.q_data_type, causal=True, sm_scale=self.backend.softmax_scale, @@ -103,7 +103,6 @@ def _mla_prefill_att( ) -> torch.Tensor: self.backend: MlaFlashInferAttBackend = self.backend # for typing k_nope, k_rope = k - # Output dimension is v_head_dim (from v.shape[-1]), not qk_nope_head_dim o_tensor = alloc_func((q.shape[0], q.shape[1], v.shape[-1]), q.dtype, device="cuda") q_head_num = q.shape[1] k = torch.cat([k_nope, torch.repeat_interleave(k_rope, q_head_num, dim=-2)], dim=-1) diff --git a/lightllm/common/basemodel/attention/triton/mla.py b/lightllm/common/basemodel/attention/triton/mla.py index fbdae4012..c7edecd10 100644 --- a/lightllm/common/basemodel/attention/triton/mla.py +++ b/lightllm/common/basemodel/attention/triton/mla.py @@ -44,7 +44,6 @@ def _mla_prefill_att( qk_rope_head_dim = 64 q_nope, q_rope = q[:, :, :-qk_rope_head_dim], q[:, :, -qk_rope_head_dim:] - # GLM-4.7-Flash : v_head_dim != qk_nope_head_dim o_tensor = alloc_func((q_nope.shape[0], q_nope.shape[1], v.shape[-1]), dtype=q_nope.dtype, device=q.device) k_nope, k_rope = k assert att_control.mla_prefill diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json index 292e2e124..72dc716ac 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=768,expert_num=64,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=4,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -44,6 +44,15 @@ "num_stages": 4, "num_warps": 4 }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, "16640": { "BLOCK_SIZE_K": 64, "BLOCK_SIZE_M": 128, diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json index 5ea445367..ceba290e6 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=384,N=2048,expert_num=64,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -89,6 +89,15 @@ "num_stages": 3, "num_warps": 4 }, + "65536": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, "66560": { "BLOCK_SIZE_K": 64, "BLOCK_SIZE_M": 128, diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=5,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=5,dtype=torch.bfloat16}_NVIDIA_H200.json index feb169f56..0c480f44e 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=5,dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=5,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -29,6 +29,12 @@ "num_stages": 2, "num_warps": 1 }, + "16384": { + "BLOCK_SEQ": 4, + "HEAD_PARALLEL_NUM": 1, + "num_stages": 2, + "num_warps": 1 + }, "16640": { "BLOCK_SEQ": 4, "HEAD_PARALLEL_NUM": 1, diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=2560,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=2560,out_dtype=torch.bfloat16}_NVIDIA_H200.json index c041200ea..1bb9f7370 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=2560,out_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=2560,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -29,6 +29,12 @@ "NUM_STAGES": 2, "num_warps": 8 }, + "16384": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, "16640": { "BLOCK_M": 1, "BLOCK_N": 256, diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json index 63ee27ea2..f133810ae 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=384,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -89,6 +89,12 @@ "NUM_STAGES": 4, "num_warps": 8 }, + "65536": { + "BLOCK_M": 64, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, "66560": { "BLOCK_M": 64, "BLOCK_N": 128, diff --git a/lightllm/common/triton_utils/autotuner.py b/lightllm/common/triton_utils/autotuner.py index a919f7b28..c62a2572f 100644 --- a/lightllm/common/triton_utils/autotuner.py +++ b/lightllm/common/triton_utils/autotuner.py @@ -62,7 +62,7 @@ def autotune( as needed before invocation. """ - def decorator(fn): + def decorator(fn: Callable) -> Callable: return Autotuner( fn=fn, kernel_name=kernel_name, diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index c7f13bb63..e596eed97 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -43,7 +43,6 @@ def _init_some_value(self): self.qk_rope_head_dim = self.config["qk_rope_head_dim"] self.q_lora_rank = self.config["q_lora_rank"] self.kv_lora_rank = self.config["kv_lora_rank"] - # v_head_dim defaults to qk_nope_head_dim for DeepSeek-V2, but GLM-4.7-Flash has different value self.v_head_dim = self.config.get("v_head_dim", self.qk_nope_head_dim) self.head_dim_ = self.kv_lora_rank + self.qk_rope_head_dim diff --git a/lightllm/models/glm4_moe_lite/layer_infer/transformer_layer_infer.py b/lightllm/models/glm4_moe_lite/layer_infer/transformer_layer_infer.py index 95f41d995..6d0041014 100644 --- a/lightllm/models/glm4_moe_lite/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/glm4_moe_lite/layer_infer/transformer_layer_infer.py @@ -1,11 +1,4 @@ -import os -import torch -import torch.distributed as dist -import triton -from functools import partial from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.distributed.communication_op import reduce_scatter_tensor class Glm4MoeLiteTransformerLayerInfer(Deepseek2TransformerLayerInfer): diff --git a/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py b/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py index 848c2df84..92a5eabe0 100644 --- a/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py @@ -1,4 +1,3 @@ -import os import torch from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, FusedMoeWeight From 5f08bb7a5522ff3b9a0baa830d80e87918e55bf4 Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 29 Jan 2026 09:13:08 +0000 Subject: [PATCH 13/14] clean code --- .../basemodel/attention/flashinfer/mla.py | 3 +- .../triton_kernel/fused_moe/grouped_topk.py | 2 +- .../triton_kernel/fused_moe/topk_select.py | 8 +- .../context_flashattention_nopad_with_v.py | 46 ++--- ...=20,dtype=torch.bfloat16}_NVIDIA_H200.json | 6 + ...out_dtype=torch.bfloat16}_NVIDIA_H200.json | 6 + ...out_dtype=torch.bfloat16}_NVIDIA_H200.json | 186 ++++++++++++++++++ ...out_dtype=torch.bfloat16}_NVIDIA_H200.json | 96 +++++++++ .../pre_and_post_layer_weight.py | 12 +- .../layer_infer/transformer_layer_infer.py | 11 -- .../layer_weights/transformer_layer_weight.py | 47 ----- .../glm4_moe_lite_mtp/layer_infer/__init__.py | 3 - .../layer_infer/pre_layer_infer.py | 82 -------- .../pre_and_post_layer_weight.py | 50 +---- lightllm/models/glm4_moe_lite_mtp/model.py | 4 +- 15 files changed, 325 insertions(+), 237 deletions(-) delete mode 100644 lightllm/models/glm4_moe_lite_mtp/layer_infer/__init__.py delete mode 100644 lightllm/models/glm4_moe_lite_mtp/layer_infer/pre_layer_infer.py diff --git a/lightllm/common/basemodel/attention/flashinfer/mla.py b/lightllm/common/basemodel/attention/flashinfer/mla.py index c9e831829..84b44dc45 100644 --- a/lightllm/common/basemodel/attention/flashinfer/mla.py +++ b/lightllm/common/basemodel/attention/flashinfer/mla.py @@ -16,8 +16,7 @@ def __init__(self, model): self.qk_nope_head_dim = model.qk_nope_head_dim self.qk_rope_head_dim = model.qk_rope_head_dim self.kv_lora_rank = model.kv_lora_rank - # v_head_dim may differ from qk_nope_head_dim (e.g., GLM-4.7-Flash: v_head_dim=256, qk_nope_head_dim=192) - self.v_head_dim = getattr(model, "v_head_dim", self.qk_nope_head_dim) + self.v_head_dim = model.v_head_dim self.q_data_type = model.data_type self.kv_data_type = model.data_type self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py index 2687adf14..fb0323cd4 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py @@ -227,7 +227,7 @@ def triton_grouped_topk( scores_buffer = torch.empty((token_num, total_expert_num), dtype=dtype, device="cuda") out_topk_weights = torch.empty((token_num, topk), dtype=torch.float32, device="cuda") - out_topk_ids = torch.empty((token_num, topk), dtype=torch.int32, device="cuda") + out_topk_ids = torch.empty((token_num, topk), dtype=torch.long, device="cuda") assert total_expert_num % num_expert_group == 0 diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py b/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py index 7ac5a03b5..72c3a381e 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py @@ -196,12 +196,10 @@ def select_experts( scoring_func=scoring_func, ) else: - if correction_bias is not None: + group_score_topk_num = 1 + # for deepseek v3 + if topk_group == 4 and num_expert_group == 8 and top_k == 8: group_score_topk_num = 2 - elif topk_group == 4 and num_expert_group == 8 and top_k == 8: - group_score_topk_num = 2 - else: - group_score_topk_num = 1 topk_weights, topk_ids = triton_grouped_topk( hidden_states=hidden_states, diff --git a/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/context_flashattention_nopad_with_v.py b/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/context_flashattention_nopad_with_v.py index d79020844..be0635182 100644 --- a/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/context_flashattention_nopad_with_v.py +++ b/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/context_flashattention_nopad_with_v.py @@ -36,9 +36,6 @@ def _fwd_kernel_with_v( BLOCK_DMODEL: tl.constexpr, BLOCK_ROPE_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - BLOCK_V_DMODEL: tl.constexpr, - ACTUAL_DMODEL: tl.constexpr, - ACTUAL_V_DMODEL: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -56,13 +53,8 @@ def _fwd_kernel_with_v( # initialize offsets offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) - offs_v_d = tl.arange(0, BLOCK_V_DMODEL) offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - - d_mask = offs_d < ACTUAL_DMODEL - v_d_mask = offs_v_d < ACTUAL_V_DMODEL - off_q = (cur_batch_in_q_start_index + offs_m[:, None]) * stride_q_bs + cur_head * stride_q_h + offs_d[None, :] off_q_rope = ( (cur_batch_in_q_start_index + offs_m[:, None]) * stride_q_rope_bs @@ -71,10 +63,9 @@ def _fwd_kernel_with_v( ) off_k = offs_n[None, :] * stride_k_bs + cur_k_head * stride_k_h + offs_d[:, None] off_k_rope = offs_n[None, :] * stride_k_rope_bs + offs_rope_d[:, None] - off_v = offs_n[:, None] * stride_vbs + cur_k_head * stride_vh + offs_v_d[None, :] + off_v = offs_n[:, None] * stride_vbs + cur_k_head * stride_vh + offs_d[None, :] - q_mask = (offs_m[:, None] < cur_batch_seq_len) & d_mask[None, :] - q = tl.load(Q_nope + off_q, mask=q_mask, other=0.0) + q = tl.load(Q_nope + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) q_rope = tl.load(Q_rope + off_q_rope, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) k_ptrs = K_nope + off_k @@ -84,7 +75,7 @@ def _fwd_kernel_with_v( # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_V_DMODEL], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) block_end_loc = tl.minimum((start_m + 1) * BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len) @@ -92,16 +83,14 @@ def _fwd_kernel_with_v( for start_n in range(0, block_mask * block_end_loc, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- - k_seq_mask = (start_n + offs_n[None, :]) < block_end_loc - k_mask = k_seq_mask & d_mask[:, None] k = tl.load( k_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_k_bs, - mask=k_mask, + mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0, ) k_rope = tl.load( k_rope_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_k_rope_bs, - mask=k_seq_mask, + mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0, ) @@ -123,11 +112,9 @@ def _fwd_kernel_with_v( # -- update output accumulator -- acc = acc * alpha[:, None] # update acc - v_seq_mask = (start_n + offs_n[:, None]) < block_end_loc - v_mask = v_seq_mask & v_d_mask[None, :] v = tl.load( v_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_vbs, - mask=v_mask, + mask=(start_n + offs_n[:, None]) < block_end_loc, other=0.0, ) p = p.to(v.dtype) @@ -137,10 +124,9 @@ def _fwd_kernel_with_v( acc = acc / l_i[:, None] # initialize pointers to output - off_o = (cur_batch_in_q_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_v_d[None, :] + off_o = (cur_batch_in_q_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] out_ptrs = Out + off_o - o_mask = (offs_m[:, None] < cur_batch_seq_len) & v_d_mask[None, :] - tl.store(out_ptrs, acc, mask=o_mask) + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) return @@ -163,14 +149,13 @@ def context_attention_fwd_with_v( BLOCK = 128 if not is_tesla() else 64 q_nope_dim = q_nope.shape[-1] q_rope_dim = q_rope.shape[-1] - v_dim = v.shape[-1] assert q_nope_dim == k_nope.shape[-1] assert q_rope_dim == k_rope.shape[-1] + assert q_nope_dim in {16, 32, 64, 128, 256, 512} + assert q_rope_dim in {16, 32, 64, 128, 256} + assert q_nope_dim == v.shape[-1] - q_nope_dim_padded = triton.next_power_of_2(q_nope_dim) - v_dim_padded = triton.next_power_of_2(v_dim) - - if q_nope_dim_padded >= 512 or v_dim_padded >= 512: + if q_nope_dim >= 512: BLOCK = 64 if not is_tesla() else 32 else: BLOCK = 128 if not is_tesla() else 64 @@ -182,7 +167,7 @@ def context_attention_fwd_with_v( batch, head = b_seq_len.shape[0], q_nope.shape[1] grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - num_warps = 4 if q_nope_dim_padded <= 64 else 8 + num_warps = 4 if q_nope_dim <= 64 else 8 _fwd_kernel_with_v[grid]( q_nope, @@ -209,12 +194,9 @@ def context_attention_fwd_with_v( o.stride(1), b_prompt_cache_len=b_prompt_cache_len, BLOCK_M=BLOCK, - BLOCK_DMODEL=q_nope_dim_padded, + BLOCK_DMODEL=q_nope_dim, BLOCK_ROPE_DMODEL=q_rope_dim, BLOCK_N=BLOCK, - BLOCK_V_DMODEL=v_dim_padded, - ACTUAL_DMODEL=q_nope_dim, - ACTUAL_V_DMODEL=v_dim, num_warps=num_warps, num_stages=1, ) diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=20,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=20,dtype=torch.bfloat16}_NVIDIA_H200.json index 056b6a747..0dcae9848 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=20,dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/rotary_emb_fwd:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=20,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -76,5 +76,11 @@ "HEAD_PARALLEL_NUM": 16, "num_stages": 4, "num_warps": 1 + }, + "8192": { + "BLOCK_SEQ": 1, + "HEAD_PARALLEL_NUM": 2, + "num_stages": 3, + "num_warps": 1 } } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=10240,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=10240,out_dtype=torch.bfloat16}_NVIDIA_H200.json index 70f6af6d6..7e9c73e72 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=10240,out_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=10240,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -76,5 +76,11 @@ "BLOCK_N": 256, "NUM_STAGES": 4, "num_warps": 1 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 } } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json index 71bc9d341..ded43f7a5 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=1536,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -35,6 +35,12 @@ "NUM_STAGES": 4, "num_warps": 1 }, + "1536": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, "16": { "BLOCK_M": 1, "BLOCK_N": 128, @@ -59,6 +65,42 @@ "NUM_STAGES": 4, "num_warps": 1 }, + "17280": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "17536": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "1792": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "18304": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "18688": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "1920": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, "2048": { "BLOCK_M": 8, "BLOCK_N": 256, @@ -77,6 +119,18 @@ "NUM_STAGES": 2, "num_warps": 4 }, + "2560": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2816": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, "32": { "BLOCK_M": 1, "BLOCK_N": 256, @@ -89,6 +143,36 @@ "NUM_STAGES": 1, "num_warps": 4 }, + "3328": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "34560": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "35200": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "3584": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "384": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, "4": { "BLOCK_M": 1, "BLOCK_N": 32, @@ -113,6 +197,24 @@ "NUM_STAGES": 4, "num_warps": 1 }, + "4608": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "4864": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "4992": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, "5": { "BLOCK_M": 1, "BLOCK_N": 128, @@ -137,6 +239,18 @@ "NUM_STAGES": 4, "num_warps": 1 }, + "5504": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "5632": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, "64": { "BLOCK_M": 1, "BLOCK_N": 256, @@ -149,18 +263,54 @@ "NUM_STAGES": 4, "num_warps": 1 }, + "6400": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, "65536": { "BLOCK_M": 1, "BLOCK_N": 128, "NUM_STAGES": 4, "num_warps": 1 }, + "65920": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "66048": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, "66560": { "BLOCK_M": 1, "BLOCK_N": 128, "NUM_STAGES": 4, "num_warps": 1 }, + "66944": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "67712": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "67968": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, "8": { "BLOCK_M": 1, "BLOCK_N": 64, @@ -184,5 +334,41 @@ "BLOCK_N": 128, "NUM_STAGES": 4, "num_warps": 1 + }, + "8832": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "896": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "9088": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "9216": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "9856": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "9984": { + "BLOCK_M": 32, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 } } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=768,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=768,out_dtype=torch.bfloat16}_NVIDIA_H200.json index dfd9dd729..4655cbf30 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=768,out_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=768,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -17,12 +17,36 @@ "NUM_STAGES": 4, "num_warps": 1 }, + "11776": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, "128": { "BLOCK_M": 1, "BLOCK_N": 256, "NUM_STAGES": 2, "num_warps": 1 }, + "12800": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "135040": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "135168": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, "16": { "BLOCK_M": 1, "BLOCK_N": 128, @@ -41,12 +65,24 @@ "NUM_STAGES": 4, "num_warps": 1 }, + "19584": { + "BLOCK_M": 32, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, "2048": { "BLOCK_M": 8, "BLOCK_N": 128, "NUM_STAGES": 4, "num_warps": 1 }, + "20992": { + "BLOCK_M": 32, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, "256": { "BLOCK_M": 1, "BLOCK_N": 256, @@ -59,6 +95,36 @@ "NUM_STAGES": 1, "num_warps": 1 }, + "3200": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "3584": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "36224": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "36608": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "384": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, "4": { "BLOCK_M": 1, "BLOCK_N": 128, @@ -83,12 +149,30 @@ "NUM_STAGES": 1, "num_warps": 1 }, + "5248": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "5504": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, "64": { "BLOCK_M": 1, "BLOCK_N": 256, "NUM_STAGES": 1, "num_warps": 1 }, + "640": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, "65536": { "BLOCK_M": 8, "BLOCK_N": 128, @@ -101,6 +185,18 @@ "NUM_STAGES": 4, "num_warps": 1 }, + "6912": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "7808": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, "8": { "BLOCK_M": 1, "BLOCK_N": 128, diff --git a/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py index 91c0b2b3f..596fe3e73 100644 --- a/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py @@ -9,32 +9,32 @@ class Deepseek3MTPPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, quant_cfg: Quantcfg): + def __init__(self, data_type, network_config, quant_cfg: Quantcfg, layer_idx=0): super().__init__(data_type, network_config) self.quant_cfg: Quantcfg = quant_cfg hidden_size = network_config["hidden_size"] self.eh_proj_weight_ = ROWMMWeight( in_dim=hidden_size * 2, out_dims=[hidden_size], - weight_names="model.layers.0.eh_proj.weight", + weight_names=f"model.layers.{layer_idx}.eh_proj.weight", data_type=self.data_type_, - quant_method=self.quant_cfg.get_quant_method(0, "eh_proj"), + quant_method=self.quant_cfg.get_quant_method(layer_idx, "eh_proj"), tp_rank=0, tp_world_size=1, ) self.enorm_weight_ = RMSNormWeight( dim=hidden_size, - weight_name="model.layers.0.enorm.weight", + weight_name=f"model.layers.{layer_idx}.enorm.weight", data_type=self.data_type_, ) self.hnorm_weight_ = RMSNormWeight( dim=hidden_size, - weight_name="model.layers.0.hnorm.weight", + weight_name=f"model.layers.{layer_idx}.hnorm.weight", data_type=self.data_type_, ) self.final_norm_weight_ = RMSNormWeight( dim=hidden_size, - weight_name="model.layers.0.shared_head.norm.weight", + weight_name=f"model.layers.{layer_idx}.shared_head.norm.weight", data_type=self.data_type_, ) diff --git a/lightllm/models/glm4_moe_lite/layer_infer/transformer_layer_infer.py b/lightllm/models/glm4_moe_lite/layer_infer/transformer_layer_infer.py index 6d0041014..5c53318f6 100644 --- a/lightllm/models/glm4_moe_lite/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/glm4_moe_lite/layer_infer/transformer_layer_infer.py @@ -3,15 +3,4 @@ class Glm4MoeLiteTransformerLayerInfer(Deepseek2TransformerLayerInfer): def __init__(self, layer_num, network_config): - self._glm4_layer_num = layer_num - self._glm4_first_k_dense = network_config.get("first_k_dense_replace", 0) - self._glm4_has_routed_experts = network_config.get("n_routed_experts") is not None super().__init__(layer_num, network_config) - - @property - def is_moe(self): - return self._glm4_has_routed_experts and self._glm4_layer_num >= self._glm4_first_k_dense - - @is_moe.setter - def is_moe(self, value): - pass diff --git a/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py b/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py index 92a5eabe0..e37b1b325 100644 --- a/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/glm4_moe_lite/layer_weights/transformer_layer_weight.py @@ -6,50 +6,3 @@ class Glm4MoeLiteTransformerLayerWeight(Deepseek2TransformerLayerWeight): def __init__(self, layer_num, data_type, network_config, quant_cfg=None): super().__init__(layer_num, data_type, network_config, quant_cfg) - - def _parse_config(self): - super()._parse_config() - - self.is_moe = self.network_config_.get( - "n_routed_experts" - ) is not None and self.layer_num_ >= self.network_config_.get("first_k_dense_replace", 0) - - from lightllm.utils.envs_utils import get_env_start_args - - self.num_fused_shared_experts = 0 - if get_env_start_args().enable_fused_shared_experts and self.is_moe: - assert not get_env_start_args().enable_ep_moe, "enable_fused_shared_experts can only work with tp mode." - self.num_fused_shared_experts = self.network_config_.get("n_shared_experts", 0) - - def _init_moe(self): - moe_intermediate_size = self.network_config_["moe_intermediate_size"] - hidden_size = self.network_config_["hidden_size"] - - self.moe_gate = ROWMMWeight( - in_dim=hidden_size, - out_dims=[self.n_routed_experts], - weight_names=f"model.layers.{self.layer_num_}.mlp.gate.weight", - data_type=torch.float32, # Router gate needs float32 for numerical stability - quant_method=None, - tp_rank=0, - tp_world_size=1, - ) - - if self.num_fused_shared_experts == 0: - self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts", is_shared_experts=True) - - self.experts = FusedMoeWeight( - gate_proj_name="gate_proj", - down_proj_name="down_proj", - up_proj_name="up_proj", - e_score_correction_bias_name=self.e_score_correction_bias_name, - weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts", - n_routed_experts=self.n_routed_experts, - hidden_size=hidden_size, - moe_intermediate_size=moe_intermediate_size, - data_type=self.data_type_, - quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), - num_fused_shared_experts=self.num_fused_shared_experts, - layer_num=self.layer_num_, - network_config=self.network_config_, - ) diff --git a/lightllm/models/glm4_moe_lite_mtp/layer_infer/__init__.py b/lightllm/models/glm4_moe_lite_mtp/layer_infer/__init__.py deleted file mode 100644 index e357bfa19..000000000 --- a/lightllm/models/glm4_moe_lite_mtp/layer_infer/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from lightllm.models.glm4_moe_lite_mtp.layer_infer.pre_layer_infer import Glm4MoeLiteMTPPreLayerInfer - -__all__ = ["Glm4MoeLiteMTPPreLayerInfer"] diff --git a/lightllm/models/glm4_moe_lite_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/glm4_moe_lite_mtp/layer_infer/pre_layer_infer.py deleted file mode 100644 index 1f2dc71d5..000000000 --- a/lightllm/models/glm4_moe_lite_mtp/layer_infer/pre_layer_infer.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch - -from lightllm.models.glm4_moe_lite_mtp.layer_weights.pre_and_post_layer_weight import ( - Glm4MoeLiteMTPPreAndPostLayerWeight, -) -from lightllm.models.glm4_moe_lite.infer_struct import Glm4MoeLiteInferStateInfo -from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer - - -class Glm4MoeLiteMTPPreLayerInfer(LlamaPreLayerInfer): - def __init__(self, network_config): - super().__init__(network_config) - self.eps_ = network_config["rms_norm_eps"] - self.hidden_size = network_config["hidden_size"] - - def _mtp_context_forward( - self, - input_embdings, - infer_state: Glm4MoeLiteInferStateInfo, - layer_weight: Glm4MoeLiteMTPPreAndPostLayerWeight, - ): - tgt_embdings = infer_state.mtp_draft_input_hiddens - assert ( - input_embdings.shape[0] == tgt_embdings.shape[0] - ), f"shape {input_embdings.shape} != shape {tgt_embdings.shape}" - - layer_weight.enorm_weight_( - input=input_embdings, - eps=self.eps_, - out=input_embdings, - ) - layer_weight.hnorm_weight_( - input=tgt_embdings, - eps=self.eps_, - out=tgt_embdings, - ) - cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) - - ans_logics = layer_weight.eh_proj_weight_.mm(cat_embdings) - return ans_logics - - def _mtp_token_forward( - self, - input_embdings, - infer_state: Glm4MoeLiteInferStateInfo, - layer_weight: Glm4MoeLiteMTPPreAndPostLayerWeight, - ): - tgt_embdings = infer_state.mtp_draft_input_hiddens - assert input_embdings.shape[0] == tgt_embdings.shape[0] - - layer_weight.enorm_weight_( - input=input_embdings, - eps=self.eps_, - out=input_embdings, - ) - layer_weight.hnorm_weight_( - input=tgt_embdings, - eps=self.eps_, - out=tgt_embdings, - ) - cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) - - ans_logics = layer_weight.eh_proj_weight_.mm(cat_embdings) - return ans_logics - - def context_forward( - self, - input_ids, - infer_state: Glm4MoeLiteInferStateInfo, - layer_weight: Glm4MoeLiteMTPPreAndPostLayerWeight, - ): - input_embdings = super().context_forward(input_ids, infer_state, layer_weight) - return self._mtp_context_forward(input_embdings, infer_state, layer_weight) - - def token_forward( - self, - input_ids, - infer_state: Glm4MoeLiteInferStateInfo, - layer_weight: Glm4MoeLiteMTPPreAndPostLayerWeight, - ): - input_embdings = super().token_forward(input_ids, infer_state, layer_weight) - return self._mtp_token_forward(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/glm4_moe_lite_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/glm4_moe_lite_mtp/layer_weights/pre_and_post_layer_weight.py index 78b4d411c..b1e7354b4 100644 --- a/lightllm/models/glm4_moe_lite_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/glm4_moe_lite_mtp/layer_weights/pre_and_post_layer_weight.py @@ -1,48 +1,6 @@ -from lightllm.common.basemodel import PreAndPostLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import ( - EmbeddingWeight, - LMHeadWeight, - RMSNormWeight, - ROWMMWeight, -) -from lightllm.common.quantization import Quantcfg +from lightllm.models.deepseek_mtp.layer_weights.pre_and_post_layer_weight import Deepseek3MTPPreAndPostLayerWeight -class Glm4MoeLiteMTPPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, quant_cfg: Quantcfg): - super().__init__(data_type, network_config) - self.quant_cfg: Quantcfg = quant_cfg - - mtp_layer_idx = network_config["num_hidden_layers"] - hidden_size = network_config["hidden_size"] - - self.eh_proj_weight_ = ROWMMWeight( - in_dim=hidden_size * 2, - out_dims=[hidden_size], - weight_names=f"model.layers.{mtp_layer_idx}.eh_proj.weight", - data_type=self.data_type_, - quant_method=self.quant_cfg.get_quant_method(mtp_layer_idx, "eh_proj"), - tp_rank=0, - tp_world_size=1, - ) - - self.enorm_weight_ = RMSNormWeight( - dim=hidden_size, - weight_name=f"model.layers.{mtp_layer_idx}.enorm.weight", - data_type=self.data_type_, - ) - - self.hnorm_weight_ = RMSNormWeight( - dim=hidden_size, - weight_name=f"model.layers.{mtp_layer_idx}.hnorm.weight", - data_type=self.data_type_, - ) - - self.final_norm_weight_ = RMSNormWeight( - dim=hidden_size, - weight_name=f"model.layers.{mtp_layer_idx}.shared_head.norm.weight", - data_type=self.data_type_, - ) - - self.wte_weight_: EmbeddingWeight = None - self.lm_head_weight_: LMHeadWeight = None +class Glm4MoeLiteMTPPreAndPostLayerWeight(Deepseek3MTPPreAndPostLayerWeight): + def __init__(self, data_type, network_config, quant_cfg, layer_idx=0): + super().__init__(data_type, network_config, quant_cfg, network_config["num_hidden_layers"]) diff --git a/lightllm/models/glm4_moe_lite_mtp/model.py b/lightllm/models/glm4_moe_lite_mtp/model.py index dea415abb..549bf7ce4 100644 --- a/lightllm/models/glm4_moe_lite_mtp/model.py +++ b/lightllm/models/glm4_moe_lite_mtp/model.py @@ -1,6 +1,6 @@ from typing import List +from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer from lightllm.models.glm4_moe_lite.model import Glm4MoeLiteTpPartModel -from lightllm.models.glm4_moe_lite_mtp.layer_infer.pre_layer_infer import Glm4MoeLiteMTPPreLayerInfer from lightllm.models.glm4_moe_lite_mtp.layer_weights.pre_and_post_layer_weight import ( Glm4MoeLiteMTPPreAndPostLayerWeight, ) @@ -11,7 +11,7 @@ class Glm4MoeLiteMTPModel(Glm4MoeLiteTpPartModel): pre_and_post_weight_class = Glm4MoeLiteMTPPreAndPostLayerWeight - pre_layer_infer_class = Glm4MoeLiteMTPPreLayerInfer + pre_layer_infer_class = Deepseek3MTPPreLayerInfer def __init__(self, kvargs: dict): self._pre_init(kvargs) From 586a62bcdf9fe1eb45856b1d664faf6b4283340a Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 29 Jan 2026 11:56:14 +0000 Subject: [PATCH 14/14] add GlmAttBackend --- ...,q_rope_dim=64,v_dim=256}_NVIDIA_H200.json | 62 +++++ lightllm/models/glm4_moe_lite/model.py | 27 +- .../glm4_moe_lite/triton_kernel/__init__.py | 1 + .../context_flashattention_nopad.py | 247 ++++++++++++++++++ .../glm4_moe_lite/triton_kernel/mla_att.py | 134 ++++++++++ 5 files changed, 468 insertions(+), 3 deletions(-) create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/glm_context_attention_fwd_with_v:v1/{dtype=torch.bfloat16,num_heads=20,q_nope_dim=192,q_rope_dim=64,v_dim=256}_NVIDIA_H200.json create mode 100644 lightllm/models/glm4_moe_lite/triton_kernel/__init__.py create mode 100644 lightllm/models/glm4_moe_lite/triton_kernel/context_flashattention_nopad.py create mode 100644 lightllm/models/glm4_moe_lite/triton_kernel/mla_att.py diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/glm_context_attention_fwd_with_v:v1/{dtype=torch.bfloat16,num_heads=20,q_nope_dim=192,q_rope_dim=64,v_dim=256}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/glm_context_attention_fwd_with_v:v1/{dtype=torch.bfloat16,num_heads=20,q_nope_dim=192,q_rope_dim=64,v_dim=256}_NVIDIA_H200.json new file mode 100644 index 000000000..5f9f3e4fc --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/glm_context_attention_fwd_with_v:v1/{dtype=torch.bfloat16,num_heads=20,q_nope_dim=192,q_rope_dim=64,v_dim=256}_NVIDIA_H200.json @@ -0,0 +1,62 @@ +{ + "1": { + "BLOCK": 32, + "num_stages": 2, + "num_warps": 4 + }, + "100": { + "BLOCK": 64, + "num_stages": 2, + "num_warps": 8 + }, + "1024": { + "BLOCK": 64, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK": 64, + "num_stages": 2, + "num_warps": 8 + }, + "16": { + "BLOCK": 32, + "num_stages": 2, + "num_warps": 8 + }, + "16384": { + "BLOCK": 64, + "num_stages": 1, + "num_warps": 4 + }, + "2048": { + "BLOCK": 64, + "num_stages": 1, + "num_warps": 4 + }, + "256": { + "BLOCK": 64, + "num_stages": 2, + "num_warps": 8 + }, + "32": { + "BLOCK": 32, + "num_stages": 2, + "num_warps": 4 + }, + "4096": { + "BLOCK": 64, + "num_stages": 1, + "num_warps": 4 + }, + "64": { + "BLOCK": 64, + "num_stages": 2, + "num_warps": 8 + }, + "8": { + "BLOCK": 32, + "num_stages": 2, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/models/glm4_moe_lite/model.py b/lightllm/models/glm4_moe_lite/model.py index 367928f2e..00f87ae8d 100644 --- a/lightllm/models/glm4_moe_lite/model.py +++ b/lightllm/models/glm4_moe_lite/model.py @@ -5,9 +5,7 @@ from lightllm.models.glm4_moe_lite.layer_weights.transformer_layer_weight import Glm4MoeLiteTransformerLayerWeight from lightllm.models.glm4_moe_lite.infer_struct import Glm4MoeLiteInferStateInfo from lightllm.distributed.communication_op import dist_group_manager -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) +from lightllm.utils.envs_utils import get_env_start_args @ModelRegistry("glm4_moe_lite") @@ -20,6 +18,29 @@ class Glm4MoeLiteTpPartModel(Deepseek2TpPartModel): def __init__(self, kvargs): super().__init__(kvargs) + def _init_att_backend(self): + args = get_env_start_args() + prefill_backend_str = args.llm_prefill_att_backend[0] + decode_backend_str = args.llm_decode_att_backend[0] + + if prefill_backend_str in ("triton", "auto"): + from lightllm.models.glm4_moe_lite.triton_kernel.mla_att import GlmMlaTritonAttBackend + + self.prefill_att_backend = GlmMlaTritonAttBackend(model=self) + else: + from lightllm.common.basemodel.attention import get_mla_prefill_att_backend_class + + self.prefill_att_backend = get_mla_prefill_att_backend_class(index=0)(model=self) + + if decode_backend_str in ("triton", "auto"): + from lightllm.models.glm4_moe_lite.triton_kernel.mla_att import GlmMlaTritonAttBackend + + self.decode_att_backend = GlmMlaTritonAttBackend(model=self) + else: + from lightllm.common.basemodel.attention import get_mla_decode_att_backend_class + + self.decode_att_backend = get_mla_decode_att_backend_class(index=0)(model=self) + def _init_config(self): super()._init_config() if "scoring_func" not in self.config: diff --git a/lightllm/models/glm4_moe_lite/triton_kernel/__init__.py b/lightllm/models/glm4_moe_lite/triton_kernel/__init__.py new file mode 100644 index 000000000..ddf45253a --- /dev/null +++ b/lightllm/models/glm4_moe_lite/triton_kernel/__init__.py @@ -0,0 +1 @@ +# GLM4 MoE Lite specific triton kernels diff --git a/lightllm/models/glm4_moe_lite/triton_kernel/context_flashattention_nopad.py b/lightllm/models/glm4_moe_lite/triton_kernel/context_flashattention_nopad.py new file mode 100644 index 000000000..a23a50b8d --- /dev/null +++ b/lightllm/models/glm4_moe_lite/triton_kernel/context_flashattention_nopad.py @@ -0,0 +1,247 @@ +import torch +import triton +import triton.language as tl +import itertools +from lightllm.utils.device_utils import is_tesla +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.jit +def _glm_fwd_kernel_with_v( + Q_nope, + Q_rope, + K_nope, + K_rope, + V, + sm_scale, + B_Start_Loc, + B_Kv_Start_Loc, + B_Seqlen, + Out, + stride_q_bs, + stride_q_h, + stride_q_d, + stride_q_rope_bs, + stride_q_rope_h, + stride_k_bs, + stride_k_h, + stride_k_d, + stride_k_rope_bs, + stride_k_rope_h, + stride_vbs, + stride_vh, + stride_obs, + stride_oh, + b_prompt_cache_len, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + ACTUAL_NOPE_DIM: tl.constexpr, + BLOCK_ROPE_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_k_head = cur_head + + cur_batch_in_q_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_in_kv_start_index = tl.load(B_Kv_Start_Loc + cur_batch) + prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - prompt_cache_len + + block_start_loc = BLOCK_M * start_m + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + + nope_valid_mask = offs_d < ACTUAL_NOPE_DIM + + off_q = ( + (cur_batch_in_q_start_index + offs_m[:, None]) * stride_q_bs + + cur_head * stride_q_h + + offs_d[None, :] * stride_q_d + ) + off_q_rope = ( + (cur_batch_in_q_start_index + offs_m[:, None]) * stride_q_rope_bs + + cur_head * stride_q_rope_h + + offs_rope_d[None, :] + ) + + off_k = offs_n[None, :] * stride_k_bs + cur_k_head * stride_k_h + offs_d[:, None] * stride_k_d + off_k_rope = offs_n[None, :] * stride_k_rope_bs + offs_rope_d[:, None] + + off_v = offs_n[:, None] * stride_vbs + cur_k_head * stride_vh + offs_d[None, :] + + seq_mask_q = offs_m[:, None] < cur_batch_seq_len + q = tl.load(Q_nope + off_q, mask=seq_mask_q & nope_valid_mask[None, :], other=0.0) + q_rope = tl.load(Q_rope + off_q_rope, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + + k_ptrs = K_nope + off_k + k_rope_ptrs = K_rope + off_k_rope + v_ptrs = V + off_v + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + block_end_loc = tl.minimum((start_m + 1) * BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len) + + for start_n in range(0, block_mask * block_end_loc, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + + seq_mask_k = (start_n + offs_n[None, :]) < block_end_loc + + k = tl.load( + k_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_k_bs, + mask=seq_mask_k & nope_valid_mask[:, None], + other=0.0, + ) + k_rope = tl.load( + k_rope_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_k_rope_bs, + mask=(start_n + offs_n[None, :]) < block_end_loc, + other=0.0, + ) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk += tl.dot(q_rope, k_rope) + qk *= sm_scale + qk = tl.where(offs_m[:, None] + prompt_cache_len >= start_n + offs_n[None, :], qk, float("-100000000.0")) + + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + + v = tl.load( + v_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < block_end_loc, + other=0.0, + ) + p = p.to(v.dtype) + acc = tl.dot(p, v, acc) + m_i = m_ij + + acc = acc / l_i[:, None] + off_o = (cur_batch_in_q_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + + +def get_autotune_configs(): + configs = [] + block_sizes = [32, 64, 128] if not is_tesla() else [16, 32, 64] + num_warps_options = [4, 8] + num_stages_options = [1, 2] + + for block_size, num_warps, num_stages in itertools.product(block_sizes, num_warps_options, num_stages_options): + configs.append( + { + "BLOCK": block_size, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +def get_static_key(q_nope, q_rope, v): + return { + "q_nope_dim": q_nope.shape[-1], + "q_rope_dim": q_rope.shape[-1], + "v_dim": v.shape[-1], + "num_heads": q_nope.shape[1], + "dtype": str(q_nope.dtype), + } + + +def get_run_key(max_input_len): + return max_input_len + + +@autotune( + kernel_name="glm_context_attention_fwd_with_v:v1", + configs_gen_func=get_autotune_configs, + static_key_func=get_static_key, + run_key_func=get_run_key, + mutates_args=["o"], +) +@torch.no_grad() +def glm_context_attention_fwd_with_v( + q_nope, + q_rope, + k_nope, + k_rope, + v, + o, + b_start_loc, + b_kv_start_loc, + b_seq_len, + b_prompt_cache_len, + max_input_len, + softmax_scale, + run_config=None, +): + ACTUAL_NOPE_DIM = q_nope.shape[-1] + BLOCK_DMODEL = v.shape[-1] + BLOCK_ROPE_DMODEL = q_rope.shape[-1] + + if run_config is None: + BLOCK = 64 if not is_tesla() else 32 + num_warps = 4 + num_stages = 1 + else: + BLOCK = run_config["BLOCK"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + if q_nope.dtype == torch.float32: + BLOCK = BLOCK // 4 + + sm_scale = softmax_scale * 1.4426950408889634 + batch, head = b_seq_len.shape[0], q_nope.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) + + _glm_fwd_kernel_with_v[grid]( + q_nope, + q_rope, + k_nope, + k_rope, + v, + sm_scale, + b_start_loc, + b_kv_start_loc, + b_seq_len, + o, + q_nope.stride(0), + q_nope.stride(1), + q_nope.stride(2), + q_rope.stride(0), + q_rope.stride(1), + k_nope.stride(0), + k_nope.stride(1), + k_nope.stride(2), + k_rope.stride(0), + k_rope.stride(1), + v.stride(0), + v.stride(1), + o.stride(0), + o.stride(1), + b_prompt_cache_len=b_prompt_cache_len, + BLOCK_M=BLOCK, + BLOCK_DMODEL=BLOCK_DMODEL, + ACTUAL_NOPE_DIM=ACTUAL_NOPE_DIM, + BLOCK_ROPE_DMODEL=BLOCK_ROPE_DMODEL, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) diff --git a/lightllm/models/glm4_moe_lite/triton_kernel/mla_att.py b/lightllm/models/glm4_moe_lite/triton_kernel/mla_att.py new file mode 100644 index 000000000..c46173e1a --- /dev/null +++ b/lightllm/models/glm4_moe_lite/triton_kernel/mla_att.py @@ -0,0 +1,134 @@ +"""GLM-4.7-Flash MLA attention backend.""" + +import dataclasses +import torch +from lightllm.common.basemodel.attention.base_att import ( + BaseAttBackend, + BasePrefillAttState, + BaseDecodeAttState, + AttControl, +) +from typing import Tuple + + +class GlmMlaTritonAttBackend(BaseAttBackend): + def create_att_prefill_state(self, infer_state) -> "GlmMlaTritonPrefillAttState": + return GlmMlaTritonPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "GlmMlaTritonDecodeAttState": + return GlmMlaTritonDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class GlmMlaTritonPrefillAttState(BasePrefillAttState): + def init_state(self): + pass + + def prefill_att( + self, + q: torch.Tensor, + k: Tuple[torch.Tensor, torch.Tensor], + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + return self._mla_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) + + def _mla_prefill_att( + self, + q: torch.Tensor, + k: Tuple[torch.Tensor, torch.Tensor], + v: torch.Tensor, + att_control: AttControl, + alloc_func=torch.empty, + ): + from .context_flashattention_nopad import glm_context_attention_fwd_with_v + + qk_rope_head_dim = 64 + q_nope, q_rope = q[:, :, :-qk_rope_head_dim], q[:, :, -qk_rope_head_dim:] + o_tensor = alloc_func((q_nope.shape[0], q_nope.shape[1], v.shape[-1]), dtype=q_nope.dtype, device=q.device) + k_nope, k_rope = k + assert att_control.mla_prefill + softmax_scale = att_control.mla_prefill_dict["softmax_scale"] + glm_context_attention_fwd_with_v( + q_nope, + q_rope, + k_nope, + k_rope, + v, + o_tensor, + self.infer_state.b_q_start_loc, + self.infer_state.b1_cu_kv_seq_len, + self.infer_state.b_seq_len, + self.infer_state.b_ready_cache_len, + self.infer_state.max_q_seq_len, + softmax_scale, + ) + return o_tensor + + +@dataclasses.dataclass +class GlmMlaTritonDecodeAttState(BaseDecodeAttState): + def init_state(self): + pass + + def copy_for_decode_cuda_graph(self, new_state: "GlmMlaTritonDecodeAttState"): + super().copy_for_decode_cuda_graph(new_state) + + def decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert ( + att_control.use_sliding_window is False + and att_control.use_att_sink is False + and att_control.use_alibi is False + ) + assert v is None + + return self._mla_decode_att( + q=q, + k=k, + v=v, + att_control=att_control, + alloc_func=alloc_func, + ) + + def _mla_decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl, + alloc_func=torch.empty, + ): + assert att_control.mla_decode + softmax_scale = att_control.mla_decode_dict["softmax_scale"] + + from lightllm.common.basemodel.triton_kernel.mla_att.decode_att import ( + gqa_token_decode_attention_flash_decoding, + ) + + qk_rope_head_dim = 64 + q_nope, q_rope = q + kv = k + + out = gqa_token_decode_attention_flash_decoding( + q_nope=q_nope, + q_rope=q_rope, + kv_nope=kv[:, :, :-qk_rope_head_dim], + kv_rope=kv[:, :, -qk_rope_head_dim:], + infer_state=self.infer_state, + softmax_scale=softmax_scale, + alloc_tensor_func=alloc_func, + ) + return out