[MM][OOT] Support CPU seq_lens for OOT MMEncoderAttention kernels (#36605)

Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Shanshan Shen
2026-03-12 18:28:23 +08:00
committed by GitHub
parent 57431d8231
commit f0d3658c0f
5 changed files with 52 additions and 40 deletions

View File

@@ -297,11 +297,10 @@ def test_mha_attn_varlen_forward_flashinfer(
hidden_size = num_heads * head_size hidden_size = num_heads * head_size
tp_size = 1 tp_size = 1
sequence_lengths_np = MMEncoderAttention.maybe_compute_sequence_lengths( sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens(
AttentionBackendEnum.FLASHINFER, cu_seqlens_np AttentionBackendEnum.FLASHINFER,
) cu_seqlens_np,
sequence_lengths = torch.from_numpy(sequence_lengths_np).to( device,
device, dtype=torch.int32, non_blocking=True
) )
max_seqlen_val = MMEncoderAttention.compute_max_seqlen( max_seqlen_val = MMEncoderAttention.compute_max_seqlen(
@@ -309,14 +308,12 @@ def test_mha_attn_varlen_forward_flashinfer(
) )
max_seqlen = torch.tensor(max_seqlen_val, device=device, dtype=torch.int32) max_seqlen = torch.tensor(max_seqlen_val, device=device, dtype=torch.int32)
cu_seqlens_np = MMEncoderAttention.maybe_recompute_cu_seqlens( cu_seqlens = MMEncoderAttention.maybe_recompute_cu_seqlens(
AttentionBackendEnum.FLASHINFER, AttentionBackendEnum.FLASHINFER,
cu_seqlens_np, cu_seqlens_np,
hidden_size, hidden_size,
tp_size, tp_size,
) device,
cu_seqlens = torch.from_numpy(cu_seqlens_np).to(
device, dtype=torch.int32, non_blocking=True
) )
scale = 1.0 / head_size**0.5 scale = 1.0 / head_size**0.5

View File

@@ -22,6 +22,12 @@ op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {} op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
def get_oot_class_by_name(class_name: str) -> type | None:
if class_name in op_registry_oot:
return op_registry_oot[class_name]
return None
class PluggableLayer(nn.Module): class PluggableLayer(nn.Module):
""" """
Base class for pluggable layers. Base class for pluggable layers.

View File

@@ -6,7 +6,7 @@ import numpy as np
import torch import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp, get_oot_class_by_name
from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.utils.math_utils import round_up from vllm.utils.math_utils import round_up
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
@@ -119,17 +119,25 @@ class MMEncoderAttention(CustomOp):
return max_seqlen return max_seqlen
@classmethod @classmethod
def maybe_compute_sequence_lengths( def maybe_compute_seq_lens(
cls, cls,
attn_backend: AttentionBackendEnum, attn_backend: AttentionBackendEnum,
cu_seqlens: np.ndarray, cu_seqlens: np.ndarray,
) -> np.ndarray | None: device: torch.device,
) -> torch.Tensor | None:
if (oot_class := get_oot_class_by_name(cls.__name__)) is not None:
return oot_class.maybe_compute_seq_lens(attn_backend, cu_seqlens, device) # type: ignore[attr-defined]
if attn_backend != AttentionBackendEnum.FLASHINFER: if attn_backend != AttentionBackendEnum.FLASHINFER:
return None return None
sequence_lengths = cu_seqlens[1:] - cu_seqlens[:-1] sequence_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
sequence_lengths = add_padding_to_seqlens( sequence_lengths = add_padding_to_seqlens(
sequence_lengths, len(sequence_lengths), 0 sequence_lengths, len(sequence_lengths), 0
) )
sequence_lengths = torch.from_numpy(sequence_lengths).to(
device, non_blocking=True
)
return sequence_lengths return sequence_lengths
@classmethod @classmethod
@@ -139,24 +147,31 @@ class MMEncoderAttention(CustomOp):
cu_seqlens: np.ndarray, cu_seqlens: np.ndarray,
hidden_size: int, hidden_size: int,
tp_size: int, tp_size: int,
) -> np.ndarray: device: torch.device,
if attn_backend != AttentionBackendEnum.FLASHINFER: ) -> torch.Tensor:
return cu_seqlens if (oot_class := get_oot_class_by_name(cls.__name__)) is not None:
return oot_class.maybe_recompute_cu_seqlens( # type: ignore[attr-defined]
attn_backend, cu_seqlens, hidden_size, tp_size, device
)
batch_size = len(cu_seqlens) - 1 if attn_backend == AttentionBackendEnum.FLASHINFER:
scale = hidden_size // tp_size batch_size = len(cu_seqlens) - 1
cu_seqlens = cu_seqlens * scale scale = hidden_size // tp_size
cu_seqlens = cu_seqlens * scale
cu_seqlens_qko = cu_seqlens cu_seqlens_qko = cu_seqlens
cu_seqlens_v = cu_seqlens * 3 cu_seqlens_v = cu_seqlens * 3
cu_seqlens_qko = add_padding_to_seqlens( cu_seqlens_qko = add_padding_to_seqlens(
cu_seqlens_qko, batch_size, cu_seqlens_qko[-1] cu_seqlens_qko, batch_size, cu_seqlens_qko[-1]
) )
cu_seqlens_v = add_padding_to_seqlens( cu_seqlens_v = add_padding_to_seqlens(
cu_seqlens_v, batch_size, cu_seqlens_v[-1] cu_seqlens_v, batch_size, cu_seqlens_v[-1]
) )
return np.concatenate([cu_seqlens_qko, cu_seqlens_v]) cu_seqlens = np.concatenate([cu_seqlens_qko, cu_seqlens_v])
cu_seqlens = torch.from_numpy(cu_seqlens).to(device, non_blocking=True)
return cu_seqlens
def __init__( def __init__(
self, self,

View File

@@ -983,13 +983,11 @@ class Qwen3Omni_VisionTransformer(nn.Module):
grid_thw_np[:, 1] * grid_thw_np[:, 2], grid_thw_np[:, 0] grid_thw_np[:, 1] * grid_thw_np[:, 2], grid_thw_np[:, 0]
).cumsum(axis=0, dtype=np.int32) ).cumsum(axis=0, dtype=np.int32)
cu_seqlens_np = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens_np]) cu_seqlens_np = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens_np])
sequence_lengths = MMEncoderAttention.maybe_compute_sequence_lengths( sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens(
self.attn_backend, cu_seqlens_np self.attn_backend,
cu_seqlens_np,
self.device,
) )
if sequence_lengths is not None:
sequence_lengths = torch.from_numpy(sequence_lengths).to(
self.device, non_blocking=True
)
hidden_states_list = [] hidden_states_list = []
deepstack_visual_indexes = self.deepstack_visual_indexes deepstack_visual_indexes = self.deepstack_visual_indexes

View File

@@ -550,13 +550,9 @@ class Qwen3_VisionTransformer(nn.Module):
axis=0, dtype=np.int32 axis=0, dtype=np.int32
) )
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens]) cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
sequence_lengths = MMEncoderAttention.maybe_compute_sequence_lengths( sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens(
self.attn_backend, cu_seqlens self.attn_backend, cu_seqlens, self.device
) )
if sequence_lengths is not None:
sequence_lengths = torch.from_numpy(sequence_lengths).to(
self.device, non_blocking=True
)
max_seqlen = torch.tensor( max_seqlen = torch.tensor(
MMEncoderAttention.compute_max_seqlen(self.attn_backend, cu_seqlens), MMEncoderAttention.compute_max_seqlen(self.attn_backend, cu_seqlens),
dtype=torch.int32, dtype=torch.int32,
@@ -567,8 +563,8 @@ class Qwen3_VisionTransformer(nn.Module):
cu_seqlens, cu_seqlens,
self.hidden_size, self.hidden_size,
self.tp_size, self.tp_size,
self.device,
) )
cu_seqlens = torch.from_numpy(cu_seqlens).to(self.device, non_blocking=True)
hidden_states = hidden_states.unsqueeze(1) hidden_states = hidden_states.unsqueeze(1)
deepstack_feature_lists = [] deepstack_feature_lists = []