diff --git a/docs/CN/source/tutorial/api_server_args.rst b/docs/CN/source/tutorial/api_server_args.rst index b7f6312a6..003fec32e 100644 --- a/docs/CN/source/tutorial/api_server_args.rst +++ b/docs/CN/source/tutorial/api_server_args.rst @@ -284,6 +284,17 @@ PD 分离模式参数 为 ViT 构建分布式环境的 NCCL 端口列表,例如 29500 29501 29502,默认为 [29500] +.. option:: --vit_att_backend + + 设置 ViT 使用的注意力后端。可选值为: + + * ``auto``: 自动选择最佳后端(默认值),优先级为 fa3 > xformers > sdpa > triton + * ``fa3``: 使用 Flash-Attention 3 后端 + * ``xformers``: 使用 xformers 后端 + * ``sdpa``: 使用 sdpa 后端 + * ``triton``: 使用 Triton 后端 + + 性能优化参数 ------------ diff --git a/docs/EN/source/tutorial/api_server_args.rst b/docs/EN/source/tutorial/api_server_args.rst index 18fe54c55..d6f3bd80b 100644 --- a/docs/EN/source/tutorial/api_server_args.rst +++ b/docs/EN/source/tutorial/api_server_args.rst @@ -282,6 +282,17 @@ Multimodal Parameters List of NCCL ports for ViT, e.g., 29500 29501 29502, default is [29500] +.. option:: --vit_att_backend + + Set the attention backend for ViT. Available options: + + * ``auto``: Automatically select the best backend (default), with priority fa3 > xformers > sdpa > triton + * ``fa3``: Use Flash-Attention 3 backend + * ``xformers``: Use xformers backend + * ``sdpa``: Use sdpa backend + * ``triton``: Use Triton backend + + Performance Optimization Parameters ----------------------------------- diff --git a/lightllm/common/basemodel/attention_vit/__init__.py b/lightllm/common/basemodel/attention_vit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/common/basemodel/attention_vit/base_att.py b/lightllm/common/basemodel/attention_vit/base_att.py new file mode 100644 index 000000000..49bf6ad74 --- /dev/null +++ b/lightllm/common/basemodel/attention_vit/base_att.py @@ -0,0 +1,38 @@ +import torch +from abc import ABC, abstractmethod + + +class BaseVitAttBackend(ABC): + """ + 用于创建支持各种不同的AttBackend, 如 fa3, sdpa, triton 实现等, + 这个是单列模式, 每种backend只有一个实例 + """ + + _instances = {} + + def __new__(cls, *args, **kwargs): + """ + 重写__new__方法实现单例模式 + """ + # 检查是否已经有该类的实例 + if cls not in cls._instances: + # 创建新实例并存储 + instance = super().__new__(cls) + cls._instances[cls] = instance + # 返回已有的实例 + return cls._instances[cls] + + def __init__(self): + pass + + @abstractmethod + def _vit_att_fwd( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + ) -> torch.Tensor: + raise NotImplementedError("not impl") diff --git a/lightllm/common/basemodel/attention_vit/create_utils.py b/lightllm/common/basemodel/attention_vit/create_utils.py new file mode 100644 index 000000000..67f830ba0 --- /dev/null +++ b/lightllm/common/basemodel/attention_vit/create_utils.py @@ -0,0 +1,51 @@ +import torch +from lightllm.utils.log_utils import init_logger +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.backend_validator import _validate +from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend +from lightllm.common.basemodel.attention_vit.fa3.fp import Fa3VitAttBackend +from lightllm.common.basemodel.attention_vit.triton.fp import TritonVitAttBackend +from lightllm.common.basemodel.attention_vit.sdpa.fp import SdpaVitAttBackend +from lightllm.common.basemodel.attention_vit.xformers.fp import XformersVitAttBackend + +logger = init_logger(__name__) + + +vit_att_backend = { + "triton": TritonVitAttBackend, + "sdpa": SdpaVitAttBackend, + "fa3": Fa3VitAttBackend, + "xformers": XformersVitAttBackend, +} + + +def get_vit_att_backend_class(backend_name: str) -> BaseVitAttBackend: + vit_att_backend_class = vit_att_backend[backend_name] + return vit_att_backend_class + + +def init_vit_att_backend(index=0, priority_list: list = ["fa3", "xformers", "sdpa", "triton"]) -> str: + args = get_env_start_args() + backend_name = args.vit_att_backend[index] + if backend_name != "auto": + logger.info(f"Selected {backend_name} backend for VIT") + return backend_name + else: + return _select_vit_backend(priority_list=priority_list) + + +def _select_vit_backend(priority_list: list = ["fa3", "xformers", "sdpa", "triton"]) -> str: + """Auto-select the best available backend with validation for VIT. + + Priority: FA3 > Xformers > Sdpa > Triton + Each backend is validated in a subprocess with ground truth checks. + """ + + for backend_name in priority_list: + if _validate(backend_name): + logger.info(f"Auto-selected {backend_name} backend (validated) for VIT") + return backend_name + + # Fallback to triton without validation (should not happen) + logger.warning("No backend validation succeeded, falling back to triton") + return "triton" diff --git a/lightllm/common/basemodel/attention_vit/fa3/__init__.py b/lightllm/common/basemodel/attention_vit/fa3/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/common/basemodel/attention_vit/fa3/fp.py b/lightllm/common/basemodel/attention_vit/fa3/fp.py new file mode 100644 index 000000000..406ff7408 --- /dev/null +++ b/lightllm/common/basemodel/attention_vit/fa3/fp.py @@ -0,0 +1,57 @@ +import dataclasses +import torch +from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend + + +class Fa3VitAttBackend(BaseVitAttBackend): + @staticmethod + def _vit_att_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + ) -> None: + + head_dim = q.shape[-1] + softmax_scale = head_dim ** -0.5 + window_size = (-1, -1) + torch.ops.sgl_kernel.fwd.default( + q, + k, + v, + None, # k_new + None, # v_new + None, # qv + o, # out + cu_seqlens, + cu_seqlens, + None, # cu_seqlens_k_new + None, + None, + max_seqlen, + max_seqlen, + None, # page_table, + None, # kv_batch_idx + None, # leftpad_k + None, # rotary cos + None, # rotary sin + None, # seqlens_rotary + None, + None, + None, + softmax_scale, + False, + window_size[0], + window_size[1], + 0.0, + is_rotary_interleaved=False, + scheduler_metadata=None, + num_splits=1, + pack_gqa=None, + sm_margin=0, + sinks=None, + ) + + return o diff --git a/lightllm/common/basemodel/attention_vit/sdpa/__init__.py b/lightllm/common/basemodel/attention_vit/sdpa/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/common/basemodel/attention_vit/sdpa/fp.py b/lightllm/common/basemodel/attention_vit/sdpa/fp.py new file mode 100644 index 000000000..6c6da2c2b --- /dev/null +++ b/lightllm/common/basemodel/attention_vit/sdpa/fp.py @@ -0,0 +1,47 @@ +import torch +import torch.nn.functional as F +from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend + + +class SdpaVitAttBackend(BaseVitAttBackend): + @staticmethod + def _vit_att_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + ) -> torch.Tensor: + assert q.ndim == k.ndim == v.ndim == o.ndim == 3 + assert cu_seqlens is not None and cu_seqlens.ndim == 1 + cu_seqlens = cu_seqlens.detach().to("cpu") + B = cu_seqlens.numel() - 1 + + with torch.no_grad(): + for b in range(B): + s = int(cu_seqlens[b]) + e = int(cu_seqlens[b + 1]) + L = e - s + if L <= 0: + continue + if max_seqlen: + assert L <= max_seqlen + + # [L, H, D] -> [1, H, L, D] + q_ = q[s:e].permute(1, 0, 2).unsqueeze(0) + k_ = k[s:e].permute(1, 0, 2).unsqueeze(0) + v_ = v[s:e].permute(1, 0, 2).unsqueeze(0) + + out = F.scaled_dot_product_attention( + q_, + k_, + v_, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + ) + # [1, H, L, D] -> [L, H, D] + o[s:e].copy_(out.squeeze(0).permute(1, 0, 2)) + + return o diff --git a/lightllm/common/basemodel/attention_vit/triton/__init__.py b/lightllm/common/basemodel/attention_vit/triton/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/common/basemodel/attention_vit/triton/fp.py b/lightllm/common/basemodel/attention_vit/triton/fp.py new file mode 100644 index 000000000..88867102a --- /dev/null +++ b/lightllm/common/basemodel/attention_vit/triton/fp.py @@ -0,0 +1,24 @@ +import torch +from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend +from lightllm.models.vit.triton_kernel.flashattention_nopad import _flash_attention_triton_fwd + + +class TritonVitAttBackend(BaseVitAttBackend): + @staticmethod + def _vit_att_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + ): + _flash_attention_triton_fwd( + q, + k, + v, + o, + cu_seqlens, # q k v cu_seqlens, + max_seqlen, + ) + return o diff --git a/lightllm/common/basemodel/attention_vit/xformers/__init__.py b/lightllm/common/basemodel/attention_vit/xformers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/common/basemodel/attention_vit/xformers/fp.py b/lightllm/common/basemodel/attention_vit/xformers/fp.py new file mode 100644 index 000000000..361b5db05 --- /dev/null +++ b/lightllm/common/basemodel/attention_vit/xformers/fp.py @@ -0,0 +1,45 @@ +import torch +import torch.nn.functional as F + +try: + from xformers import ops as xformers_ops + from xformers.ops import fmha +except ImportError: + xformers_ops = None + fmha = None + +from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend + + +class XformersVitAttBackend(BaseVitAttBackend): + @torch.no_grad() + @staticmethod + def _vit_att_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + ) -> torch.Tensor: + assert q.ndim == k.ndim == v.ndim == o.ndim == 3 + assert cu_seqlens is not None and cu_seqlens.ndim == 1 + assert q.shape == k.shape == v.shape == o.shape + + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).to(torch.int64).tolist() + seqlens = [int(L) for L in seqlens if int(L) > 0] + + if len(seqlens) == 0: + return o + if max_seqlen: + assert max(seqlens) <= max_seqlen + + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens, device=q.device) + + q_ = q.unsqueeze(0) # [1, T, H, D] + k_ = k.unsqueeze(0) # [1, T, H, D] + v_ = v.unsqueeze(0) # [1, T, H, D] + + out = xformers_ops.memory_efficient_attention(q_, k_, v_, attn_bias=attn_bias, p=0.0) + o.copy_(out.squeeze(0)) # [T, H, D] + return o diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py index 498f82e14..7156a5ce2 100644 --- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -13,7 +13,7 @@ from lightllm.server.multimodal_params import ImageItem from lightllm.models.qwen2_vl.qwen2_visual import PatchEmbed, VisionRotaryEmbedding from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm -from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd +from lightllm.server.visualserver import get_vit_attn_backend from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton @@ -74,7 +74,7 @@ def forward( k = apply_rotary_pos_emb_triton(k, rotary_cos, rotary_sin) attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device) - flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen) + get_vit_attn_backend()(q, k, v, attn_output, cu_seqlens, max_seqlen) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) return attn_output diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index 334ffc844..0e2af0cbb 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -32,8 +32,8 @@ from transformers.activations import ACT2FN from safetensors import safe_open from lightllm.server.multimodal_params import ImageItem +from lightllm.server.visualserver import get_vit_attn_backend from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor -from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton @@ -143,7 +143,7 @@ def forward( attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device) - flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen) + get_vit_attn_backend()(q, k, v, attn_output, cu_seqlens, max_seqlen) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) return attn_output diff --git a/lightllm/models/vit/layer_infer/transformer_layer_infer.py b/lightllm/models/vit/layer_infer/transformer_layer_infer.py index 0d55d1b57..cada55e58 100644 --- a/lightllm/models/vit/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/vit/layer_infer/transformer_layer_infer.py @@ -3,10 +3,10 @@ from typing import Tuple from lightllm.models.vit.layer_weights.transformer_layer_weight import ViTTransformerLayerWeight -from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size from lightllm.models.vit.triton_kernel.gelu_vit import gelu_fwd from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm +from lightllm.server.visualserver import get_vit_attn_backend from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager @@ -105,7 +105,7 @@ def _context_attention_kernel(self, q, k, v) -> torch.Tensor: q, k, v, out = map(reshape, (q, k, v, out)) cu_seqlens = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device) * seq_len max_seqlen = seq_len - flash_attention_fwd(q, k, v, out, cu_seqlens, max_seqlen) + get_vit_attn_backend()(q, k, v, out, cu_seqlens, max_seqlen) return out.reshape(batch_size, seq_len, -1) def _get_o(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 44cc38822..09fd52138 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -39,7 +39,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--pd_decode_rpyc_port", type=int, - default=42000, + default=None, help="p d mode, decode node used for kv move manager rpyc server port", ) parser.add_argument( @@ -210,7 +210,7 @@ def make_argument_parser() -> argparse.ArgumentParser: When deploying in multi-node manner, the value should be set to the IP of the master node""", ) parser.add_argument( - "--nccl_port", type=int, default=28765, help="the nccl_port to build a distributed environment for PyTorch" + "--nccl_port", type=int, default=None, help="the nccl_port to build a distributed environment for PyTorch" ) parser.add_argument( "--use_config_server_to_init_nccl", @@ -259,7 +259,7 @@ def make_argument_parser() -> argparse.ArgumentParser: ) parser.add_argument("--disable_dynamic_prompt_cache", action="store_true", help="disable dynamic prompt cache") - parser.add_argument("--chunked_prefill_size", type=int, default=4096, help="chunked prefill size") + parser.add_argument("--chunked_prefill_size", type=int, default=None, help="chunked prefill size") parser.add_argument("--disable_chunked_prefill", action="store_true", help="whether to disable chunked prefill") parser.add_argument("--diverse_mode", action="store_true", help="diversity generation mode") parser.add_argument("--token_healing_mode", action="store_true", help="code model infer mode") @@ -333,6 +333,16 @@ def make_argument_parser() -> argparse.ArgumentParser: auto: automatically select best backend based on GPU and available packages (priority: fa3 > flashinfer > triton)""", ) + parser.add_argument( + "--vit_att_backend", + type=str, + nargs="+", + choices=["auto", "triton", "fa3", "sdpa", "xformers"], + default=["auto"], + help="""vit attention kernel used in vlm. + auto: automatically select best backend based on GPU and available packages + (priority: fa3 > xformers > sdpa > triton)""", + ) parser.add_argument( "--llm_kv_type", type=str, @@ -390,7 +400,7 @@ def make_argument_parser() -> argparse.ArgumentParser: ) parser.add_argument("--push_interval", type=int, default=10, help="interval of pushing monitoring metrics") parser.add_argument( - "--visual_infer_batch_size", type=int, default=1, help="number of images to process in each inference batch" + "--visual_infer_batch_size", type=int, default=None, help="number of images to process in each inference batch" ) parser.add_argument( "--visual_send_batch_size", @@ -410,7 +420,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "--visual_nccl_ports", nargs="+", type=int, - default=[29500], + default=None, help="List of NCCL ports to build a distributed environment for Vit, e.g., 29500 29501 29502", ) parser.add_argument( diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 3ae3789f4..bd8e4db8b 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -149,18 +149,12 @@ def normal_or_p_d_start(args): else: args.visual_gpu_ids = args.visual_gpu_ids[:total_required_gpus] - # 检查visual_nccl_port数量是否足够 - if len(args.visual_nccl_ports) < args.visual_dp: - raise ValueError( - f"Not enough visual_nccl_ports specified. You need at least {args.visual_dp}, " - f"but got ({len(args.visual_nccl_ports)})." - ) - else: - args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] - if args.visual_dp <= 0: raise ValueError("visual_dp must be a positive integer.") + if args.visual_infer_batch_size is None: + args.visual_infer_batch_size = args.visual_dp + # 检查visual_infer_batch_size是否合理 if args.visual_infer_batch_size // args.visual_dp < 1 or args.visual_infer_batch_size % args.visual_dp != 0: raise ValueError( @@ -174,12 +168,14 @@ def normal_or_p_d_start(args): if args.batch_max_tokens is None: args.batch_max_tokens = args.max_req_total_len else: - assert args.batch_max_tokens >= args.max_req_total_len, "batch_max_tokens must >= max_req_total_len" + assert args.batch_max_tokens >= args.max_req_total_len, f"batch_max_tokens must >= max_req_total_len" + f"but got {args.batch_max_tokens}, {args.max_req_total_len}" else: # chunked 模式下 if args.batch_max_tokens is None: - args.batch_max_tokens = min(args.max_req_total_len, 2 * args.chunked_prefill_size + 256) - + args.batch_max_tokens = 16384 // args.dp + if args.chunked_prefill_size is None: + args.chunked_prefill_size = args.batch_max_tokens // 2 assert ( args.batch_max_tokens >= args.chunked_prefill_size ), "chunked prefill mode, batch_max_tokens must >= chunked_prefill_size, " @@ -203,9 +199,11 @@ def normal_or_p_d_start(args): args.data_type = get_dtype(args.model_dir) assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"] - already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port] - if args.run_mode == "decode": - already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port, args.pd_decode_rpyc_port] + already_uesd_ports = [args.port] + if args.nccl_port is not None: + already_uesd_ports.append(args.nccl_port) + if args.pd_decode_rpyc_port is not None: + already_uesd_ports.append(args.pd_decode_rpyc_port) # 提前锁定端口,防止在单个机器上启动多个实列的时候,要到模型启动的时候才能 # 捕获到端口设置冲突的问题 @@ -214,10 +212,11 @@ def normal_or_p_d_start(args): node_world_size = args.tp // args.nnodes can_use_ports = alloc_can_use_network_port( - num=8 + node_world_size + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports + num=10 + node_world_size + args.visual_dp * (args.visual_tp + 1), used_nccl_ports=already_uesd_ports ) logger.info(f"alloced ports: {can_use_ports}") ( + nccl_port, router_port, detokenization_port, http_server_port, @@ -226,16 +225,24 @@ def normal_or_p_d_start(args): cache_port, metric_port, multi_level_kv_cache_port, - ) = can_use_ports[0:8] - can_use_ports = can_use_ports[8:] + pd_decode_rpyc_port, + ) = can_use_ports[0:10] + can_use_ports = can_use_ports[10:] visual_model_tp_ports = [] + visual_nccl_ports = [] for _ in range(args.visual_dp): tp_ports_for_dp = can_use_ports[0 : args.visual_tp] - can_use_ports = can_use_ports[args.visual_tp :] visual_model_tp_ports.append(tp_ports_for_dp) + can_use_ports = can_use_ports[args.visual_tp :] + visual_nccl_ports.append(can_use_ports[0]) + can_use_ports = can_use_ports[1:] # 将申请好的端口放入args参数中 + if args.nccl_port is None: + args.nccl_port = nccl_port + if args.pd_decode_rpyc_port is None: + args.pd_decode_rpyc_port = pd_decode_rpyc_port args.router_port = router_port args.detokenization_port = detokenization_port args.http_server_port = http_server_port @@ -244,7 +251,7 @@ def normal_or_p_d_start(args): args.cache_port = cache_port args.metric_port = metric_port args.multi_level_kv_cache_port = multi_level_kv_cache_port - + args.visual_nccl_ports = visual_nccl_ports # 申请在 p d 分离模式下,会用的端口 args.pd_node_infer_rpyc_ports = can_use_ports[0:node_world_size] # p d 分离模式下用于标识节点的id diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 239cebfdd..897c48e93 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -20,7 +20,7 @@ class StartArgs: pd_master_port: int = field(default=1212) config_server_host: str = field(default=None) config_server_port: int = field(default=None) - pd_decode_rpyc_port: int = field(default=42000) + pd_decode_rpyc_port: int = field(default=None) select_p_d_node_strategy: str = field(default=None) model_name: str = field(default="default_model_name") model_dir: Optional[str] = field(default=None) @@ -61,7 +61,7 @@ class StartArgs: node_rank: int = field(default=0) max_req_total_len: int = field(default=2048 + 1024) nccl_host: str = field(default="127.0.0.1") - nccl_port: int = field(default=28765) + nccl_port: int = field(default=None) use_config_server_to_init_nccl: bool = field(default=False) trust_remote_code: bool = field(default=False) disable_log_stats: bool = field(default=False) @@ -71,7 +71,7 @@ class StartArgs: router_max_wait_tokens: int = field(default=1) disable_aggressive_schedule: bool = field(default=False) disable_dynamic_prompt_cache: bool = field(default=False) - chunked_prefill_size: int = field(default=8192) + chunked_prefill_size: int = field(default=None) disable_chunked_prefill: bool = field(default=False) diverse_mode: bool = field(default=False) token_healing_mode: bool = field(default=False) @@ -97,12 +97,12 @@ class StartArgs: job_name: str = field(default="lightllm") grouping_key: List[str] = field(default_factory=list) push_interval: int = field(default=10) - visual_infer_batch_size: int = field(default=1) + visual_infer_batch_size: int = field(default=None) visual_send_batch_size: int = field(default=1) visual_gpu_ids: List[int] = field(default_factory=lambda: [0]) visual_tp: int = field(default=1) visual_dp: int = field(default=1) - visual_nccl_ports: List[int] = field(default_factory=lambda: [29500]) + visual_nccl_ports: List[int] = field(default=None) enable_monitor_auth: bool = field(default=False) disable_cudagraph: bool = field(default=False) enable_prefill_cudagraph: bool = field(default=False) @@ -110,7 +110,7 @@ class StartArgs: graph_max_batch_size: int = field(default=256) graph_split_batch_size: int = field(default=32) graph_grow_step_size: int = field(default=16) - graph_max_len_in_batch: int = field(default=8192) + graph_max_len_in_batch: int = field(default=0) quant_type: Optional[str] = field(default=None) quant_cfg: Optional[str] = field(default=None) vit_quant_type: Optional[str] = field(default=None) @@ -121,6 +121,9 @@ class StartArgs: llm_decode_att_backend: List[str] = field( default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]} ) + vit_att_backend: List[str] = field( + default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "sdpa", "xformers"]} + ) llm_kv_type: str = field(default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv"]}) llm_kv_quant_group_size: int = field(default=8) sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "sglang_kernel"]}) diff --git a/lightllm/server/visualserver/__init__.py b/lightllm/server/visualserver/__init__.py index e69de29bb..6bc0923da 100644 --- a/lightllm/server/visualserver/__init__.py +++ b/lightllm/server/visualserver/__init__.py @@ -0,0 +1,16 @@ +from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend +from lightllm.common.basemodel.attention_vit.create_utils import get_vit_att_backend_class + +VIT_ATTN_BACKEND: BaseVitAttBackend = None + + +def set_vit_att_backend(backend_name: str): + global VIT_ATTN_BACKEND + VIT_ATTN_BACKEND = get_vit_att_backend_class(backend_name) + return + + +def get_vit_attn_backend(): + if VIT_ATTN_BACKEND is None: + raise RuntimeError("VIT_ATTN_BACKEND is not initialized. Call init_vit_att_backend() first.") + return VIT_ATTN_BACKEND._vit_att_fwd diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index a389272e5..202c2fc45 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -14,6 +14,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from lightllm.server.multimodal_params import MultimodalParams, ImageItem from .model_infer.model_rpc import start_model_process, VisualModelRpcClient +from lightllm.common.basemodel.attention_vit.create_utils import init_vit_att_backend from lightllm.utils.log_utils import init_logger from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread @@ -63,7 +64,7 @@ def __init__( async def wait_to_model_ready(self): self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)] - + self.vit_attn_backend = init_vit_att_backend(index=0) for dp_rank_id in range(self.vit_dp): tp_ports_each_dp = self.visual_model_rpc_ports[dp_rank_id] for tp_rank_id in range(self.vit_tp): @@ -91,6 +92,7 @@ async def wait_to_model_ready(self): "quant_type": self.args.vit_quant_type, "quant_cfg": self.args.vit_quant_cfg, "max_batch_size": min(self.infer_batch_size // self.vit_dp, 1), + "vit_attn_backend": self.vit_attn_backend, } init_model_ret.append(self.model_rpcs[dp_rank_id][tp_rank_id].init_model(kvargs)) await asyncio.gather(*init_model_ret) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index d3d1610f3..8f4a1ee45 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -24,6 +24,7 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient +from lightllm.server.visualserver import set_vit_att_backend class VisualModelRpcServer(rpyc.Service): @@ -42,7 +43,8 @@ def exposed_init_model(self, kvargs): self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.data_type = kvargs["data_type"] - + self.vit_attn_backend = kvargs["vit_attn_backend"] + set_vit_att_backend(self.vit_attn_backend) init_vision_distributed_env(kvargs) model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) diff --git a/lightllm/utils/backend_validator.py b/lightllm/utils/backend_validator.py index 1f02c3952..0e2f9c962 100644 --- a/lightllm/utils/backend_validator.py +++ b/lightllm/utils/backend_validator.py @@ -121,6 +121,72 @@ def _softmax_kernel(input_ptr, output_ptr, n_cols, BLOCK: tl.constexpr): return True, None +def _validate_xformers(): + """Validate Xformers Attn with ground truth.""" + try: + import torch + + if not torch.cuda.is_available(): + return False, "CUDA not available" + + import xformers.ops as xformers_ops + from xformers.ops import fmha + except Exception as e: + return False, f"xformers import failed: {type(e).__name__}: {e}" + + batch, heads, seq, dim = 1, 4, 8, 64 + dtype = torch.bfloat16 + device = "cuda" + + q = torch.randn(batch, heads, seq, dim, dtype=dtype, device=device) + k = torch.randn(batch, heads, seq, dim, dtype=dtype, device=device) + v = torch.randn(batch, heads, seq, dim, dtype=dtype, device=device) + + expected = _compute_ground_truth(q, k, v, is_causal=False) + + q_bmhd = q.transpose(1, 2).contiguous() # (B, seq, heads, dim) + k_bmhd = k.transpose(1, 2).contiguous() + v_bmhd = v.transpose(1, 2).contiguous() + + try: + out = xformers_ops.memory_efficient_attention(q_bmhd, k_bmhd, v_bmhd, p=0.0) + except Exception as e: + return False, f"xformers attention run failed: {type(e).__name__}: {e}" + + out = out.transpose(1, 2).contiguous() + + if not torch.allclose(out, expected, rtol=1e-2, atol=1e-2): + return False, f"Output mismatch: max diff {(out - expected).abs().max().item():.6f}" + + return True, None + + +def _validate_sdpa(): + """Validate SDPA Attn with ground truth.""" + try: + import torch + from torch.nn.functional import scaled_dot_product_attention + except Exception as e: + return False, f"SDPA import failed: {type(e).__name__}: {e}" + + batch, heads, seq, dim = 1, 4, 8, 64 + dtype = torch.bfloat16 + device = "cuda" + + q = torch.randn(batch, heads, seq, dim, dtype=dtype, device=device) + k = torch.randn(batch, heads, seq, dim, dtype=dtype, device=device) + v = torch.randn(batch, heads, seq, dim, dtype=dtype, device=device) + + expected = _compute_ground_truth(q, k, v, is_causal=False) + + out = scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False) + + if not torch.allclose(out, expected, rtol=1e-2, atol=1e-2): + return False, f"Output mismatch: max diff {(out - expected).abs().max().item():.6f}" + + return True, None + + def _run_in_subprocess(backend_name, pipe): """Run validation in subprocess with suppressed output.""" import sys @@ -133,6 +199,10 @@ def _run_in_subprocess(backend_name, pipe): try: if backend_name == "fa3": success, err = _validate_fa3() + elif backend_name == "xformers": + success, err = _validate_xformers() + elif backend_name == "sdpa": + success, err = _validate_sdpa() elif backend_name == "flashinfer": success, err = _validate_flashinfer() elif backend_name == "triton": diff --git a/requirements.txt b/requirements.txt index 8d9a011be..a3b9473f8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -87,6 +87,7 @@ librosa==0.11.0 cuda_bindings==12.9.0 orjson==3.11.2 setproctitle==1.3.6 +xformers==0.0.32.post1 xxhash==3.6.0 torchvision==0.23.0 interegular==0.3.3 diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index 7f1c2b493..919f379b9 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -43,6 +43,7 @@ def test_model_inference(args): "disable_cudagraph": args.disable_cudagraph, "llm_prefill_att_backend": args.llm_prefill_att_backend, "llm_decode_att_backend": args.llm_decode_att_backend, + "vit_att_backend": args.vit_att_backend, "llm_kv_type": args.llm_kv_type, "llm_kv_quant_group_size": args.llm_kv_quant_group_size, }