[MM][OOT] Support CPU seq_lens for OOT MMEncoderAttention kernels (#36605)

Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Shanshan Shen
2026-03-12 18:28:23 +08:00
committed by GitHub
parent 57431d8231
commit f0d3658c0f
5 changed files with 52 additions and 40 deletions

View File

@@ -297,11 +297,10 @@ def test_mha_attn_varlen_forward_flashinfer(
hidden_size = num_heads * head_size
tp_size = 1
sequence_lengths_np = MMEncoderAttention.maybe_compute_sequence_lengths(
AttentionBackendEnum.FLASHINFER, cu_seqlens_np
)
sequence_lengths = torch.from_numpy(sequence_lengths_np).to(
device, dtype=torch.int32, non_blocking=True
sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens(
AttentionBackendEnum.FLASHINFER,
cu_seqlens_np,
device,
)
max_seqlen_val = MMEncoderAttention.compute_max_seqlen(
@@ -309,14 +308,12 @@ def test_mha_attn_varlen_forward_flashinfer(
)
max_seqlen = torch.tensor(max_seqlen_val, device=device, dtype=torch.int32)
cu_seqlens_np = MMEncoderAttention.maybe_recompute_cu_seqlens(
cu_seqlens = MMEncoderAttention.maybe_recompute_cu_seqlens(
AttentionBackendEnum.FLASHINFER,
cu_seqlens_np,
hidden_size,
tp_size,
)
cu_seqlens = torch.from_numpy(cu_seqlens_np).to(
device, dtype=torch.int32, non_blocking=True
device,
)
scale = 1.0 / head_size**0.5

View File

@@ -22,6 +22,12 @@ op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
def get_oot_class_by_name(class_name: str) -> type | None:
if class_name in op_registry_oot:
return op_registry_oot[class_name]
return None
class PluggableLayer(nn.Module):
"""
Base class for pluggable layers.

View File

@@ -6,7 +6,7 @@ import numpy as np
import torch
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.custom_op import CustomOp, get_oot_class_by_name
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.utils.math_utils import round_up
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
@@ -119,17 +119,25 @@ class MMEncoderAttention(CustomOp):
return max_seqlen
@classmethod
def maybe_compute_sequence_lengths(
def maybe_compute_seq_lens(
cls,
attn_backend: AttentionBackendEnum,
cu_seqlens: np.ndarray,
) -> np.ndarray | None:
device: torch.device,
) -> torch.Tensor | None:
if (oot_class := get_oot_class_by_name(cls.__name__)) is not None:
return oot_class.maybe_compute_seq_lens(attn_backend, cu_seqlens, device) # type: ignore[attr-defined]
if attn_backend != AttentionBackendEnum.FLASHINFER:
return None
sequence_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
sequence_lengths = add_padding_to_seqlens(
sequence_lengths, len(sequence_lengths), 0
)
sequence_lengths = torch.from_numpy(sequence_lengths).to(
device, non_blocking=True
)
return sequence_lengths
@classmethod
@@ -139,24 +147,31 @@ class MMEncoderAttention(CustomOp):
cu_seqlens: np.ndarray,
hidden_size: int,
tp_size: int,
) -> np.ndarray:
if attn_backend != AttentionBackendEnum.FLASHINFER:
return cu_seqlens
device: torch.device,
) -> torch.Tensor:
if (oot_class := get_oot_class_by_name(cls.__name__)) is not None:
return oot_class.maybe_recompute_cu_seqlens( # type: ignore[attr-defined]
attn_backend, cu_seqlens, hidden_size, tp_size, device
)
batch_size = len(cu_seqlens) - 1
scale = hidden_size // tp_size
cu_seqlens = cu_seqlens * scale
if attn_backend == AttentionBackendEnum.FLASHINFER:
batch_size = len(cu_seqlens) - 1
scale = hidden_size // tp_size
cu_seqlens = cu_seqlens * scale
cu_seqlens_qko = cu_seqlens
cu_seqlens_v = cu_seqlens * 3
cu_seqlens_qko = cu_seqlens
cu_seqlens_v = cu_seqlens * 3
cu_seqlens_qko = add_padding_to_seqlens(
cu_seqlens_qko, batch_size, cu_seqlens_qko[-1]
)
cu_seqlens_v = add_padding_to_seqlens(
cu_seqlens_v, batch_size, cu_seqlens_v[-1]
)
return np.concatenate([cu_seqlens_qko, cu_seqlens_v])
cu_seqlens_qko = add_padding_to_seqlens(
cu_seqlens_qko, batch_size, cu_seqlens_qko[-1]
)
cu_seqlens_v = add_padding_to_seqlens(
cu_seqlens_v, batch_size, cu_seqlens_v[-1]
)
cu_seqlens = np.concatenate([cu_seqlens_qko, cu_seqlens_v])
cu_seqlens = torch.from_numpy(cu_seqlens).to(device, non_blocking=True)
return cu_seqlens
def __init__(
self,

View File

@@ -983,13 +983,11 @@ class Qwen3Omni_VisionTransformer(nn.Module):
grid_thw_np[:, 1] * grid_thw_np[:, 2], grid_thw_np[:, 0]
).cumsum(axis=0, dtype=np.int32)
cu_seqlens_np = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens_np])
sequence_lengths = MMEncoderAttention.maybe_compute_sequence_lengths(
self.attn_backend, cu_seqlens_np
sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens(
self.attn_backend,
cu_seqlens_np,
self.device,
)
if sequence_lengths is not None:
sequence_lengths = torch.from_numpy(sequence_lengths).to(
self.device, non_blocking=True
)
hidden_states_list = []
deepstack_visual_indexes = self.deepstack_visual_indexes

View File

@@ -550,13 +550,9 @@ class Qwen3_VisionTransformer(nn.Module):
axis=0, dtype=np.int32
)
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
sequence_lengths = MMEncoderAttention.maybe_compute_sequence_lengths(
self.attn_backend, cu_seqlens
sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens(
self.attn_backend, cu_seqlens, self.device
)
if sequence_lengths is not None:
sequence_lengths = torch.from_numpy(sequence_lengths).to(
self.device, non_blocking=True
)
max_seqlen = torch.tensor(
MMEncoderAttention.compute_max_seqlen(self.attn_backend, cu_seqlens),
dtype=torch.int32,
@@ -567,8 +563,8 @@ class Qwen3_VisionTransformer(nn.Module):
cu_seqlens,
self.hidden_size,
self.tp_size,
self.device,
)
cu_seqlens = torch.from_numpy(cu_seqlens).to(self.device, non_blocking=True)
hidden_states = hidden_states.unsqueeze(1)
deepstack_feature_lists = []