[MM][OOT] Support CPU seq_lens for OOT MMEncoderAttention kernels (#36605)
Signed-off-by: shen-shanshan <467638484@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -297,11 +297,10 @@ def test_mha_attn_varlen_forward_flashinfer(
|
||||
hidden_size = num_heads * head_size
|
||||
tp_size = 1
|
||||
|
||||
sequence_lengths_np = MMEncoderAttention.maybe_compute_sequence_lengths(
|
||||
AttentionBackendEnum.FLASHINFER, cu_seqlens_np
|
||||
)
|
||||
sequence_lengths = torch.from_numpy(sequence_lengths_np).to(
|
||||
device, dtype=torch.int32, non_blocking=True
|
||||
sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens(
|
||||
AttentionBackendEnum.FLASHINFER,
|
||||
cu_seqlens_np,
|
||||
device,
|
||||
)
|
||||
|
||||
max_seqlen_val = MMEncoderAttention.compute_max_seqlen(
|
||||
@@ -309,14 +308,12 @@ def test_mha_attn_varlen_forward_flashinfer(
|
||||
)
|
||||
max_seqlen = torch.tensor(max_seqlen_val, device=device, dtype=torch.int32)
|
||||
|
||||
cu_seqlens_np = MMEncoderAttention.maybe_recompute_cu_seqlens(
|
||||
cu_seqlens = MMEncoderAttention.maybe_recompute_cu_seqlens(
|
||||
AttentionBackendEnum.FLASHINFER,
|
||||
cu_seqlens_np,
|
||||
hidden_size,
|
||||
tp_size,
|
||||
)
|
||||
cu_seqlens = torch.from_numpy(cu_seqlens_np).to(
|
||||
device, dtype=torch.int32, non_blocking=True
|
||||
device,
|
||||
)
|
||||
|
||||
scale = 1.0 / head_size**0.5
|
||||
|
||||
@@ -22,6 +22,12 @@ op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
|
||||
op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
|
||||
|
||||
|
||||
def get_oot_class_by_name(class_name: str) -> type | None:
|
||||
if class_name in op_registry_oot:
|
||||
return op_registry_oot[class_name]
|
||||
return None
|
||||
|
||||
|
||||
class PluggableLayer(nn.Module):
|
||||
"""
|
||||
Base class for pluggable layers.
|
||||
|
||||
@@ -6,7 +6,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.custom_op import CustomOp, get_oot_class_by_name
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
from vllm.utils.math_utils import round_up
|
||||
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
|
||||
@@ -119,17 +119,25 @@ class MMEncoderAttention(CustomOp):
|
||||
return max_seqlen
|
||||
|
||||
@classmethod
|
||||
def maybe_compute_sequence_lengths(
|
||||
def maybe_compute_seq_lens(
|
||||
cls,
|
||||
attn_backend: AttentionBackendEnum,
|
||||
cu_seqlens: np.ndarray,
|
||||
) -> np.ndarray | None:
|
||||
device: torch.device,
|
||||
) -> torch.Tensor | None:
|
||||
if (oot_class := get_oot_class_by_name(cls.__name__)) is not None:
|
||||
return oot_class.maybe_compute_seq_lens(attn_backend, cu_seqlens, device) # type: ignore[attr-defined]
|
||||
|
||||
if attn_backend != AttentionBackendEnum.FLASHINFER:
|
||||
return None
|
||||
|
||||
sequence_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
sequence_lengths = add_padding_to_seqlens(
|
||||
sequence_lengths, len(sequence_lengths), 0
|
||||
)
|
||||
sequence_lengths = torch.from_numpy(sequence_lengths).to(
|
||||
device, non_blocking=True
|
||||
)
|
||||
return sequence_lengths
|
||||
|
||||
@classmethod
|
||||
@@ -139,10 +147,14 @@ class MMEncoderAttention(CustomOp):
|
||||
cu_seqlens: np.ndarray,
|
||||
hidden_size: int,
|
||||
tp_size: int,
|
||||
) -> np.ndarray:
|
||||
if attn_backend != AttentionBackendEnum.FLASHINFER:
|
||||
return cu_seqlens
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
if (oot_class := get_oot_class_by_name(cls.__name__)) is not None:
|
||||
return oot_class.maybe_recompute_cu_seqlens( # type: ignore[attr-defined]
|
||||
attn_backend, cu_seqlens, hidden_size, tp_size, device
|
||||
)
|
||||
|
||||
if attn_backend == AttentionBackendEnum.FLASHINFER:
|
||||
batch_size = len(cu_seqlens) - 1
|
||||
scale = hidden_size // tp_size
|
||||
cu_seqlens = cu_seqlens * scale
|
||||
@@ -156,7 +168,10 @@ class MMEncoderAttention(CustomOp):
|
||||
cu_seqlens_v = add_padding_to_seqlens(
|
||||
cu_seqlens_v, batch_size, cu_seqlens_v[-1]
|
||||
)
|
||||
return np.concatenate([cu_seqlens_qko, cu_seqlens_v])
|
||||
cu_seqlens = np.concatenate([cu_seqlens_qko, cu_seqlens_v])
|
||||
|
||||
cu_seqlens = torch.from_numpy(cu_seqlens).to(device, non_blocking=True)
|
||||
return cu_seqlens
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -983,12 +983,10 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
||||
grid_thw_np[:, 1] * grid_thw_np[:, 2], grid_thw_np[:, 0]
|
||||
).cumsum(axis=0, dtype=np.int32)
|
||||
cu_seqlens_np = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens_np])
|
||||
sequence_lengths = MMEncoderAttention.maybe_compute_sequence_lengths(
|
||||
self.attn_backend, cu_seqlens_np
|
||||
)
|
||||
if sequence_lengths is not None:
|
||||
sequence_lengths = torch.from_numpy(sequence_lengths).to(
|
||||
self.device, non_blocking=True
|
||||
sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens(
|
||||
self.attn_backend,
|
||||
cu_seqlens_np,
|
||||
self.device,
|
||||
)
|
||||
|
||||
hidden_states_list = []
|
||||
|
||||
@@ -550,12 +550,8 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
axis=0, dtype=np.int32
|
||||
)
|
||||
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
|
||||
sequence_lengths = MMEncoderAttention.maybe_compute_sequence_lengths(
|
||||
self.attn_backend, cu_seqlens
|
||||
)
|
||||
if sequence_lengths is not None:
|
||||
sequence_lengths = torch.from_numpy(sequence_lengths).to(
|
||||
self.device, non_blocking=True
|
||||
sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens(
|
||||
self.attn_backend, cu_seqlens, self.device
|
||||
)
|
||||
max_seqlen = torch.tensor(
|
||||
MMEncoderAttention.compute_max_seqlen(self.attn_backend, cu_seqlens),
|
||||
@@ -567,8 +563,8 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
cu_seqlens,
|
||||
self.hidden_size,
|
||||
self.tp_size,
|
||||
self.device,
|
||||
)
|
||||
cu_seqlens = torch.from_numpy(cu_seqlens).to(self.device, non_blocking=True)
|
||||
hidden_states = hidden_states.unsqueeze(1)
|
||||
|
||||
deepstack_feature_lists = []
|
||||
|
||||
Reference in New Issue
Block a user