diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index 3bcde3b0a..858d9504a 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -297,11 +297,10 @@ def test_mha_attn_varlen_forward_flashinfer( hidden_size = num_heads * head_size tp_size = 1 - sequence_lengths_np = MMEncoderAttention.maybe_compute_sequence_lengths( - AttentionBackendEnum.FLASHINFER, cu_seqlens_np - ) - sequence_lengths = torch.from_numpy(sequence_lengths_np).to( - device, dtype=torch.int32, non_blocking=True + sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens( + AttentionBackendEnum.FLASHINFER, + cu_seqlens_np, + device, ) max_seqlen_val = MMEncoderAttention.compute_max_seqlen( @@ -309,14 +308,12 @@ def test_mha_attn_varlen_forward_flashinfer( ) max_seqlen = torch.tensor(max_seqlen_val, device=device, dtype=torch.int32) - cu_seqlens_np = MMEncoderAttention.maybe_recompute_cu_seqlens( + cu_seqlens = MMEncoderAttention.maybe_recompute_cu_seqlens( AttentionBackendEnum.FLASHINFER, cu_seqlens_np, hidden_size, tp_size, - ) - cu_seqlens = torch.from_numpy(cu_seqlens_np).to( - device, dtype=torch.int32, non_blocking=True + device, ) scale = 1.0 / head_size**0.5 diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 851546297..b8e372e88 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -22,6 +22,12 @@ op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {} op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {} +def get_oot_class_by_name(class_name: str) -> type | None: + if class_name in op_registry_oot: + return op_registry_oot[class_name] + return None + + class PluggableLayer(nn.Module): """ Base class for pluggable layers. diff --git a/vllm/model_executor/layers/attention/mm_encoder_attention.py b/vllm/model_executor/layers/attention/mm_encoder_attention.py index d902f2ebc..bc0687ed2 100644 --- a/vllm/model_executor/layers/attention/mm_encoder_attention.py +++ b/vllm/model_executor/layers/attention/mm_encoder_attention.py @@ -6,7 +6,7 @@ import numpy as np import torch from vllm.logger import init_logger -from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.custom_op import CustomOp, get_oot_class_by_name from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.utils.math_utils import round_up from vllm.v1.attention.backends.fa_utils import get_flash_attn_version @@ -119,17 +119,25 @@ class MMEncoderAttention(CustomOp): return max_seqlen @classmethod - def maybe_compute_sequence_lengths( + def maybe_compute_seq_lens( cls, attn_backend: AttentionBackendEnum, cu_seqlens: np.ndarray, - ) -> np.ndarray | None: + device: torch.device, + ) -> torch.Tensor | None: + if (oot_class := get_oot_class_by_name(cls.__name__)) is not None: + return oot_class.maybe_compute_seq_lens(attn_backend, cu_seqlens, device) # type: ignore[attr-defined] + if attn_backend != AttentionBackendEnum.FLASHINFER: return None + sequence_lengths = cu_seqlens[1:] - cu_seqlens[:-1] sequence_lengths = add_padding_to_seqlens( sequence_lengths, len(sequence_lengths), 0 ) + sequence_lengths = torch.from_numpy(sequence_lengths).to( + device, non_blocking=True + ) return sequence_lengths @classmethod @@ -139,24 +147,31 @@ class MMEncoderAttention(CustomOp): cu_seqlens: np.ndarray, hidden_size: int, tp_size: int, - ) -> np.ndarray: - if attn_backend != AttentionBackendEnum.FLASHINFER: - return cu_seqlens + device: torch.device, + ) -> torch.Tensor: + if (oot_class := get_oot_class_by_name(cls.__name__)) is not None: + return oot_class.maybe_recompute_cu_seqlens( # type: ignore[attr-defined] + attn_backend, cu_seqlens, hidden_size, tp_size, device + ) - batch_size = len(cu_seqlens) - 1 - scale = hidden_size // tp_size - cu_seqlens = cu_seqlens * scale + if attn_backend == AttentionBackendEnum.FLASHINFER: + batch_size = len(cu_seqlens) - 1 + scale = hidden_size // tp_size + cu_seqlens = cu_seqlens * scale - cu_seqlens_qko = cu_seqlens - cu_seqlens_v = cu_seqlens * 3 + cu_seqlens_qko = cu_seqlens + cu_seqlens_v = cu_seqlens * 3 - cu_seqlens_qko = add_padding_to_seqlens( - cu_seqlens_qko, batch_size, cu_seqlens_qko[-1] - ) - cu_seqlens_v = add_padding_to_seqlens( - cu_seqlens_v, batch_size, cu_seqlens_v[-1] - ) - return np.concatenate([cu_seqlens_qko, cu_seqlens_v]) + cu_seqlens_qko = add_padding_to_seqlens( + cu_seqlens_qko, batch_size, cu_seqlens_qko[-1] + ) + cu_seqlens_v = add_padding_to_seqlens( + cu_seqlens_v, batch_size, cu_seqlens_v[-1] + ) + cu_seqlens = np.concatenate([cu_seqlens_qko, cu_seqlens_v]) + + cu_seqlens = torch.from_numpy(cu_seqlens).to(device, non_blocking=True) + return cu_seqlens def __init__( self, diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index f3a8d8d53..ff352a735 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -983,13 +983,11 @@ class Qwen3Omni_VisionTransformer(nn.Module): grid_thw_np[:, 1] * grid_thw_np[:, 2], grid_thw_np[:, 0] ).cumsum(axis=0, dtype=np.int32) cu_seqlens_np = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens_np]) - sequence_lengths = MMEncoderAttention.maybe_compute_sequence_lengths( - self.attn_backend, cu_seqlens_np + sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens( + self.attn_backend, + cu_seqlens_np, + self.device, ) - if sequence_lengths is not None: - sequence_lengths = torch.from_numpy(sequence_lengths).to( - self.device, non_blocking=True - ) hidden_states_list = [] deepstack_visual_indexes = self.deepstack_visual_indexes diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index dcfa087c1..dc0842258 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -550,13 +550,9 @@ class Qwen3_VisionTransformer(nn.Module): axis=0, dtype=np.int32 ) cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens]) - sequence_lengths = MMEncoderAttention.maybe_compute_sequence_lengths( - self.attn_backend, cu_seqlens + sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens( + self.attn_backend, cu_seqlens, self.device ) - if sequence_lengths is not None: - sequence_lengths = torch.from_numpy(sequence_lengths).to( - self.device, non_blocking=True - ) max_seqlen = torch.tensor( MMEncoderAttention.compute_max_seqlen(self.attn_backend, cu_seqlens), dtype=torch.int32, @@ -567,8 +563,8 @@ class Qwen3_VisionTransformer(nn.Module): cu_seqlens, self.hidden_size, self.tp_size, + self.device, ) - cu_seqlens = torch.from_numpy(cu_seqlens).to(self.device, non_blocking=True) hidden_states = hidden_states.unsqueeze(1) deepstack_feature_lists = []