From 3ce980475f1d68327981f2cc2b5ab21420108a25 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 22 Jan 2026 10:49:14 +0000 Subject: [PATCH 01/27] add-choose-vit-backend --- .../common/basemodel/attention/__init__.py | 1 + .../common/basemodel/attention/base_att.py | 20 ++++++ .../basemodel/attention/create_utils.py | 33 ++++++++- lightllm/common/basemodel/attention/fa3/fp.py | 69 ++++++++++++++++++- .../common/basemodel/attention/triton/fp.py | 33 ++++++++- lightllm/common/basemodel/basemodel.py | 3 +- lightllm/models/qwen2_vl/qwen2_visual.py | 8 ++- lightllm/models/qwen3_vl/qwen3_visual.py | 6 +- .../visualserver/model_infer/model_rpc.py | 10 ++- 9 files changed, 175 insertions(+), 8 deletions(-) diff --git a/lightllm/common/basemodel/attention/__init__.py b/lightllm/common/basemodel/attention/__init__.py index 80df54549..d6fba6966 100644 --- a/lightllm/common/basemodel/attention/__init__.py +++ b/lightllm/common/basemodel/attention/__init__.py @@ -15,4 +15,5 @@ get_decode_att_backend_class, get_mla_prefill_att_backend_class, get_mla_decode_att_backend_class, + get_vit_att_backend_class, ) diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py index 859d97ca8..28fd5596e 100644 --- a/lightllm/common/basemodel/attention/base_att.py +++ b/lightllm/common/basemodel/attention/base_att.py @@ -37,6 +37,9 @@ def create_att_prefill_state(self) -> "BasePrefillAttState": def create_att_decode_state(self) -> "BaseDecodeAttState": raise NotImplementedError("not impl") + def create_vit_att_state(self) -> "BaseVitAttState": + raise NotImplementedError("not impl") + def _find_layer_index( self, k: torch.Tensor, v: torch.Tensor, att_state: Union["BasePrefillAttState", "BaseDecodeAttState"] ) -> int: @@ -115,3 +118,20 @@ def decode_att( alloc_func=torch.empty, ) -> torch.Tensor: pass + + +class BaseVitAttState(ABC): + + backend: BaseAttBackend = None + + @abstractmethod + def vit_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + ) -> torch.Tensor: + raise NotImplementedError("not impl") diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index 19252cf13..40844c0e0 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -10,7 +10,7 @@ from .triton.int4kv import Int4kvTritonAttBackend from .triton.int8kv import Int8kvTritonAttBackend from .triton.mla import MlaTritonAttBackend -from .fa3.fp import Fa3AttBackend +from .fa3.fp import Fa3AttBackend, Fa3ViTAttBackend from .fa3.fp8 import Fp8Fa3AttBackend from .fa3.mla import MlaFa3AttBackend from .flashinfer.fp8 import Fp8FlashInferAttBackend @@ -46,6 +46,13 @@ }, } +vit_data_type_to_backend = { + "None": { + "triton": TritonAttBackend, + "fa3": Fa3ViTAttBackend, + }, +} + def _auto_select_backend( llm_dtype: str, is_mla: bool = False, priority_list: list = ["fa3", "flashinfer", "triton"] @@ -60,6 +67,7 @@ def _auto_select_backend( for backend_name in priority_list: if validate(backend_name): logger.info(f"Auto-selected {backend_name} backend (validated)") + print(f"llm_dtype is {llm_dtype}, backend_name is {backend_name} ") return backend_map[llm_dtype][backend_name] # Fallback to triton without validation (should not happen) @@ -67,6 +75,25 @@ def _auto_select_backend( return backend_map[llm_dtype]["triton"] +def _auto_select_vit_backend(llm_dtype: str, priority_list: list = ["fa3", "triton"]) -> type: + """Auto-select the best available backend with validation for vit. + + Priority: FA3 > Triton + Each backend is validated in a subprocess with ground truth checks. + """ + backend_map = vit_data_type_to_backend + + for backend_name in priority_list: + if validate(backend_name): + logger.info(f"Auto-selected {backend_name} backend (validated) for ViT") + print(f"llm_dtype is {llm_dtype}, backend_name is {backend_name} ") + return backend_map[llm_dtype][backend_name] + + # Fallback to triton without validation (should not happen) + logger.warning("No backend validation succeeded for vit, falling back to triton") + return backend_map[llm_dtype]["triton"] + + def get_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend: args = get_env_start_args() llm_dtype = args.llm_kv_type @@ -105,3 +132,7 @@ def get_mla_decode_att_backend_class(index=0, priority_list: list = ["fa3", "fla return mla_data_type_to_backend[llm_dtype][backend_str] else: return _auto_select_backend(llm_dtype, is_mla=True, priority_list=priority_list) + + +def get_vit_att_backend_class(index=0, priority_list: list = ["fa3", "triton"]) -> BaseAttBackend: + return _auto_select_vit_backend(llm_dtype="None", priority_list=priority_list) diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 952bb39d9..7369f7dee 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -1,6 +1,6 @@ import dataclasses import torch -from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl, BaseVitAttState from typing import Optional, TYPE_CHECKING from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.sgl_utils import flash_attn_with_kvcache @@ -37,6 +37,14 @@ def create_att_decode_state(self, infer_state) -> "Fa3DecodeAttState": return Fa3DecodeAttState(backend=self, infer_state=infer_state) +class Fa3ViTAttBackend(BaseAttBackend): + def __init__(self, model): + super().__init__(model=model) + + def create_vit_att_state(self) -> "Fa3VitAttState": + return Fa3VitAttState(backend=self) + + @dataclasses.dataclass class Fa3PrefillAttState(BasePrefillAttState): cu_seqlens_q: torch.Tensor = None @@ -241,3 +249,62 @@ def _normal_decode_att( sinks=sink_weight, ) return o + + +@dataclasses.dataclass +class Fa3VitAttState(BaseVitAttState): + + backend: "Fa3ViTAttBackend" + + def vit_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + ) -> None: + self.backend: Fa3ViTAttBackend = self.backend # for typing + + head_dim = q.shape[-1] + softmax_scale = head_dim ** -0.5 + window_size = (-1, -1) + torch.ops.sgl_kernel.fwd.default( + q, + k, + v, + None, # k_new + None, # v_new + None, # qv + o, # out + cu_seqlens, + cu_seqlens, + None, # cu_seqlens_k_new + None, + None, + max_seqlen, + max_seqlen, + None, # page_table, + None, # kv_batch_idx + None, # leftpad_k + None, # rotary cos + None, # rotary sin + None, # seqlens_rotary + None, + None, + None, + softmax_scale, + False, + window_size[0], + window_size[1], + 0.0, + is_rotary_interleaved=False, + scheduler_metadata=None, + num_splits=1, + pack_gqa=None, + sm_margin=0, + sinks=None, + ) + + return o diff --git a/lightllm/common/basemodel/attention/triton/fp.py b/lightllm/common/basemodel/attention/triton/fp.py index d29f15ec3..c35ad362a 100644 --- a/lightllm/common/basemodel/attention/triton/fp.py +++ b/lightllm/common/basemodel/attention/triton/fp.py @@ -1,6 +1,6 @@ import dataclasses import torch -from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl, BaseVitAttState from typing import Optional @@ -11,6 +11,9 @@ def create_att_prefill_state(self, infer_state) -> "TritonPrefillAttState": def create_att_decode_state(self, infer_state) -> "TritonDecodeAttState": return TritonDecodeAttState(backend=self, infer_state=infer_state) + def create_vit_att_state(self, infer_state) -> "TritonDecodeAttState": + return TritonVitAttState(backend=self, infer_state=infer_state) + @dataclasses.dataclass class TritonPrefillAttState(BasePrefillAttState): @@ -273,3 +276,31 @@ def _normal_decode_stage3_att( b_seq_len=self.infer_state.b_seq_len, ) return o_tensor + + +@dataclasses.dataclass +class TritonVitAttState(BaseVitAttState): + def init_state(self): + pass + + def _vit_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + alloc_func=torch.empty, + ): + from lightllm.models.vit.triton_kernel.flashattention_nopad import _flash_attention_triton_fwd + + _flash_attention_triton_fwd( + q, + k, + v, + o, + cu_seqlens, # q k v cu_seqlens, + max_seqlen, + ) + return diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 26d51af3d..e97381884 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -32,7 +32,7 @@ from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel from lightllm.common.triton_utils.autotuner import Autotuner from lightllm.utils.infer_utils import post_empty_cache -from .attention import get_prefill_att_backend_class, get_decode_att_backend_class +from .attention import get_prefill_att_backend_class, get_decode_att_backend_class, get_vit_att_backend_class from .attention import BaseAttBackend logger = init_logger(__name__) @@ -119,7 +119,6 @@ def __init__(self, kvargs): self._init_custom() # wait必须在init cudagraph 之前,避免错误捕获 self._wait_other_modules_ready() - self._init_att_backend() self._init_att_backend1() diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index 334ffc844..7d53c326f 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -127,6 +127,7 @@ def __init__(self, dim: int, num_heads: int = 16) -> None: self.num_heads = num_heads self.qkv = nn.Linear(dim, dim * 3, bias=True) self.proj = nn.Linear(dim, dim) + self.vit_att_backend = None def forward( self, @@ -143,7 +144,7 @@ def forward( attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device) - flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen) + self.vit_att_backend.vit_att(q, k, v, attn_output, cu_seqlens, max_seqlen) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) return attn_output @@ -234,6 +235,11 @@ def _init_datatype(self): raise ValueError(f"Unsupport datatype {self.data_type}!") return + def _init_vit_att(self, vit_att): + for blk in self.blocks: + blk.attn.vit_att_backend = vit_att + return + def load_model(self, weight_dir): processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") diff --git a/lightllm/models/qwen3_vl/qwen3_visual.py b/lightllm/models/qwen3_vl/qwen3_visual.py index 00ad6c05a..dd119f2de 100644 --- a/lightllm/models/qwen3_vl/qwen3_visual.py +++ b/lightllm/models/qwen3_vl/qwen3_visual.py @@ -129,7 +129,6 @@ def __init__( ): super().__init__() self.data_type = kvargs.get("data_type", "bfloat16") - self.depth = depth self.out_hidden_size = out_hidden_size self.hidden_size = hidden_size @@ -182,6 +181,11 @@ def __init__( ) self._init_datatype() + def _init_vit_att(self, vit_att): + for blk in self.blocks: + blk.attn.vit_att_backend = vit_att + return + def _init_datatype(self): if isinstance(self.data_type, torch.dtype): return diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index d3d1610f3..79434f5b9 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -24,9 +24,15 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient +from lightllm.common.basemodel.basemodel import BaseAttBackend, get_vit_att_backend_class +from lightllm.utils.dist_utils import set_global_rank class VisualModelRpcServer(rpyc.Service): + def _init_vit_att_backend(self): + self.vit_att_backend: BaseAttBackend = get_vit_att_backend_class(index=0)(model=self) + return + def exposed_init_model(self, kvargs): kvargs = obtain(kvargs) import torch @@ -42,7 +48,7 @@ def exposed_init_model(self, kvargs): self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.data_type = kvargs["data_type"] - + set_global_rank(kvargs["tp_rank_id"]) # 这里看后面怎么改 init_vision_distributed_env(kvargs) model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) @@ -81,6 +87,8 @@ def exposed_init_model(self, kvargs): else: raise Exception(f"can not support {self.model_type} now") + self._init_vit_att_backend() + self.model._init_vit_att(self.vit_att_backend.create_vit_att_state()) self.model.load_model(weight_dir) self.model = self.model.cuda() self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=True) From 7ec0e7af63713a1527175119fd3cdc1827e458d0 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 23 Jan 2026 02:49:01 +0000 Subject: [PATCH 02/27] add vit_attention dirs. --- lightllm/common/basemodel/attention_vit/__init__.py | 0 lightllm/common/basemodel/attention_vit/fa3/__init__.py | 0 lightllm/common/basemodel/attention_vit/triton/__init__.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 lightllm/common/basemodel/attention_vit/__init__.py create mode 100644 lightllm/common/basemodel/attention_vit/fa3/__init__.py create mode 100644 lightllm/common/basemodel/attention_vit/triton/__init__.py diff --git a/lightllm/common/basemodel/attention_vit/__init__.py b/lightllm/common/basemodel/attention_vit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/common/basemodel/attention_vit/fa3/__init__.py b/lightllm/common/basemodel/attention_vit/fa3/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/common/basemodel/attention_vit/triton/__init__.py b/lightllm/common/basemodel/attention_vit/triton/__init__.py new file mode 100644 index 000000000..e69de29bb From e76b3cc9e5ad7e04b053ce35a2276aba58810feb Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 23 Jan 2026 03:01:12 +0000 Subject: [PATCH 03/27] fix. --- .../basemodel/attention_vit/base_att.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 lightllm/common/basemodel/attention_vit/base_att.py diff --git a/lightllm/common/basemodel/attention_vit/base_att.py b/lightllm/common/basemodel/attention_vit/base_att.py new file mode 100644 index 000000000..1b51ae08a --- /dev/null +++ b/lightllm/common/basemodel/attention_vit/base_att.py @@ -0,0 +1,46 @@ +import torch +from abc import ABC, abstractmethod + + +class BaseVitAttBackend: + """ + 用于创建支持各种不同的AttBackend, 如 fa3, flashinfer, triton 实现等, + 这个是单列模式, 每种backend只有一个实例 + """ + + _instances = {} + + def __new__(cls, *args, **kwargs): + """ + 重写__new__方法实现单例模式 + """ + # 检查是否已经有该类的实例 + if cls not in cls._instances: + # 创建新实例并存储 + instance = super().__new__(cls) + cls._instances[cls] = instance + # 返回已有的实例 + return cls._instances[cls] + + def __init__(self, model): + self.model = model + + def create_vit_att_state(self) -> "BaseVitAttState": + raise NotImplementedError("not impl") + + +class BaseVitAttState(ABC): + + backend: BaseVitAttBackend = None + + @abstractmethod + def vit_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + ) -> torch.Tensor: + raise NotImplementedError("not impl") From f8537bb54e95eebc9b90a450632d6d9adc53b051 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 23 Jan 2026 03:08:26 +0000 Subject: [PATCH 04/27] fix --- .../common/basemodel/attention_vit/fa3/fp.py | 68 +++++++++++++++++++ .../basemodel/attention_vit/triton/fp.py | 36 ++++++++++ 2 files changed, 104 insertions(+) create mode 100644 lightllm/common/basemodel/attention_vit/fa3/fp.py create mode 100644 lightllm/common/basemodel/attention_vit/triton/fp.py diff --git a/lightllm/common/basemodel/attention_vit/fa3/fp.py b/lightllm/common/basemodel/attention_vit/fa3/fp.py new file mode 100644 index 000000000..dfd437581 --- /dev/null +++ b/lightllm/common/basemodel/attention_vit/fa3/fp.py @@ -0,0 +1,68 @@ +import dataclasses +import torch +from ..base_att import BaseVitAttState, BaseVitAttBackend +from lightllm.utils.sgl_utils import flash_attn_with_kvcache + + +class Fa3VitAttBackend(BaseVitAttBackend): + def __init__(self, model): + super().__init__(model=model) + + def create_vit_att_state(self) -> "Fa3VitAttState": + return Fa3VitAttState(backend=self) + + +@dataclasses.dataclass +class Fa3VitAttState(BaseVitAttState): + def vit_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + ) -> None: + self.backend: Fa3VitAttBackend = self.backend # for typing + + head_dim = q.shape[-1] + softmax_scale = head_dim ** -0.5 + window_size = (-1, -1) + torch.ops.sgl_kernel.fwd.default( + q, + k, + v, + None, # k_new + None, # v_new + None, # qv + o, # out + cu_seqlens, + cu_seqlens, + None, # cu_seqlens_k_new + None, + None, + max_seqlen, + max_seqlen, + None, # page_table, + None, # kv_batch_idx + None, # leftpad_k + None, # rotary cos + None, # rotary sin + None, # seqlens_rotary + None, + None, + None, + softmax_scale, + False, + window_size[0], + window_size[1], + 0.0, + is_rotary_interleaved=False, + scheduler_metadata=None, + num_splits=1, + pack_gqa=None, + sm_margin=0, + sinks=None, + ) + + return o diff --git a/lightllm/common/basemodel/attention_vit/triton/fp.py b/lightllm/common/basemodel/attention_vit/triton/fp.py new file mode 100644 index 000000000..506313a8d --- /dev/null +++ b/lightllm/common/basemodel/attention_vit/triton/fp.py @@ -0,0 +1,36 @@ +import dataclasses +import torch +from ..base_att import BaseVitAttBackend, BaseVitAttState + + +class TritonVitAttBackend(BaseVitAttBackend): + def create_vit_att_state(self, infer_state) -> "TritonVitAttState": + return TritonVitAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class TritonVitAttState(BaseVitAttState): + def init_state(self): + pass + + def _vit_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + alloc_func=torch.empty, + ): + from lightllm.models.vit.triton_kernel.flashattention_nopad import _flash_attention_triton_fwd + + _flash_attention_triton_fwd( + q, + k, + v, + o, + cu_seqlens, # q k v cu_seqlens, + max_seqlen, + ) + return From 65549a746da24a0df31c78cef3cf3b380f5b76b5 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 23 Jan 2026 03:10:33 +0000 Subject: [PATCH 05/27] fix --- .../common/basemodel/attention/__init__.py | 1 - .../common/basemodel/attention/base_att.py | 20 ------ .../basemodel/attention/create_utils.py | 33 +-------- lightllm/common/basemodel/attention/fa3/fp.py | 69 +------------------ .../common/basemodel/attention/triton/fp.py | 33 +-------- 5 files changed, 3 insertions(+), 153 deletions(-) diff --git a/lightllm/common/basemodel/attention/__init__.py b/lightllm/common/basemodel/attention/__init__.py index d6fba6966..80df54549 100644 --- a/lightllm/common/basemodel/attention/__init__.py +++ b/lightllm/common/basemodel/attention/__init__.py @@ -15,5 +15,4 @@ get_decode_att_backend_class, get_mla_prefill_att_backend_class, get_mla_decode_att_backend_class, - get_vit_att_backend_class, ) diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py index 28fd5596e..859d97ca8 100644 --- a/lightllm/common/basemodel/attention/base_att.py +++ b/lightllm/common/basemodel/attention/base_att.py @@ -37,9 +37,6 @@ def create_att_prefill_state(self) -> "BasePrefillAttState": def create_att_decode_state(self) -> "BaseDecodeAttState": raise NotImplementedError("not impl") - def create_vit_att_state(self) -> "BaseVitAttState": - raise NotImplementedError("not impl") - def _find_layer_index( self, k: torch.Tensor, v: torch.Tensor, att_state: Union["BasePrefillAttState", "BaseDecodeAttState"] ) -> int: @@ -118,20 +115,3 @@ def decode_att( alloc_func=torch.empty, ) -> torch.Tensor: pass - - -class BaseVitAttState(ABC): - - backend: BaseAttBackend = None - - @abstractmethod - def vit_att( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - cu_seqlens: torch.Tensor, - max_seqlen: int, - ) -> torch.Tensor: - raise NotImplementedError("not impl") diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index 40844c0e0..19252cf13 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -10,7 +10,7 @@ from .triton.int4kv import Int4kvTritonAttBackend from .triton.int8kv import Int8kvTritonAttBackend from .triton.mla import MlaTritonAttBackend -from .fa3.fp import Fa3AttBackend, Fa3ViTAttBackend +from .fa3.fp import Fa3AttBackend from .fa3.fp8 import Fp8Fa3AttBackend from .fa3.mla import MlaFa3AttBackend from .flashinfer.fp8 import Fp8FlashInferAttBackend @@ -46,13 +46,6 @@ }, } -vit_data_type_to_backend = { - "None": { - "triton": TritonAttBackend, - "fa3": Fa3ViTAttBackend, - }, -} - def _auto_select_backend( llm_dtype: str, is_mla: bool = False, priority_list: list = ["fa3", "flashinfer", "triton"] @@ -67,7 +60,6 @@ def _auto_select_backend( for backend_name in priority_list: if validate(backend_name): logger.info(f"Auto-selected {backend_name} backend (validated)") - print(f"llm_dtype is {llm_dtype}, backend_name is {backend_name} ") return backend_map[llm_dtype][backend_name] # Fallback to triton without validation (should not happen) @@ -75,25 +67,6 @@ def _auto_select_backend( return backend_map[llm_dtype]["triton"] -def _auto_select_vit_backend(llm_dtype: str, priority_list: list = ["fa3", "triton"]) -> type: - """Auto-select the best available backend with validation for vit. - - Priority: FA3 > Triton - Each backend is validated in a subprocess with ground truth checks. - """ - backend_map = vit_data_type_to_backend - - for backend_name in priority_list: - if validate(backend_name): - logger.info(f"Auto-selected {backend_name} backend (validated) for ViT") - print(f"llm_dtype is {llm_dtype}, backend_name is {backend_name} ") - return backend_map[llm_dtype][backend_name] - - # Fallback to triton without validation (should not happen) - logger.warning("No backend validation succeeded for vit, falling back to triton") - return backend_map[llm_dtype]["triton"] - - def get_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend: args = get_env_start_args() llm_dtype = args.llm_kv_type @@ -132,7 +105,3 @@ def get_mla_decode_att_backend_class(index=0, priority_list: list = ["fa3", "fla return mla_data_type_to_backend[llm_dtype][backend_str] else: return _auto_select_backend(llm_dtype, is_mla=True, priority_list=priority_list) - - -def get_vit_att_backend_class(index=0, priority_list: list = ["fa3", "triton"]) -> BaseAttBackend: - return _auto_select_vit_backend(llm_dtype="None", priority_list=priority_list) diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 7369f7dee..952bb39d9 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -1,6 +1,6 @@ import dataclasses import torch -from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl, BaseVitAttState +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from typing import Optional, TYPE_CHECKING from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.sgl_utils import flash_attn_with_kvcache @@ -37,14 +37,6 @@ def create_att_decode_state(self, infer_state) -> "Fa3DecodeAttState": return Fa3DecodeAttState(backend=self, infer_state=infer_state) -class Fa3ViTAttBackend(BaseAttBackend): - def __init__(self, model): - super().__init__(model=model) - - def create_vit_att_state(self) -> "Fa3VitAttState": - return Fa3VitAttState(backend=self) - - @dataclasses.dataclass class Fa3PrefillAttState(BasePrefillAttState): cu_seqlens_q: torch.Tensor = None @@ -249,62 +241,3 @@ def _normal_decode_att( sinks=sink_weight, ) return o - - -@dataclasses.dataclass -class Fa3VitAttState(BaseVitAttState): - - backend: "Fa3ViTAttBackend" - - def vit_att( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - cu_seqlens: torch.Tensor, - max_seqlen: int, - ) -> None: - self.backend: Fa3ViTAttBackend = self.backend # for typing - - head_dim = q.shape[-1] - softmax_scale = head_dim ** -0.5 - window_size = (-1, -1) - torch.ops.sgl_kernel.fwd.default( - q, - k, - v, - None, # k_new - None, # v_new - None, # qv - o, # out - cu_seqlens, - cu_seqlens, - None, # cu_seqlens_k_new - None, - None, - max_seqlen, - max_seqlen, - None, # page_table, - None, # kv_batch_idx - None, # leftpad_k - None, # rotary cos - None, # rotary sin - None, # seqlens_rotary - None, - None, - None, - softmax_scale, - False, - window_size[0], - window_size[1], - 0.0, - is_rotary_interleaved=False, - scheduler_metadata=None, - num_splits=1, - pack_gqa=None, - sm_margin=0, - sinks=None, - ) - - return o diff --git a/lightllm/common/basemodel/attention/triton/fp.py b/lightllm/common/basemodel/attention/triton/fp.py index c35ad362a..d29f15ec3 100644 --- a/lightllm/common/basemodel/attention/triton/fp.py +++ b/lightllm/common/basemodel/attention/triton/fp.py @@ -1,6 +1,6 @@ import dataclasses import torch -from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl, BaseVitAttState +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from typing import Optional @@ -11,9 +11,6 @@ def create_att_prefill_state(self, infer_state) -> "TritonPrefillAttState": def create_att_decode_state(self, infer_state) -> "TritonDecodeAttState": return TritonDecodeAttState(backend=self, infer_state=infer_state) - def create_vit_att_state(self, infer_state) -> "TritonDecodeAttState": - return TritonVitAttState(backend=self, infer_state=infer_state) - @dataclasses.dataclass class TritonPrefillAttState(BasePrefillAttState): @@ -276,31 +273,3 @@ def _normal_decode_stage3_att( b_seq_len=self.infer_state.b_seq_len, ) return o_tensor - - -@dataclasses.dataclass -class TritonVitAttState(BaseVitAttState): - def init_state(self): - pass - - def _vit_att( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - cu_seqlens: torch.Tensor, - max_seqlen: int, - alloc_func=torch.empty, - ): - from lightllm.models.vit.triton_kernel.flashattention_nopad import _flash_attention_triton_fwd - - _flash_attention_triton_fwd( - q, - k, - v, - o, - cu_seqlens, # q k v cu_seqlens, - max_seqlen, - ) - return From e35427c6436d61ea1a8f22494846b366bdaf8b36 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 23 Jan 2026 03:26:39 +0000 Subject: [PATCH 06/27] fix --- lightllm/common/basemodel/attention_vit/base_att.py | 10 +--------- lightllm/common/basemodel/attention_vit/fa3/fp.py | 9 +-------- lightllm/common/basemodel/attention_vit/triton/fp.py | 12 +----------- 3 files changed, 3 insertions(+), 28 deletions(-) diff --git a/lightllm/common/basemodel/attention_vit/base_att.py b/lightllm/common/basemodel/attention_vit/base_att.py index 1b51ae08a..e43475ac0 100644 --- a/lightllm/common/basemodel/attention_vit/base_att.py +++ b/lightllm/common/basemodel/attention_vit/base_att.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod -class BaseVitAttBackend: +class BaseVitAttBackend(ABC): """ 用于创建支持各种不同的AttBackend, 如 fa3, flashinfer, triton 实现等, 这个是单列模式, 每种backend只有一个实例 @@ -25,14 +25,6 @@ def __new__(cls, *args, **kwargs): def __init__(self, model): self.model = model - def create_vit_att_state(self) -> "BaseVitAttState": - raise NotImplementedError("not impl") - - -class BaseVitAttState(ABC): - - backend: BaseVitAttBackend = None - @abstractmethod def vit_att( self, diff --git a/lightllm/common/basemodel/attention_vit/fa3/fp.py b/lightllm/common/basemodel/attention_vit/fa3/fp.py index dfd437581..fa7f14adb 100644 --- a/lightllm/common/basemodel/attention_vit/fa3/fp.py +++ b/lightllm/common/basemodel/attention_vit/fa3/fp.py @@ -1,19 +1,12 @@ import dataclasses import torch -from ..base_att import BaseVitAttState, BaseVitAttBackend -from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from ..base_att import BaseVitAttBackend class Fa3VitAttBackend(BaseVitAttBackend): def __init__(self, model): super().__init__(model=model) - def create_vit_att_state(self) -> "Fa3VitAttState": - return Fa3VitAttState(backend=self) - - -@dataclasses.dataclass -class Fa3VitAttState(BaseVitAttState): def vit_att( self, q: torch.Tensor, diff --git a/lightllm/common/basemodel/attention_vit/triton/fp.py b/lightllm/common/basemodel/attention_vit/triton/fp.py index 506313a8d..2c012f5a2 100644 --- a/lightllm/common/basemodel/attention_vit/triton/fp.py +++ b/lightllm/common/basemodel/attention_vit/triton/fp.py @@ -1,18 +1,8 @@ -import dataclasses import torch -from ..base_att import BaseVitAttBackend, BaseVitAttState +from ..base_att import BaseVitAttBackend class TritonVitAttBackend(BaseVitAttBackend): - def create_vit_att_state(self, infer_state) -> "TritonVitAttState": - return TritonVitAttState(backend=self, infer_state=infer_state) - - -@dataclasses.dataclass -class TritonVitAttState(BaseVitAttState): - def init_state(self): - pass - def _vit_att( self, q: torch.Tensor, From c18dd256f752e5d359191b142a9fd030848047c6 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 23 Jan 2026 04:38:01 +0000 Subject: [PATCH 07/27] fix --- lightllm/common/basemodel/attention_vit/fa3/fp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightllm/common/basemodel/attention_vit/fa3/fp.py b/lightllm/common/basemodel/attention_vit/fa3/fp.py index fa7f14adb..ed1be1400 100644 --- a/lightllm/common/basemodel/attention_vit/fa3/fp.py +++ b/lightllm/common/basemodel/attention_vit/fa3/fp.py @@ -16,7 +16,6 @@ def vit_att( cu_seqlens: torch.Tensor, max_seqlen: int, ) -> None: - self.backend: Fa3VitAttBackend = self.backend # for typing head_dim = q.shape[-1] softmax_scale = head_dim ** -0.5 From ad46e4aa63ac5fb78f4818a7d78aa67f3bd42825 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 23 Jan 2026 06:02:03 +0000 Subject: [PATCH 08/27] fix --- lightllm/common/basemodel/basemodel.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index e97381884..26d51af3d 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -32,7 +32,7 @@ from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel from lightllm.common.triton_utils.autotuner import Autotuner from lightllm.utils.infer_utils import post_empty_cache -from .attention import get_prefill_att_backend_class, get_decode_att_backend_class, get_vit_att_backend_class +from .attention import get_prefill_att_backend_class, get_decode_att_backend_class from .attention import BaseAttBackend logger = init_logger(__name__) @@ -119,6 +119,7 @@ def __init__(self, kvargs): self._init_custom() # wait必须在init cudagraph 之前,避免错误捕获 self._wait_other_modules_ready() + self._init_att_backend() self._init_att_backend1() From ee1d9aa3e6d52f322042c908bbd1945df95d3866 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Mon, 26 Jan 2026 10:32:23 +0000 Subject: [PATCH 09/27] fix0126 --- .../basemodel/attention_vit/base_att.py | 4 +- .../basemodel/attention_vit/create_utils.py | 101 ++++++++++++++++++ .../common/basemodel/attention_vit/fa3/fp.py | 5 +- .../basemodel/attention_vit/sdpa/__init__.py | 0 .../common/basemodel/attention_vit/sdpa/fp.py | 48 +++++++++ .../basemodel/attention_vit/triton/fp.py | 10 +- lightllm/models/qwen2_5_vl/qwen2_5_visual.py | 4 +- lightllm/models/qwen2_vl/qwen2_visual.py | 5 +- lightllm/models/qwen3_vl/qwen3_visual.py | 5 - .../layer_infer/transformer_layer_infer.py | 4 +- lightllm/server/api_cli.py | 10 ++ lightllm/server/core/objs/start_args_type.py | 1 + lightllm/server/visualserver/__init__.py | 16 +++ .../visualserver/model_infer/model_rpc.py | 10 +- .../benchmark/static_inference/model_infer.py | 1 + 15 files changed, 192 insertions(+), 32 deletions(-) create mode 100644 lightllm/common/basemodel/attention_vit/create_utils.py create mode 100644 lightllm/common/basemodel/attention_vit/sdpa/__init__.py create mode 100644 lightllm/common/basemodel/attention_vit/sdpa/fp.py diff --git a/lightllm/common/basemodel/attention_vit/base_att.py b/lightllm/common/basemodel/attention_vit/base_att.py index e43475ac0..405aeb245 100644 --- a/lightllm/common/basemodel/attention_vit/base_att.py +++ b/lightllm/common/basemodel/attention_vit/base_att.py @@ -22,8 +22,8 @@ def __new__(cls, *args, **kwargs): # 返回已有的实例 return cls._instances[cls] - def __init__(self, model): - self.model = model + def __init__(self): + pass @abstractmethod def vit_att( diff --git a/lightllm/common/basemodel/attention_vit/create_utils.py b/lightllm/common/basemodel/attention_vit/create_utils.py new file mode 100644 index 000000000..8983ea087 --- /dev/null +++ b/lightllm/common/basemodel/attention_vit/create_utils.py @@ -0,0 +1,101 @@ +import torch +from lightllm.utils.log_utils import init_logger +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.backend_validator import _validate_triton, _compute_ground_truth +from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend +from lightllm.common.basemodel.attention_vit.fa3.fp import Fa3VitAttBackend +from lightllm.common.basemodel.attention_vit.triton.fp import TritonVitAttBackend +from lightllm.common.basemodel.attention_vit.sdpa.fp import SdpaVitAttBackend + +logger = init_logger(__name__) + + +vit_att_backend = {"triton": TritonVitAttBackend, "sdpa": SdpaVitAttBackend, "fa3": Fa3VitAttBackend} + + +def get_vit_att_backend_class(index=0, priority_list: list = ["fa3", "sdpa", "triton"]) -> BaseVitAttBackend: + args = get_env_start_args() + backend_str = args.vit_att_backend[index] + if backend_str != "auto": + return vit_att_backend[backend_str] + else: + return _select_vit_backend(priority_list=priority_list) + + +def _select_vit_backend(priority_list: list = ["fa3", "sdpa", "triton"]) -> type: + """Auto-select the best available backend with validation for VIT. + + Priority: FA3 > Sdpa > Triton + Each backend is validated in a subprocess with ground truth checks. + """ + backend_map = vit_att_backend + + for backend_name in priority_list: + if validate(backend_name): + logger.info(f"Auto-selected {backend_name} backend (validated) for VIT") + return backend_map[backend_name] + + # Fallback to triton without validation (should not happen) + logger.warning("No backend validation succeeded, falling back to triton") + return backend_map["triton"] + + +def validate(backend_name: str) -> bool: + if backend_name == "fa3": + validate_ok = _validate_fa3() + elif backend_name == "sdpa": + validate_ok = _validate_sdpa() + elif backend_name == "triton": + validate_ok = _validate_triton() + else: + raise ValueError("not suuported vit attn backend") + return validate_ok + + +def _validate_fa3(): + """Validate FA3 with ground truth.""" + from lightllm.utils.sgl_utils import flash_attn_varlen_func + + if flash_attn_varlen_func is None: + return False + + batch, heads, seq, dim = 1, 4, 8, 64 + q = torch.randn(batch, heads, seq, dim, dtype=torch.bfloat16, device="cuda") + k = torch.randn(batch, heads, seq, dim, dtype=torch.bfloat16, device="cuda") + v = torch.randn(batch, heads, seq, dim, dtype=torch.bfloat16, device="cuda") + + expected = _compute_ground_truth(q, k, v) + + q_flat = q.transpose(1, 2).reshape(batch * seq, heads, dim) + k_flat = k.transpose(1, 2).reshape(batch * seq, heads, dim) + v_flat = v.transpose(1, 2).reshape(batch * seq, heads, dim) + cu_seqlens = torch.arange(0, batch * seq + 1, seq, dtype=torch.int32, device="cuda") + + out = flash_attn_varlen_func( + q=q_flat, + k=k_flat, + v=v_flat, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=seq, + max_seqlen_k=seq, + softmax_scale=1.0 / (dim ** 0.5), + causal=True, + ) + out = out.reshape(batch, seq, heads, dim).transpose(1, 2) + torch.cuda.synchronize() + + if not torch.allclose(out, expected, rtol=1e-2, atol=1e-2): + return False + + return True + + +def _validate_sdpa(): + """Validate SDPA Attn""" + from torch.nn.functional import scaled_dot_product_attention + + if scaled_dot_product_attention is None: + return False + + return True diff --git a/lightllm/common/basemodel/attention_vit/fa3/fp.py b/lightllm/common/basemodel/attention_vit/fa3/fp.py index ed1be1400..cec91165d 100644 --- a/lightllm/common/basemodel/attention_vit/fa3/fp.py +++ b/lightllm/common/basemodel/attention_vit/fa3/fp.py @@ -1,12 +1,9 @@ import dataclasses import torch -from ..base_att import BaseVitAttBackend +from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend class Fa3VitAttBackend(BaseVitAttBackend): - def __init__(self, model): - super().__init__(model=model) - def vit_att( self, q: torch.Tensor, diff --git a/lightllm/common/basemodel/attention_vit/sdpa/__init__.py b/lightllm/common/basemodel/attention_vit/sdpa/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/common/basemodel/attention_vit/sdpa/fp.py b/lightllm/common/basemodel/attention_vit/sdpa/fp.py new file mode 100644 index 000000000..7f57d98a2 --- /dev/null +++ b/lightllm/common/basemodel/attention_vit/sdpa/fp.py @@ -0,0 +1,48 @@ +import torch +import torch.nn.functional as F +from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend + + +class SdpaVitAttBackend(BaseVitAttBackend): + def vit_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + ) -> torch.Tensor: + assert q.ndim == k.ndim == v.ndim == o.ndim == 3 + assert cu_seqlens is not None and cu_seqlens.ndim == 1 + + cu = cu_seqlens.to(device=q.device) + B = cu.numel() - 1 + + with torch.no_grad(): + for b in range(B): + s = int(cu[b].item()) + e = int(cu[b + 1].item()) + L = e - s + if L <= 0: + continue + if max_seqlen: + assert L <= max_seqlen + + # [L, H, D] -> [1, H, L, D] + q_ = q[s:e].permute(1, 0, 2).unsqueeze(0) + k_ = k[s:e].permute(1, 0, 2).unsqueeze(0) + v_ = v[s:e].permute(1, 0, 2).unsqueeze(0) + + out = F.scaled_dot_product_attention( + q_, + k_, + v_, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + ) + # [1, H, L, D] -> [L, H, D] + o[s:e].copy_(out.squeeze(0).permute(1, 0, 2)) + + return o diff --git a/lightllm/common/basemodel/attention_vit/triton/fp.py b/lightllm/common/basemodel/attention_vit/triton/fp.py index 2c012f5a2..51e47f056 100644 --- a/lightllm/common/basemodel/attention_vit/triton/fp.py +++ b/lightllm/common/basemodel/attention_vit/triton/fp.py @@ -1,9 +1,10 @@ import torch -from ..base_att import BaseVitAttBackend +from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend +from lightllm.models.vit.triton_kernel.flashattention_nopad import _flash_attention_triton_fwd class TritonVitAttBackend(BaseVitAttBackend): - def _vit_att( + def vit_att( self, q: torch.Tensor, k: torch.Tensor, @@ -11,10 +12,7 @@ def _vit_att( o: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: int, - alloc_func=torch.empty, ): - from lightllm.models.vit.triton_kernel.flashattention_nopad import _flash_attention_triton_fwd - _flash_attention_triton_fwd( q, k, @@ -23,4 +21,4 @@ def _vit_att( cu_seqlens, # q k v cu_seqlens, max_seqlen, ) - return + return o diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py index 498f82e14..7156a5ce2 100644 --- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -13,7 +13,7 @@ from lightllm.server.multimodal_params import ImageItem from lightllm.models.qwen2_vl.qwen2_visual import PatchEmbed, VisionRotaryEmbedding from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm -from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd +from lightllm.server.visualserver import get_vit_attn_backend from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton @@ -74,7 +74,7 @@ def forward( k = apply_rotary_pos_emb_triton(k, rotary_cos, rotary_sin) attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device) - flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen) + get_vit_attn_backend()(q, k, v, attn_output, cu_seqlens, max_seqlen) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) return attn_output diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index 7d53c326f..8cf23bd57 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -32,8 +32,8 @@ from transformers.activations import ACT2FN from safetensors import safe_open from lightllm.server.multimodal_params import ImageItem +from lightllm.server.visualserver import get_vit_attn_backend from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor -from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton @@ -127,7 +127,6 @@ def __init__(self, dim: int, num_heads: int = 16) -> None: self.num_heads = num_heads self.qkv = nn.Linear(dim, dim * 3, bias=True) self.proj = nn.Linear(dim, dim) - self.vit_att_backend = None def forward( self, @@ -144,7 +143,7 @@ def forward( attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device) - self.vit_att_backend.vit_att(q, k, v, attn_output, cu_seqlens, max_seqlen) + get_vit_attn_backend()(q, k, v, attn_output, cu_seqlens, max_seqlen) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) return attn_output diff --git a/lightllm/models/qwen3_vl/qwen3_visual.py b/lightllm/models/qwen3_vl/qwen3_visual.py index dd119f2de..e3e350729 100644 --- a/lightllm/models/qwen3_vl/qwen3_visual.py +++ b/lightllm/models/qwen3_vl/qwen3_visual.py @@ -181,11 +181,6 @@ def __init__( ) self._init_datatype() - def _init_vit_att(self, vit_att): - for blk in self.blocks: - blk.attn.vit_att_backend = vit_att - return - def _init_datatype(self): if isinstance(self.data_type, torch.dtype): return diff --git a/lightllm/models/vit/layer_infer/transformer_layer_infer.py b/lightllm/models/vit/layer_infer/transformer_layer_infer.py index 0d55d1b57..cada55e58 100644 --- a/lightllm/models/vit/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/vit/layer_infer/transformer_layer_infer.py @@ -3,10 +3,10 @@ from typing import Tuple from lightllm.models.vit.layer_weights.transformer_layer_weight import ViTTransformerLayerWeight -from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size from lightllm.models.vit.triton_kernel.gelu_vit import gelu_fwd from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm +from lightllm.server.visualserver import get_vit_attn_backend from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager @@ -105,7 +105,7 @@ def _context_attention_kernel(self, q, k, v) -> torch.Tensor: q, k, v, out = map(reshape, (q, k, v, out)) cu_seqlens = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device) * seq_len max_seqlen = seq_len - flash_attention_fwd(q, k, v, out, cu_seqlens, max_seqlen) + get_vit_attn_backend()(q, k, v, out, cu_seqlens, max_seqlen) return out.reshape(batch_size, seq_len, -1) def _get_o(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 44cc38822..199f0dccc 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -333,6 +333,16 @@ def make_argument_parser() -> argparse.ArgumentParser: auto: automatically select best backend based on GPU and available packages (priority: fa3 > flashinfer > triton)""", ) + parser.add_argument( + "--vit_att_backend", + type=str, + nargs="+", + choices=["auto", "triton", "fa3", "sdpa"], + default=["auto"], + help="""vit attention kernel used in vlm. + auto: automatically select best backend based on GPU and available packages + (priority: fa3 > sdpa > triton)""", + ) parser.add_argument( "--llm_kv_type", type=str, diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 239cebfdd..0608d3972 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -121,6 +121,7 @@ class StartArgs: llm_decode_att_backend: List[str] = field( default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]} ) + vit_att_backend: List[str] = field(default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "sdpa"]}) llm_kv_type: str = field(default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv"]}) llm_kv_quant_group_size: int = field(default=8) sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "sglang_kernel"]}) diff --git a/lightllm/server/visualserver/__init__.py b/lightllm/server/visualserver/__init__.py index e69de29bb..c8c108938 100644 --- a/lightllm/server/visualserver/__init__.py +++ b/lightllm/server/visualserver/__init__.py @@ -0,0 +1,16 @@ +from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend +from lightllm.common.basemodel.attention_vit.create_utils import get_vit_att_backend_class + +VIT_ATTN_BACKEND: BaseVitAttBackend = None + + +def init_vit_att_backend(): + global VIT_ATTN_BACKEND + VIT_ATTN_BACKEND = get_vit_att_backend_class(index=0)() + return + + +def get_vit_attn_backend(): + if VIT_ATTN_BACKEND is None: + raise RuntimeError("VIT_ATTN_BACKEND is not initialized. Call init_vit_att_backend() first.") + return VIT_ATTN_BACKEND.vit_att diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 79434f5b9..674c7a83d 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -24,15 +24,11 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient -from lightllm.common.basemodel.basemodel import BaseAttBackend, get_vit_att_backend_class +from lightllm.server.visualserver import init_vit_att_backend from lightllm.utils.dist_utils import set_global_rank class VisualModelRpcServer(rpyc.Service): - def _init_vit_att_backend(self): - self.vit_att_backend: BaseAttBackend = get_vit_att_backend_class(index=0)(model=self) - return - def exposed_init_model(self, kvargs): kvargs = obtain(kvargs) import torch @@ -48,7 +44,6 @@ def exposed_init_model(self, kvargs): self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.data_type = kvargs["data_type"] - set_global_rank(kvargs["tp_rank_id"]) # 这里看后面怎么改 init_vision_distributed_env(kvargs) model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) @@ -87,8 +82,7 @@ def exposed_init_model(self, kvargs): else: raise Exception(f"can not support {self.model_type} now") - self._init_vit_att_backend() - self.model._init_vit_att(self.vit_att_backend.create_vit_att_state()) + init_vit_att_backend() self.model.load_model(weight_dir) self.model = self.model.cuda() self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=True) diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index 7f1c2b493..919f379b9 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -43,6 +43,7 @@ def test_model_inference(args): "disable_cudagraph": args.disable_cudagraph, "llm_prefill_att_backend": args.llm_prefill_att_backend, "llm_decode_att_backend": args.llm_decode_att_backend, + "vit_att_backend": args.vit_att_backend, "llm_kv_type": args.llm_kv_type, "llm_kv_quant_group_size": args.llm_kv_quant_group_size, } From 31cf3b0299c24a06012e29502733c649ee60f19a Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Mon, 26 Jan 2026 10:34:51 +0000 Subject: [PATCH 10/27] fix0126 --- lightllm/models/qwen2_vl/qwen2_visual.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index 8cf23bd57..0e2af0cbb 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -234,11 +234,6 @@ def _init_datatype(self): raise ValueError(f"Unsupport datatype {self.data_type}!") return - def _init_vit_att(self, vit_att): - for blk in self.blocks: - blk.attn.vit_att_backend = vit_att - return - def load_model(self, weight_dir): processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") From f0bfc1cd5e9b17fa10d10b14270b86df9c34c1f4 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Mon, 26 Jan 2026 10:35:43 +0000 Subject: [PATCH 11/27] fix0126 --- lightllm/models/qwen3_vl/qwen3_visual.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lightllm/models/qwen3_vl/qwen3_visual.py b/lightllm/models/qwen3_vl/qwen3_visual.py index e3e350729..00ad6c05a 100644 --- a/lightllm/models/qwen3_vl/qwen3_visual.py +++ b/lightllm/models/qwen3_vl/qwen3_visual.py @@ -129,6 +129,7 @@ def __init__( ): super().__init__() self.data_type = kvargs.get("data_type", "bfloat16") + self.depth = depth self.out_hidden_size = out_hidden_size self.hidden_size = hidden_size From 7f0fd35b292a0c6c4d0dcc26b07fd018a5c364a9 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Mon, 26 Jan 2026 10:37:54 +0000 Subject: [PATCH 12/27] fix0126 --- lightllm/common/basemodel/attention_vit/base_att.py | 4 ++-- lightllm/common/basemodel/attention_vit/fa3/fp.py | 2 +- lightllm/common/basemodel/attention_vit/sdpa/fp.py | 2 +- lightllm/common/basemodel/attention_vit/triton/fp.py | 2 +- lightllm/server/visualserver/__init__.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/lightllm/common/basemodel/attention_vit/base_att.py b/lightllm/common/basemodel/attention_vit/base_att.py index 405aeb245..49bf6ad74 100644 --- a/lightllm/common/basemodel/attention_vit/base_att.py +++ b/lightllm/common/basemodel/attention_vit/base_att.py @@ -4,7 +4,7 @@ class BaseVitAttBackend(ABC): """ - 用于创建支持各种不同的AttBackend, 如 fa3, flashinfer, triton 实现等, + 用于创建支持各种不同的AttBackend, 如 fa3, sdpa, triton 实现等, 这个是单列模式, 每种backend只有一个实例 """ @@ -26,7 +26,7 @@ def __init__(self): pass @abstractmethod - def vit_att( + def _vit_att_fwd( self, q: torch.Tensor, k: torch.Tensor, diff --git a/lightllm/common/basemodel/attention_vit/fa3/fp.py b/lightllm/common/basemodel/attention_vit/fa3/fp.py index cec91165d..e77a0cec7 100644 --- a/lightllm/common/basemodel/attention_vit/fa3/fp.py +++ b/lightllm/common/basemodel/attention_vit/fa3/fp.py @@ -4,7 +4,7 @@ class Fa3VitAttBackend(BaseVitAttBackend): - def vit_att( + def _vit_att_fwd( self, q: torch.Tensor, k: torch.Tensor, diff --git a/lightllm/common/basemodel/attention_vit/sdpa/fp.py b/lightllm/common/basemodel/attention_vit/sdpa/fp.py index 7f57d98a2..9c7d5e311 100644 --- a/lightllm/common/basemodel/attention_vit/sdpa/fp.py +++ b/lightllm/common/basemodel/attention_vit/sdpa/fp.py @@ -4,7 +4,7 @@ class SdpaVitAttBackend(BaseVitAttBackend): - def vit_att( + def _vit_att_fwd( self, q: torch.Tensor, k: torch.Tensor, diff --git a/lightllm/common/basemodel/attention_vit/triton/fp.py b/lightllm/common/basemodel/attention_vit/triton/fp.py index 51e47f056..c38a46633 100644 --- a/lightllm/common/basemodel/attention_vit/triton/fp.py +++ b/lightllm/common/basemodel/attention_vit/triton/fp.py @@ -4,7 +4,7 @@ class TritonVitAttBackend(BaseVitAttBackend): - def vit_att( + def _vit_att_fwd( self, q: torch.Tensor, k: torch.Tensor, diff --git a/lightllm/server/visualserver/__init__.py b/lightllm/server/visualserver/__init__.py index c8c108938..026458447 100644 --- a/lightllm/server/visualserver/__init__.py +++ b/lightllm/server/visualserver/__init__.py @@ -13,4 +13,4 @@ def init_vit_att_backend(): def get_vit_attn_backend(): if VIT_ATTN_BACKEND is None: raise RuntimeError("VIT_ATTN_BACKEND is not initialized. Call init_vit_att_backend() first.") - return VIT_ATTN_BACKEND.vit_att + return VIT_ATTN_BACKEND._vit_att_fwd From a5aa81782119b0da90390c5dd1705352fa5ac92a Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Mon, 26 Jan 2026 10:46:52 +0000 Subject: [PATCH 13/27] fix0126 --- lightllm/server/visualserver/model_infer/model_rpc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 674c7a83d..a81e89efe 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -45,6 +45,7 @@ def exposed_init_model(self, kvargs): self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.data_type = kvargs["data_type"] init_vision_distributed_env(kvargs) + init_vit_att_backend() model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) try: @@ -82,7 +83,6 @@ def exposed_init_model(self, kvargs): else: raise Exception(f"can not support {self.model_type} now") - init_vit_att_backend() self.model.load_model(weight_dir) self.model = self.model.cuda() self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=True) From 04ff9ab318080a81002822886c2b05d94ed54868 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Tue, 27 Jan 2026 06:40:23 +0000 Subject: [PATCH 14/27] add-xformers --- .../basemodel/attention_vit/create_utils.py | 27 ++++++++++-- .../attention_vit/xformers/__init__.py | 0 .../basemodel/attention_vit/xformers/fp.py | 42 +++++++++++++++++++ lightllm/server/api_cli.py | 4 +- lightllm/server/core/objs/start_args_type.py | 4 +- 5 files changed, 71 insertions(+), 6 deletions(-) create mode 100644 lightllm/common/basemodel/attention_vit/xformers/__init__.py create mode 100644 lightllm/common/basemodel/attention_vit/xformers/fp.py diff --git a/lightllm/common/basemodel/attention_vit/create_utils.py b/lightllm/common/basemodel/attention_vit/create_utils.py index 8983ea087..fee5a7bf1 100644 --- a/lightllm/common/basemodel/attention_vit/create_utils.py +++ b/lightllm/common/basemodel/attention_vit/create_utils.py @@ -6,23 +6,32 @@ from lightllm.common.basemodel.attention_vit.fa3.fp import Fa3VitAttBackend from lightllm.common.basemodel.attention_vit.triton.fp import TritonVitAttBackend from lightllm.common.basemodel.attention_vit.sdpa.fp import SdpaVitAttBackend +from lightllm.common.basemodel.attention_vit.xformers.fp import XformersVitAttBackend logger = init_logger(__name__) -vit_att_backend = {"triton": TritonVitAttBackend, "sdpa": SdpaVitAttBackend, "fa3": Fa3VitAttBackend} +vit_att_backend = { + "triton": TritonVitAttBackend, + "sdpa": SdpaVitAttBackend, + "fa3": Fa3VitAttBackend, + "xformers": XformersVitAttBackend, +} -def get_vit_att_backend_class(index=0, priority_list: list = ["fa3", "sdpa", "triton"]) -> BaseVitAttBackend: +def get_vit_att_backend_class( + index=0, priority_list: list = ["fa3", "xformers", "sdpa", "triton"] +) -> BaseVitAttBackend: args = get_env_start_args() backend_str = args.vit_att_backend[index] if backend_str != "auto": + logger.info(f"Selected {backend_str} backend for VIT") return vit_att_backend[backend_str] else: return _select_vit_backend(priority_list=priority_list) -def _select_vit_backend(priority_list: list = ["fa3", "sdpa", "triton"]) -> type: +def _select_vit_backend(priority_list: list = ["fa3", "xformers", "sdpa", "triton"]) -> type: """Auto-select the best available backend with validation for VIT. Priority: FA3 > Sdpa > Triton @@ -43,6 +52,8 @@ def _select_vit_backend(priority_list: list = ["fa3", "sdpa", "triton"]) -> type def validate(backend_name: str) -> bool: if backend_name == "fa3": validate_ok = _validate_fa3() + elif backend_name == "xformers": + validate_ok = _validate_xformers() elif backend_name == "sdpa": validate_ok = _validate_sdpa() elif backend_name == "triton": @@ -91,6 +102,16 @@ def _validate_fa3(): return True +def _validate_xformers(): + """Validate Xformers Attn""" + from xformers import ops as xformers_ops + + if xformers_ops is None: + return False + + return True + + def _validate_sdpa(): """Validate SDPA Attn""" from torch.nn.functional import scaled_dot_product_attention diff --git a/lightllm/common/basemodel/attention_vit/xformers/__init__.py b/lightllm/common/basemodel/attention_vit/xformers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/common/basemodel/attention_vit/xformers/fp.py b/lightllm/common/basemodel/attention_vit/xformers/fp.py new file mode 100644 index 000000000..874cdef85 --- /dev/null +++ b/lightllm/common/basemodel/attention_vit/xformers/fp.py @@ -0,0 +1,42 @@ +import torch +import torch.nn.functional as F +from xformers import ops as xformers_ops +from xformers.ops import fmha +from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend + + +class XformersVitAttBackend(BaseVitAttBackend): + @torch.no_grad() + def _vit_att_fwd( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + ) -> torch.Tensor: + assert q.ndim == k.ndim == v.ndim == o.ndim == 3 + assert cu_seqlens is not None and cu_seqlens.ndim == 1 + assert q.shape == k.shape == v.shape == o.shape + + T, H, D = q.shape + + cu_cpu = cu_seqlens if cu_seqlens.device.type == "cpu" else cu_seqlens.to("cpu") + seqlens = (cu_cpu[1:] - cu_cpu[:-1]).to(torch.int64).tolist() + seqlens = [int(L) for L in seqlens if int(L) > 0] + + if len(seqlens) == 0: + return o + if max_seqlen: + assert max(seqlens) <= max_seqlen + + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens, device=q.device) + + q_ = q.unsqueeze(0) # [1, T, H, D] + k_ = k.unsqueeze(0) # [1, T, H, D] + v_ = v.unsqueeze(0) # [1, T, H, D] + + out = xformers_ops.memory_efficient_attention(q_, k_, v_, attn_bias=attn_bias, p=0.0) + o.copy_(out.squeeze(0)) # [T, H, D] + return o diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 199f0dccc..b6fffce4e 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -337,11 +337,11 @@ def make_argument_parser() -> argparse.ArgumentParser: "--vit_att_backend", type=str, nargs="+", - choices=["auto", "triton", "fa3", "sdpa"], + choices=["auto", "triton", "fa3", "sdpa", "xformers"], default=["auto"], help="""vit attention kernel used in vlm. auto: automatically select best backend based on GPU and available packages - (priority: fa3 > sdpa > triton)""", + (priority: fa3 > xformers > sdpa > triton)""", ) parser.add_argument( "--llm_kv_type", diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 0608d3972..d914a1736 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -121,7 +121,9 @@ class StartArgs: llm_decode_att_backend: List[str] = field( default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]} ) - vit_att_backend: List[str] = field(default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "sdpa"]}) + vit_att_backend: List[str] = field( + default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "sdpa", "xformers"]} + ) llm_kv_type: str = field(default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv"]}) llm_kv_quant_group_size: int = field(default=8) sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "sglang_kernel"]}) From b7df5ce6aa593a91e3ce52cb64d081d098a3f1e2 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Tue, 27 Jan 2026 07:56:05 +0000 Subject: [PATCH 15/27] add-xformers --- lightllm/common/basemodel/attention_vit/sdpa/fp.py | 7 +++---- lightllm/common/basemodel/attention_vit/xformers/fp.py | 5 +---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/lightllm/common/basemodel/attention_vit/sdpa/fp.py b/lightllm/common/basemodel/attention_vit/sdpa/fp.py index 9c7d5e311..aea8b19bb 100644 --- a/lightllm/common/basemodel/attention_vit/sdpa/fp.py +++ b/lightllm/common/basemodel/attention_vit/sdpa/fp.py @@ -16,13 +16,12 @@ def _vit_att_fwd( assert q.ndim == k.ndim == v.ndim == o.ndim == 3 assert cu_seqlens is not None and cu_seqlens.ndim == 1 - cu = cu_seqlens.to(device=q.device) - B = cu.numel() - 1 + B = cu_seqlens.numel() - 1 with torch.no_grad(): for b in range(B): - s = int(cu[b].item()) - e = int(cu[b + 1].item()) + s = int(cu_seqlens[b].item()) + e = int(cu_seqlens[b + 1].item()) L = e - s if L <= 0: continue diff --git a/lightllm/common/basemodel/attention_vit/xformers/fp.py b/lightllm/common/basemodel/attention_vit/xformers/fp.py index 874cdef85..be1d33948 100644 --- a/lightllm/common/basemodel/attention_vit/xformers/fp.py +++ b/lightllm/common/basemodel/attention_vit/xformers/fp.py @@ -20,10 +20,7 @@ def _vit_att_fwd( assert cu_seqlens is not None and cu_seqlens.ndim == 1 assert q.shape == k.shape == v.shape == o.shape - T, H, D = q.shape - - cu_cpu = cu_seqlens if cu_seqlens.device.type == "cpu" else cu_seqlens.to("cpu") - seqlens = (cu_cpu[1:] - cu_cpu[:-1]).to(torch.int64).tolist() + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).to(torch.int64).tolist() seqlens = [int(L) for L in seqlens if int(L) > 0] if len(seqlens) == 0: From 91805c1ecfb75d595a1112592bbaae0ea0c65417 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Tue, 27 Jan 2026 10:47:25 +0000 Subject: [PATCH 16/27] fix0127 --- lightllm/common/basemodel/attention_vit/create_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/common/basemodel/attention_vit/create_utils.py b/lightllm/common/basemodel/attention_vit/create_utils.py index fee5a7bf1..3aa23f16d 100644 --- a/lightllm/common/basemodel/attention_vit/create_utils.py +++ b/lightllm/common/basemodel/attention_vit/create_utils.py @@ -34,7 +34,7 @@ def get_vit_att_backend_class( def _select_vit_backend(priority_list: list = ["fa3", "xformers", "sdpa", "triton"]) -> type: """Auto-select the best available backend with validation for VIT. - Priority: FA3 > Sdpa > Triton + Priority: FA3 > Xformers > Sdpa > Triton Each backend is validated in a subprocess with ground truth checks. """ backend_map = vit_att_backend From 5400970800b19756c65bd783bb4d23ec1ee3133d Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Tue, 27 Jan 2026 11:00:04 +0000 Subject: [PATCH 17/27] fix0127 --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 8d9a011be..a3b9473f8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -87,6 +87,7 @@ librosa==0.11.0 cuda_bindings==12.9.0 orjson==3.11.2 setproctitle==1.3.6 +xformers==0.0.32.post1 xxhash==3.6.0 torchvision==0.23.0 interegular==0.3.3 From d82dbe42251a64c14c05071d64359b678c5802f4 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 28 Jan 2026 02:38:21 +0000 Subject: [PATCH 18/27] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=96=87=E6=A1=A3?= =?UTF-8?q?=E8=AF=B4=E6=98=8E=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/CN/source/tutorial/api_server_args.rst | 11 +++++++++++ docs/EN/source/tutorial/api_server_args.rst | 11 +++++++++++ 2 files changed, 22 insertions(+) diff --git a/docs/CN/source/tutorial/api_server_args.rst b/docs/CN/source/tutorial/api_server_args.rst index b7f6312a6..003fec32e 100644 --- a/docs/CN/source/tutorial/api_server_args.rst +++ b/docs/CN/source/tutorial/api_server_args.rst @@ -284,6 +284,17 @@ PD 分离模式参数 为 ViT 构建分布式环境的 NCCL 端口列表,例如 29500 29501 29502,默认为 [29500] +.. option:: --vit_att_backend + + 设置 ViT 使用的注意力后端。可选值为: + + * ``auto``: 自动选择最佳后端(默认值),优先级为 fa3 > xformers > sdpa > triton + * ``fa3``: 使用 Flash-Attention 3 后端 + * ``xformers``: 使用 xformers 后端 + * ``sdpa``: 使用 sdpa 后端 + * ``triton``: 使用 Triton 后端 + + 性能优化参数 ------------ diff --git a/docs/EN/source/tutorial/api_server_args.rst b/docs/EN/source/tutorial/api_server_args.rst index 18fe54c55..d6f3bd80b 100644 --- a/docs/EN/source/tutorial/api_server_args.rst +++ b/docs/EN/source/tutorial/api_server_args.rst @@ -282,6 +282,17 @@ Multimodal Parameters List of NCCL ports for ViT, e.g., 29500 29501 29502, default is [29500] +.. option:: --vit_att_backend + + Set the attention backend for ViT. Available options: + + * ``auto``: Automatically select the best backend (default), with priority fa3 > xformers > sdpa > triton + * ``fa3``: Use Flash-Attention 3 backend + * ``xformers``: Use xformers backend + * ``sdpa``: Use sdpa backend + * ``triton``: Use Triton backend + + Performance Optimization Parameters ----------------------------------- From fb34c76bb88dea688580f55a3b6528da6e03358a Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 28 Jan 2026 05:54:53 +0000 Subject: [PATCH 19/27] simplify port alloc --- lightllm/server/api_cli.py | 12 +++--- lightllm/server/api_start.py | 43 ++++++++------------ lightllm/server/core/objs/start_args_type.py | 12 +++--- 3 files changed, 30 insertions(+), 37 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index b6fffce4e..a513b6dc3 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -39,7 +39,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--pd_decode_rpyc_port", type=int, - default=42000, + default=None, help="p d mode, decode node used for kv move manager rpyc server port", ) parser.add_argument( @@ -119,7 +119,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--batch_max_tokens", type=int, - default=None, + default=16384, help="max tokens num for new cat batch, it control prefill batch size to Preventing OOM", ) parser.add_argument( @@ -210,7 +210,7 @@ def make_argument_parser() -> argparse.ArgumentParser: When deploying in multi-node manner, the value should be set to the IP of the master node""", ) parser.add_argument( - "--nccl_port", type=int, default=28765, help="the nccl_port to build a distributed environment for PyTorch" + "--nccl_port", type=int, default=None, help="the nccl_port to build a distributed environment for PyTorch" ) parser.add_argument( "--use_config_server_to_init_nccl", @@ -259,7 +259,7 @@ def make_argument_parser() -> argparse.ArgumentParser: ) parser.add_argument("--disable_dynamic_prompt_cache", action="store_true", help="disable dynamic prompt cache") - parser.add_argument("--chunked_prefill_size", type=int, default=4096, help="chunked prefill size") + parser.add_argument("--chunked_prefill_size", type=int, default=None, help="chunked prefill size") parser.add_argument("--disable_chunked_prefill", action="store_true", help="whether to disable chunked prefill") parser.add_argument("--diverse_mode", action="store_true", help="diversity generation mode") parser.add_argument("--token_healing_mode", action="store_true", help="code model infer mode") @@ -400,7 +400,7 @@ def make_argument_parser() -> argparse.ArgumentParser: ) parser.add_argument("--push_interval", type=int, default=10, help="interval of pushing monitoring metrics") parser.add_argument( - "--visual_infer_batch_size", type=int, default=1, help="number of images to process in each inference batch" + "--visual_infer_batch_size", type=int, default=None, help="number of images to process in each inference batch" ) parser.add_argument( "--visual_send_batch_size", @@ -420,7 +420,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "--visual_nccl_ports", nargs="+", type=int, - default=[29500], + default=None, help="List of NCCL ports to build a distributed environment for Vit, e.g., 29500 29501 29502", ) parser.add_argument( diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 3ae3789f4..576e558a4 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -149,18 +149,12 @@ def normal_or_p_d_start(args): else: args.visual_gpu_ids = args.visual_gpu_ids[:total_required_gpus] - # 检查visual_nccl_port数量是否足够 - if len(args.visual_nccl_ports) < args.visual_dp: - raise ValueError( - f"Not enough visual_nccl_ports specified. You need at least {args.visual_dp}, " - f"but got ({len(args.visual_nccl_ports)})." - ) - else: - args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] - if args.visual_dp <= 0: raise ValueError("visual_dp must be a positive integer.") + if args.visual_infer_batch_size is None: + args.visual_infer_batch_size = args.visual_dp + # 检查visual_infer_batch_size是否合理 if args.visual_infer_batch_size // args.visual_dp < 1 or args.visual_infer_batch_size % args.visual_dp != 0: raise ValueError( @@ -171,15 +165,11 @@ def normal_or_p_d_start(args): if args.disable_chunked_prefill: args.chunked_prefill_size = args.max_req_total_len # 普通模式下 - if args.batch_max_tokens is None: - args.batch_max_tokens = args.max_req_total_len - else: - assert args.batch_max_tokens >= args.max_req_total_len, "batch_max_tokens must >= max_req_total_len" + args.batch_max_tokens = args.max_req_total_len else: # chunked 模式下 - if args.batch_max_tokens is None: - args.batch_max_tokens = min(args.max_req_total_len, 2 * args.chunked_prefill_size + 256) - + args.batch_max_tokens = args.batch_max_tokens // args.dp + args.chunked_prefill_size = args.batch_max_tokens // 2 assert ( args.batch_max_tokens >= args.chunked_prefill_size ), "chunked prefill mode, batch_max_tokens must >= chunked_prefill_size, " @@ -203,10 +193,7 @@ def normal_or_p_d_start(args): args.data_type = get_dtype(args.model_dir) assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"] - already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port] - if args.run_mode == "decode": - already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port, args.pd_decode_rpyc_port] - + already_uesd_ports = [args.port] # 提前锁定端口,防止在单个机器上启动多个实列的时候,要到模型启动的时候才能 # 捕获到端口设置冲突的问题 ports_locker = PortLocker(already_uesd_ports) @@ -214,10 +201,11 @@ def normal_or_p_d_start(args): node_world_size = args.tp // args.nnodes can_use_ports = alloc_can_use_network_port( - num=8 + node_world_size + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports + num=10 + node_world_size + args.visual_dp * (args.visual_tp + 1), used_nccl_ports=already_uesd_ports ) logger.info(f"alloced ports: {can_use_ports}") ( + nccl_port, router_port, detokenization_port, http_server_port, @@ -226,16 +214,20 @@ def normal_or_p_d_start(args): cache_port, metric_port, multi_level_kv_cache_port, - ) = can_use_ports[0:8] - can_use_ports = can_use_ports[8:] + pd_decode_rpyc_port, + ) = can_use_ports[0:10] + can_use_ports = can_use_ports[10:] visual_model_tp_ports = [] + visual_nccl_ports = [] for _ in range(args.visual_dp): tp_ports_for_dp = can_use_ports[0 : args.visual_tp] - can_use_ports = can_use_ports[args.visual_tp :] visual_model_tp_ports.append(tp_ports_for_dp) + visual_nccl_ports.append(can_use_ports[args.visual_tp]) + can_use_ports = can_use_ports[args.visual_tp + 1 :] # 将申请好的端口放入args参数中 + args.nccl_port = nccl_port args.router_port = router_port args.detokenization_port = detokenization_port args.http_server_port = http_server_port @@ -244,7 +236,8 @@ def normal_or_p_d_start(args): args.cache_port = cache_port args.metric_port = metric_port args.multi_level_kv_cache_port = multi_level_kv_cache_port - + args.pd_decode_rpyc_port = pd_decode_rpyc_port + args.visual_nccl_ports = visual_nccl_ports # 申请在 p d 分离模式下,会用的端口 args.pd_node_infer_rpyc_ports = can_use_ports[0:node_world_size] # p d 分离模式下用于标识节点的id diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index d914a1736..897c48e93 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -20,7 +20,7 @@ class StartArgs: pd_master_port: int = field(default=1212) config_server_host: str = field(default=None) config_server_port: int = field(default=None) - pd_decode_rpyc_port: int = field(default=42000) + pd_decode_rpyc_port: int = field(default=None) select_p_d_node_strategy: str = field(default=None) model_name: str = field(default="default_model_name") model_dir: Optional[str] = field(default=None) @@ -61,7 +61,7 @@ class StartArgs: node_rank: int = field(default=0) max_req_total_len: int = field(default=2048 + 1024) nccl_host: str = field(default="127.0.0.1") - nccl_port: int = field(default=28765) + nccl_port: int = field(default=None) use_config_server_to_init_nccl: bool = field(default=False) trust_remote_code: bool = field(default=False) disable_log_stats: bool = field(default=False) @@ -71,7 +71,7 @@ class StartArgs: router_max_wait_tokens: int = field(default=1) disable_aggressive_schedule: bool = field(default=False) disable_dynamic_prompt_cache: bool = field(default=False) - chunked_prefill_size: int = field(default=8192) + chunked_prefill_size: int = field(default=None) disable_chunked_prefill: bool = field(default=False) diverse_mode: bool = field(default=False) token_healing_mode: bool = field(default=False) @@ -97,12 +97,12 @@ class StartArgs: job_name: str = field(default="lightllm") grouping_key: List[str] = field(default_factory=list) push_interval: int = field(default=10) - visual_infer_batch_size: int = field(default=1) + visual_infer_batch_size: int = field(default=None) visual_send_batch_size: int = field(default=1) visual_gpu_ids: List[int] = field(default_factory=lambda: [0]) visual_tp: int = field(default=1) visual_dp: int = field(default=1) - visual_nccl_ports: List[int] = field(default_factory=lambda: [29500]) + visual_nccl_ports: List[int] = field(default=None) enable_monitor_auth: bool = field(default=False) disable_cudagraph: bool = field(default=False) enable_prefill_cudagraph: bool = field(default=False) @@ -110,7 +110,7 @@ class StartArgs: graph_max_batch_size: int = field(default=256) graph_split_batch_size: int = field(default=32) graph_grow_step_size: int = field(default=16) - graph_max_len_in_batch: int = field(default=8192) + graph_max_len_in_batch: int = field(default=0) quant_type: Optional[str] = field(default=None) quant_cfg: Optional[str] = field(default=None) vit_quant_type: Optional[str] = field(default=None) From ef5f4a3b2c5eaa7e51e73366d25db2d6fd93d84d Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 28 Jan 2026 06:01:01 +0000 Subject: [PATCH 20/27] safe import of xformers --- .../common/basemodel/attention_vit/xformers/fp.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/lightllm/common/basemodel/attention_vit/xformers/fp.py b/lightllm/common/basemodel/attention_vit/xformers/fp.py index be1d33948..04772b20c 100644 --- a/lightllm/common/basemodel/attention_vit/xformers/fp.py +++ b/lightllm/common/basemodel/attention_vit/xformers/fp.py @@ -1,7 +1,16 @@ import torch import torch.nn.functional as F -from xformers import ops as xformers_ops -from xformers.ops import fmha +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) +try: + from xformers import ops as xformers_ops + from xformers.ops import fmha +except ImportError: + xformers_ops = None + fmha = None + logger.warning("xformers or flash-attn is not installed, please ensure xformers and flash-attn is installed.") + from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend From f790bc40d1ec161325b17937661346c047575db9 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Wed, 28 Jan 2026 06:06:51 +0000 Subject: [PATCH 21/27] fix0128 --- .../basemodel/attention_vit/create_utils.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/lightllm/common/basemodel/attention_vit/create_utils.py b/lightllm/common/basemodel/attention_vit/create_utils.py index 3aa23f16d..1b543a1b6 100644 --- a/lightllm/common/basemodel/attention_vit/create_utils.py +++ b/lightllm/common/basemodel/attention_vit/create_utils.py @@ -104,19 +104,20 @@ def _validate_fa3(): def _validate_xformers(): """Validate Xformers Attn""" - from xformers import ops as xformers_ops + try: + import xformers.ops + from xformers.ops import fmha - if xformers_ops is None: + return True + except ImportError: return False - return True - def _validate_sdpa(): """Validate SDPA Attn""" - from torch.nn.functional import scaled_dot_product_attention + try: + from torch.nn.functional import scaled_dot_product_attention - if scaled_dot_product_attention is None: + return True + except ImportError: return False - - return True From 1cbbe1bb66abe6ec5c92a9c90fe9baa552dd334b Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Wed, 28 Jan 2026 07:43:23 +0000 Subject: [PATCH 22/27] fix0128 --- .../basemodel/attention_vit/create_utils.py | 78 +---------------- .../common/basemodel/attention_vit/sdpa/fp.py | 6 +- lightllm/utils/backend_validator.py | 84 ++++++++++++++++++- lightllm/utils/dist_utils.py | 9 ++ 4 files changed, 97 insertions(+), 80 deletions(-) diff --git a/lightllm/common/basemodel/attention_vit/create_utils.py b/lightllm/common/basemodel/attention_vit/create_utils.py index 1b543a1b6..63b3a182e 100644 --- a/lightllm/common/basemodel/attention_vit/create_utils.py +++ b/lightllm/common/basemodel/attention_vit/create_utils.py @@ -1,7 +1,7 @@ import torch from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args -from lightllm.utils.backend_validator import _validate_triton, _compute_ground_truth +from lightllm.utils.backend_validator import validate_vit from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend from lightllm.common.basemodel.attention_vit.fa3.fp import Fa3VitAttBackend from lightllm.common.basemodel.attention_vit.triton.fp import TritonVitAttBackend @@ -40,84 +40,10 @@ def _select_vit_backend(priority_list: list = ["fa3", "xformers", "sdpa", "trito backend_map = vit_att_backend for backend_name in priority_list: - if validate(backend_name): + if validate_vit(backend_name): logger.info(f"Auto-selected {backend_name} backend (validated) for VIT") return backend_map[backend_name] # Fallback to triton without validation (should not happen) logger.warning("No backend validation succeeded, falling back to triton") return backend_map["triton"] - - -def validate(backend_name: str) -> bool: - if backend_name == "fa3": - validate_ok = _validate_fa3() - elif backend_name == "xformers": - validate_ok = _validate_xformers() - elif backend_name == "sdpa": - validate_ok = _validate_sdpa() - elif backend_name == "triton": - validate_ok = _validate_triton() - else: - raise ValueError("not suuported vit attn backend") - return validate_ok - - -def _validate_fa3(): - """Validate FA3 with ground truth.""" - from lightllm.utils.sgl_utils import flash_attn_varlen_func - - if flash_attn_varlen_func is None: - return False - - batch, heads, seq, dim = 1, 4, 8, 64 - q = torch.randn(batch, heads, seq, dim, dtype=torch.bfloat16, device="cuda") - k = torch.randn(batch, heads, seq, dim, dtype=torch.bfloat16, device="cuda") - v = torch.randn(batch, heads, seq, dim, dtype=torch.bfloat16, device="cuda") - - expected = _compute_ground_truth(q, k, v) - - q_flat = q.transpose(1, 2).reshape(batch * seq, heads, dim) - k_flat = k.transpose(1, 2).reshape(batch * seq, heads, dim) - v_flat = v.transpose(1, 2).reshape(batch * seq, heads, dim) - cu_seqlens = torch.arange(0, batch * seq + 1, seq, dtype=torch.int32, device="cuda") - - out = flash_attn_varlen_func( - q=q_flat, - k=k_flat, - v=v_flat, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=seq, - max_seqlen_k=seq, - softmax_scale=1.0 / (dim ** 0.5), - causal=True, - ) - out = out.reshape(batch, seq, heads, dim).transpose(1, 2) - torch.cuda.synchronize() - - if not torch.allclose(out, expected, rtol=1e-2, atol=1e-2): - return False - - return True - - -def _validate_xformers(): - """Validate Xformers Attn""" - try: - import xformers.ops - from xformers.ops import fmha - - return True - except ImportError: - return False - - -def _validate_sdpa(): - """Validate SDPA Attn""" - try: - from torch.nn.functional import scaled_dot_product_attention - - return True - except ImportError: - return False diff --git a/lightllm/common/basemodel/attention_vit/sdpa/fp.py b/lightllm/common/basemodel/attention_vit/sdpa/fp.py index aea8b19bb..874452f0f 100644 --- a/lightllm/common/basemodel/attention_vit/sdpa/fp.py +++ b/lightllm/common/basemodel/attention_vit/sdpa/fp.py @@ -15,13 +15,13 @@ def _vit_att_fwd( ) -> torch.Tensor: assert q.ndim == k.ndim == v.ndim == o.ndim == 3 assert cu_seqlens is not None and cu_seqlens.ndim == 1 - + cu_seqlens = cu_seqlens.detach().to("cpu") B = cu_seqlens.numel() - 1 with torch.no_grad(): for b in range(B): - s = int(cu_seqlens[b].item()) - e = int(cu_seqlens[b + 1].item()) + s = int(cu_seqlens[b]) + e = int(cu_seqlens[b + 1]) L = e - s if L <= 0: continue diff --git a/lightllm/utils/backend_validator.py b/lightllm/utils/backend_validator.py index 1f02c3952..142085ecf 100644 --- a/lightllm/utils/backend_validator.py +++ b/lightllm/utils/backend_validator.py @@ -4,7 +4,7 @@ import os import torch from lightllm.utils.log_utils import init_logger -from lightllm.utils.dist_utils import get_global_rank +from lightllm.utils.dist_utils import get_global_rank, get_global_vit_rank from functools import lru_cache logger = init_logger(__name__) @@ -121,6 +121,72 @@ def _softmax_kernel(input_ptr, output_ptr, n_cols, BLOCK: tl.constexpr): return True, None +def _validate_xformers(): + """Validate Xformers Attn with ground truth.""" + try: + import torch + + if not torch.cuda.is_available(): + return False, "CUDA not available" + + import xformers.ops as xformers_ops + from xformers.ops import fmha + except Exception as e: + return False, f"xformers import failed: {type(e).__name__}: {e}" + + batch, heads, seq, dim = 1, 4, 8, 64 + dtype = torch.bfloat16 + device = "cuda" + + q = torch.randn(batch, heads, seq, dim, dtype=dtype, device=device) + k = torch.randn(batch, heads, seq, dim, dtype=dtype, device=device) + v = torch.randn(batch, heads, seq, dim, dtype=dtype, device=device) + + expected = _compute_ground_truth(q, k, v, is_causal=False) + + q_bmhd = q.transpose(1, 2).contiguous() # (B, seq, heads, dim) + k_bmhd = k.transpose(1, 2).contiguous() + v_bmhd = v.transpose(1, 2).contiguous() + + try: + out = xformers_ops.memory_efficient_attention(q_bmhd, k_bmhd, v_bmhd, p=0.0) + except Exception as e: + return False, f"xformers attention run failed: {type(e).__name__}: {e}" + + out = out.transpose(1, 2).contiguous() + + if not torch.allclose(out, expected, rtol=1e-2, atol=1e-2): + return False, f"Output mismatch: max diff {(out - expected).abs().max().item():.6f}" + + return True, None + + +def _validate_sdpa(): + """Validate SDPA Attn with ground truth.""" + try: + import torch + from torch.nn.functional import scaled_dot_product_attention + except Exception as e: + return False, f"SDPA import failed: {type(e).__name__}: {e}" + + batch, heads, seq, dim = 1, 4, 8, 64 + dtype = torch.bfloat16 + device = "cuda" + + q = torch.randn(batch, heads, seq, dim, dtype=dtype, device=device) + k = torch.randn(batch, heads, seq, dim, dtype=dtype, device=device) + v = torch.randn(batch, heads, seq, dim, dtype=dtype, device=device) + + expected = _compute_ground_truth(q, k, v, is_causal=False) + + out = scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False) + + if not torch.allclose(out, expected, rtol=1e-2, atol=1e-2): + return False, f"Output mismatch: max diff {(out - expected).abs().max().item():.6f}" + + return True, None + + def _run_in_subprocess(backend_name, pipe): """Run validation in subprocess with suppressed output.""" import sys @@ -133,6 +199,10 @@ def _run_in_subprocess(backend_name, pipe): try: if backend_name == "fa3": success, err = _validate_fa3() + elif backend_name == "xformers": + success, err = _validate_xformers() + elif backend_name == "sdpa": + success, err = _validate_sdpa() elif backend_name == "flashinfer": success, err = _validate_flashinfer() elif backend_name == "triton": @@ -158,6 +228,18 @@ def validate(backend_name: str) -> bool: return validate_ok +@lru_cache(maxsize=None) +def validate_vit(backend_name: str) -> bool: + if get_global_vit_rank() == 0: + validate_ok = _validate(backend_name) + torch.distributed.broadcast_object_list([validate_ok], src=0) + else: + validate_ok = [None] + torch.distributed.broadcast_object_list(validate_ok, src=0) + validate_ok = validate_ok[0] + return validate_ok + + def _validate(backend_name: str) -> bool: """Validate backend in subprocess with ground truth check.""" try: diff --git a/lightllm/utils/dist_utils.py b/lightllm/utils/dist_utils.py index 65ac401d4..77d11abe4 100644 --- a/lightllm/utils/dist_utils.py +++ b/lightllm/utils/dist_utils.py @@ -63,6 +63,7 @@ def init_vision_distributed_env(kvargs): set_current_rank_in_dp(tp_rank_id) visual_gpu_ids = kvargs["visual_gpu_ids"] device_id = visual_gpu_ids[kvargs["vit_rank_id"]] + set_global_vit_rank(device_id) set_current_device_id(device_id) torch.cuda.set_device(device_id) dist.init_process_group( @@ -112,6 +113,14 @@ def init_distributed_env(kvargs): del _a +def set_global_vit_rank(global_vit_rank: int): + set_environ("LIGHTLLM_GLOBAL_VIT_RANK", global_vit_rank) + + +def get_global_vit_rank(): + return int(get_environ("LIGHTLLM_GLOBAL_VIT_RANK")) + + def set_global_rank(global_rank: int): set_environ("LIGHTLLM_GLOBAL_RANK", global_rank) From 1835cb7a17351ca082f3840c9cfae2ada8e94882 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 28 Jan 2026 08:56:49 +0000 Subject: [PATCH 23/27] fix chunked_prefill_size, batch_max_tokens --- lightllm/server/api_cli.py | 2 +- lightllm/server/api_start.py | 17 ++++++++++++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index a513b6dc3..09fd52138 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -119,7 +119,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--batch_max_tokens", type=int, - default=16384, + default=None, help="max tokens num for new cat batch, it control prefill batch size to Preventing OOM", ) parser.add_argument( diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 576e558a4..8e5b23d09 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -165,11 +165,17 @@ def normal_or_p_d_start(args): if args.disable_chunked_prefill: args.chunked_prefill_size = args.max_req_total_len # 普通模式下 - args.batch_max_tokens = args.max_req_total_len + if args.batch_max_tokens is None: + args.batch_max_tokens = args.max_req_total_len + else: + assert args.batch_max_tokens >= args.max_req_total_len, f"batch_max_tokens must >= max_req_total_len" + f"but got {args.batch_max_tokens}, {args.max_req_total_len}" else: # chunked 模式下 - args.batch_max_tokens = args.batch_max_tokens // args.dp - args.chunked_prefill_size = args.batch_max_tokens // 2 + if args.batch_max_tokens is None: + args.batch_max_tokens = 16384 // args.dp + if args.chunked_prefill_size is None: + args.chunked_prefill_size = args.batch_max_tokens // 2 assert ( args.batch_max_tokens >= args.chunked_prefill_size ), "chunked prefill mode, batch_max_tokens must >= chunked_prefill_size, " @@ -223,8 +229,9 @@ def normal_or_p_d_start(args): for _ in range(args.visual_dp): tp_ports_for_dp = can_use_ports[0 : args.visual_tp] visual_model_tp_ports.append(tp_ports_for_dp) - visual_nccl_ports.append(can_use_ports[args.visual_tp]) - can_use_ports = can_use_ports[args.visual_tp + 1 :] + can_use_ports = can_use_ports[args.visual_tp :] + visual_nccl_ports.append(can_use_ports[0]) + can_use_ports = can_use_ports[1:] # 将申请好的端口放入args参数中 args.nccl_port = nccl_port From 7fc26322ed6ae550591b18a0b7342065af13c672 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Wed, 28 Jan 2026 10:20:15 +0000 Subject: [PATCH 24/27] fix0128 --- .../basemodel/attention_vit/create_utils.py | 28 ++++++++++--------- .../common/basemodel/attention_vit/fa3/fp.py | 2 +- .../common/basemodel/attention_vit/sdpa/fp.py | 2 +- .../basemodel/attention_vit/triton/fp.py | 2 +- .../basemodel/attention_vit/xformers/fp.py | 2 +- lightllm/server/visualserver/__init__.py | 4 +-- lightllm/server/visualserver/manager.py | 4 ++- .../visualserver/model_infer/model_rpc.py | 6 ++-- lightllm/utils/backend_validator.py | 14 +--------- lightllm/utils/dist_utils.py | 9 ------ 10 files changed, 28 insertions(+), 45 deletions(-) diff --git a/lightllm/common/basemodel/attention_vit/create_utils.py b/lightllm/common/basemodel/attention_vit/create_utils.py index 63b3a182e..67f830ba0 100644 --- a/lightllm/common/basemodel/attention_vit/create_utils.py +++ b/lightllm/common/basemodel/attention_vit/create_utils.py @@ -1,7 +1,7 @@ import torch from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args -from lightllm.utils.backend_validator import validate_vit +from lightllm.utils.backend_validator import _validate from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend from lightllm.common.basemodel.attention_vit.fa3.fp import Fa3VitAttBackend from lightllm.common.basemodel.attention_vit.triton.fp import TritonVitAttBackend @@ -19,31 +19,33 @@ } -def get_vit_att_backend_class( - index=0, priority_list: list = ["fa3", "xformers", "sdpa", "triton"] -) -> BaseVitAttBackend: +def get_vit_att_backend_class(backend_name: str) -> BaseVitAttBackend: + vit_att_backend_class = vit_att_backend[backend_name] + return vit_att_backend_class + + +def init_vit_att_backend(index=0, priority_list: list = ["fa3", "xformers", "sdpa", "triton"]) -> str: args = get_env_start_args() - backend_str = args.vit_att_backend[index] - if backend_str != "auto": - logger.info(f"Selected {backend_str} backend for VIT") - return vit_att_backend[backend_str] + backend_name = args.vit_att_backend[index] + if backend_name != "auto": + logger.info(f"Selected {backend_name} backend for VIT") + return backend_name else: return _select_vit_backend(priority_list=priority_list) -def _select_vit_backend(priority_list: list = ["fa3", "xformers", "sdpa", "triton"]) -> type: +def _select_vit_backend(priority_list: list = ["fa3", "xformers", "sdpa", "triton"]) -> str: """Auto-select the best available backend with validation for VIT. Priority: FA3 > Xformers > Sdpa > Triton Each backend is validated in a subprocess with ground truth checks. """ - backend_map = vit_att_backend for backend_name in priority_list: - if validate_vit(backend_name): + if _validate(backend_name): logger.info(f"Auto-selected {backend_name} backend (validated) for VIT") - return backend_map[backend_name] + return backend_name # Fallback to triton without validation (should not happen) logger.warning("No backend validation succeeded, falling back to triton") - return backend_map["triton"] + return "triton" diff --git a/lightllm/common/basemodel/attention_vit/fa3/fp.py b/lightllm/common/basemodel/attention_vit/fa3/fp.py index e77a0cec7..406ff7408 100644 --- a/lightllm/common/basemodel/attention_vit/fa3/fp.py +++ b/lightllm/common/basemodel/attention_vit/fa3/fp.py @@ -4,8 +4,8 @@ class Fa3VitAttBackend(BaseVitAttBackend): + @staticmethod def _vit_att_fwd( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, diff --git a/lightllm/common/basemodel/attention_vit/sdpa/fp.py b/lightllm/common/basemodel/attention_vit/sdpa/fp.py index 874452f0f..6c6da2c2b 100644 --- a/lightllm/common/basemodel/attention_vit/sdpa/fp.py +++ b/lightllm/common/basemodel/attention_vit/sdpa/fp.py @@ -4,8 +4,8 @@ class SdpaVitAttBackend(BaseVitAttBackend): + @staticmethod def _vit_att_fwd( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, diff --git a/lightllm/common/basemodel/attention_vit/triton/fp.py b/lightllm/common/basemodel/attention_vit/triton/fp.py index c38a46633..88867102a 100644 --- a/lightllm/common/basemodel/attention_vit/triton/fp.py +++ b/lightllm/common/basemodel/attention_vit/triton/fp.py @@ -4,8 +4,8 @@ class TritonVitAttBackend(BaseVitAttBackend): + @staticmethod def _vit_att_fwd( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, diff --git a/lightllm/common/basemodel/attention_vit/xformers/fp.py b/lightllm/common/basemodel/attention_vit/xformers/fp.py index 04772b20c..be9411d5c 100644 --- a/lightllm/common/basemodel/attention_vit/xformers/fp.py +++ b/lightllm/common/basemodel/attention_vit/xformers/fp.py @@ -16,8 +16,8 @@ class XformersVitAttBackend(BaseVitAttBackend): @torch.no_grad() + @staticmethod def _vit_att_fwd( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, diff --git a/lightllm/server/visualserver/__init__.py b/lightllm/server/visualserver/__init__.py index 026458447..6bc0923da 100644 --- a/lightllm/server/visualserver/__init__.py +++ b/lightllm/server/visualserver/__init__.py @@ -4,9 +4,9 @@ VIT_ATTN_BACKEND: BaseVitAttBackend = None -def init_vit_att_backend(): +def set_vit_att_backend(backend_name: str): global VIT_ATTN_BACKEND - VIT_ATTN_BACKEND = get_vit_att_backend_class(index=0)() + VIT_ATTN_BACKEND = get_vit_att_backend_class(backend_name) return diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index a389272e5..202c2fc45 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -14,6 +14,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from lightllm.server.multimodal_params import MultimodalParams, ImageItem from .model_infer.model_rpc import start_model_process, VisualModelRpcClient +from lightllm.common.basemodel.attention_vit.create_utils import init_vit_att_backend from lightllm.utils.log_utils import init_logger from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread @@ -63,7 +64,7 @@ def __init__( async def wait_to_model_ready(self): self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)] - + self.vit_attn_backend = init_vit_att_backend(index=0) for dp_rank_id in range(self.vit_dp): tp_ports_each_dp = self.visual_model_rpc_ports[dp_rank_id] for tp_rank_id in range(self.vit_tp): @@ -91,6 +92,7 @@ async def wait_to_model_ready(self): "quant_type": self.args.vit_quant_type, "quant_cfg": self.args.vit_quant_cfg, "max_batch_size": min(self.infer_batch_size // self.vit_dp, 1), + "vit_attn_backend": self.vit_attn_backend, } init_model_ret.append(self.model_rpcs[dp_rank_id][tp_rank_id].init_model(kvargs)) await asyncio.gather(*init_model_ret) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index a81e89efe..8f4a1ee45 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -24,8 +24,7 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient -from lightllm.server.visualserver import init_vit_att_backend -from lightllm.utils.dist_utils import set_global_rank +from lightllm.server.visualserver import set_vit_att_backend class VisualModelRpcServer(rpyc.Service): @@ -44,8 +43,9 @@ def exposed_init_model(self, kvargs): self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.data_type = kvargs["data_type"] + self.vit_attn_backend = kvargs["vit_attn_backend"] + set_vit_att_backend(self.vit_attn_backend) init_vision_distributed_env(kvargs) - init_vit_att_backend() model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) try: diff --git a/lightllm/utils/backend_validator.py b/lightllm/utils/backend_validator.py index 142085ecf..0e2f9c962 100644 --- a/lightllm/utils/backend_validator.py +++ b/lightllm/utils/backend_validator.py @@ -4,7 +4,7 @@ import os import torch from lightllm.utils.log_utils import init_logger -from lightllm.utils.dist_utils import get_global_rank, get_global_vit_rank +from lightllm.utils.dist_utils import get_global_rank from functools import lru_cache logger = init_logger(__name__) @@ -228,18 +228,6 @@ def validate(backend_name: str) -> bool: return validate_ok -@lru_cache(maxsize=None) -def validate_vit(backend_name: str) -> bool: - if get_global_vit_rank() == 0: - validate_ok = _validate(backend_name) - torch.distributed.broadcast_object_list([validate_ok], src=0) - else: - validate_ok = [None] - torch.distributed.broadcast_object_list(validate_ok, src=0) - validate_ok = validate_ok[0] - return validate_ok - - def _validate(backend_name: str) -> bool: """Validate backend in subprocess with ground truth check.""" try: diff --git a/lightllm/utils/dist_utils.py b/lightllm/utils/dist_utils.py index 77d11abe4..65ac401d4 100644 --- a/lightllm/utils/dist_utils.py +++ b/lightllm/utils/dist_utils.py @@ -63,7 +63,6 @@ def init_vision_distributed_env(kvargs): set_current_rank_in_dp(tp_rank_id) visual_gpu_ids = kvargs["visual_gpu_ids"] device_id = visual_gpu_ids[kvargs["vit_rank_id"]] - set_global_vit_rank(device_id) set_current_device_id(device_id) torch.cuda.set_device(device_id) dist.init_process_group( @@ -113,14 +112,6 @@ def init_distributed_env(kvargs): del _a -def set_global_vit_rank(global_vit_rank: int): - set_environ("LIGHTLLM_GLOBAL_VIT_RANK", global_vit_rank) - - -def get_global_vit_rank(): - return int(get_environ("LIGHTLLM_GLOBAL_VIT_RANK")) - - def set_global_rank(global_rank: int): set_environ("LIGHTLLM_GLOBAL_RANK", global_rank) From 6a921ebaa200c59d71df8070779b300cb3944973 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 28 Jan 2026 11:36:59 +0000 Subject: [PATCH 25/27] remove xformers logger --- lightllm/common/basemodel/attention_vit/xformers/fp.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/lightllm/common/basemodel/attention_vit/xformers/fp.py b/lightllm/common/basemodel/attention_vit/xformers/fp.py index be9411d5c..361b5db05 100644 --- a/lightllm/common/basemodel/attention_vit/xformers/fp.py +++ b/lightllm/common/basemodel/attention_vit/xformers/fp.py @@ -1,15 +1,12 @@ import torch import torch.nn.functional as F -from lightllm.utils.log_utils import init_logger -logger = init_logger(__name__) try: from xformers import ops as xformers_ops from xformers.ops import fmha except ImportError: xformers_ops = None fmha = None - logger.warning("xformers or flash-attn is not installed, please ensure xformers and flash-attn is installed.") from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend From df99ab8943f3f104767da0658ffa89a00decf036 Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Wed, 28 Jan 2026 20:41:40 +0800 Subject: [PATCH 26/27] fix nccl port. --- lightllm/server/api_start.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 8e5b23d09..8f2088897 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -200,6 +200,9 @@ def normal_or_p_d_start(args): assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"] already_uesd_ports = [args.port] + if args.nccl_port is not None: + already_uesd_ports.append(args.nccl_port) + # 提前锁定端口,防止在单个机器上启动多个实列的时候,要到模型启动的时候才能 # 捕获到端口设置冲突的问题 ports_locker = PortLocker(already_uesd_ports) @@ -234,7 +237,8 @@ def normal_or_p_d_start(args): can_use_ports = can_use_ports[1:] # 将申请好的端口放入args参数中 - args.nccl_port = nccl_port + if args.nccl_port is None: + args.nccl_port = nccl_port args.router_port = router_port args.detokenization_port = detokenization_port args.http_server_port = http_server_port From 2171142b57faa44f95dbbf1497d6cac918fb882b Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 28 Jan 2026 12:50:38 +0000 Subject: [PATCH 27/27] fix pd_decode_rpyc_port --- lightllm/server/api_start.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 8f2088897..bd8e4db8b 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -202,6 +202,8 @@ def normal_or_p_d_start(args): already_uesd_ports = [args.port] if args.nccl_port is not None: already_uesd_ports.append(args.nccl_port) + if args.pd_decode_rpyc_port is not None: + already_uesd_ports.append(args.pd_decode_rpyc_port) # 提前锁定端口,防止在单个机器上启动多个实列的时候,要到模型启动的时候才能 # 捕获到端口设置冲突的问题 @@ -239,6 +241,8 @@ def normal_or_p_d_start(args): # 将申请好的端口放入args参数中 if args.nccl_port is None: args.nccl_port = nccl_port + if args.pd_decode_rpyc_port is None: + args.pd_decode_rpyc_port = pd_decode_rpyc_port args.router_port = router_port args.detokenization_port = detokenization_port args.http_server_port = http_server_port @@ -247,7 +251,6 @@ def normal_or_p_d_start(args): args.cache_port = cache_port args.metric_port = metric_port args.multi_level_kv_cache_port = multi_level_kv_cache_port - args.pd_decode_rpyc_port = pd_decode_rpyc_port args.visual_nccl_ports = visual_nccl_ports # 申请在 p d 分离模式下,会用的端口 args.pd_node_infer_rpyc_ports = can_use_ports[0:node_world_size]