[MM][OOT] Support CPU seq_lens for OOT MMEncoderAttention kernels (#36605)
Signed-off-by: shen-shanshan <467638484@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -297,11 +297,10 @@ def test_mha_attn_varlen_forward_flashinfer(
|
|||||||
hidden_size = num_heads * head_size
|
hidden_size = num_heads * head_size
|
||||||
tp_size = 1
|
tp_size = 1
|
||||||
|
|
||||||
sequence_lengths_np = MMEncoderAttention.maybe_compute_sequence_lengths(
|
sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens(
|
||||||
AttentionBackendEnum.FLASHINFER, cu_seqlens_np
|
AttentionBackendEnum.FLASHINFER,
|
||||||
)
|
cu_seqlens_np,
|
||||||
sequence_lengths = torch.from_numpy(sequence_lengths_np).to(
|
device,
|
||||||
device, dtype=torch.int32, non_blocking=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
max_seqlen_val = MMEncoderAttention.compute_max_seqlen(
|
max_seqlen_val = MMEncoderAttention.compute_max_seqlen(
|
||||||
@@ -309,14 +308,12 @@ def test_mha_attn_varlen_forward_flashinfer(
|
|||||||
)
|
)
|
||||||
max_seqlen = torch.tensor(max_seqlen_val, device=device, dtype=torch.int32)
|
max_seqlen = torch.tensor(max_seqlen_val, device=device, dtype=torch.int32)
|
||||||
|
|
||||||
cu_seqlens_np = MMEncoderAttention.maybe_recompute_cu_seqlens(
|
cu_seqlens = MMEncoderAttention.maybe_recompute_cu_seqlens(
|
||||||
AttentionBackendEnum.FLASHINFER,
|
AttentionBackendEnum.FLASHINFER,
|
||||||
cu_seqlens_np,
|
cu_seqlens_np,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
tp_size,
|
tp_size,
|
||||||
)
|
device,
|
||||||
cu_seqlens = torch.from_numpy(cu_seqlens_np).to(
|
|
||||||
device, dtype=torch.int32, non_blocking=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
scale = 1.0 / head_size**0.5
|
scale = 1.0 / head_size**0.5
|
||||||
|
|||||||
@@ -22,6 +22,12 @@ op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
|
|||||||
op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
|
op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_oot_class_by_name(class_name: str) -> type | None:
|
||||||
|
if class_name in op_registry_oot:
|
||||||
|
return op_registry_oot[class_name]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class PluggableLayer(nn.Module):
|
class PluggableLayer(nn.Module):
|
||||||
"""
|
"""
|
||||||
Base class for pluggable layers.
|
Base class for pluggable layers.
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp, get_oot_class_by_name
|
||||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||||
from vllm.utils.math_utils import round_up
|
from vllm.utils.math_utils import round_up
|
||||||
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
|
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
|
||||||
@@ -119,17 +119,25 @@ class MMEncoderAttention(CustomOp):
|
|||||||
return max_seqlen
|
return max_seqlen
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def maybe_compute_sequence_lengths(
|
def maybe_compute_seq_lens(
|
||||||
cls,
|
cls,
|
||||||
attn_backend: AttentionBackendEnum,
|
attn_backend: AttentionBackendEnum,
|
||||||
cu_seqlens: np.ndarray,
|
cu_seqlens: np.ndarray,
|
||||||
) -> np.ndarray | None:
|
device: torch.device,
|
||||||
|
) -> torch.Tensor | None:
|
||||||
|
if (oot_class := get_oot_class_by_name(cls.__name__)) is not None:
|
||||||
|
return oot_class.maybe_compute_seq_lens(attn_backend, cu_seqlens, device) # type: ignore[attr-defined]
|
||||||
|
|
||||||
if attn_backend != AttentionBackendEnum.FLASHINFER:
|
if attn_backend != AttentionBackendEnum.FLASHINFER:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
sequence_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
sequence_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||||
sequence_lengths = add_padding_to_seqlens(
|
sequence_lengths = add_padding_to_seqlens(
|
||||||
sequence_lengths, len(sequence_lengths), 0
|
sequence_lengths, len(sequence_lengths), 0
|
||||||
)
|
)
|
||||||
|
sequence_lengths = torch.from_numpy(sequence_lengths).to(
|
||||||
|
device, non_blocking=True
|
||||||
|
)
|
||||||
return sequence_lengths
|
return sequence_lengths
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -139,24 +147,31 @@ class MMEncoderAttention(CustomOp):
|
|||||||
cu_seqlens: np.ndarray,
|
cu_seqlens: np.ndarray,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
tp_size: int,
|
tp_size: int,
|
||||||
) -> np.ndarray:
|
device: torch.device,
|
||||||
if attn_backend != AttentionBackendEnum.FLASHINFER:
|
) -> torch.Tensor:
|
||||||
return cu_seqlens
|
if (oot_class := get_oot_class_by_name(cls.__name__)) is not None:
|
||||||
|
return oot_class.maybe_recompute_cu_seqlens( # type: ignore[attr-defined]
|
||||||
|
attn_backend, cu_seqlens, hidden_size, tp_size, device
|
||||||
|
)
|
||||||
|
|
||||||
batch_size = len(cu_seqlens) - 1
|
if attn_backend == AttentionBackendEnum.FLASHINFER:
|
||||||
scale = hidden_size // tp_size
|
batch_size = len(cu_seqlens) - 1
|
||||||
cu_seqlens = cu_seqlens * scale
|
scale = hidden_size // tp_size
|
||||||
|
cu_seqlens = cu_seqlens * scale
|
||||||
|
|
||||||
cu_seqlens_qko = cu_seqlens
|
cu_seqlens_qko = cu_seqlens
|
||||||
cu_seqlens_v = cu_seqlens * 3
|
cu_seqlens_v = cu_seqlens * 3
|
||||||
|
|
||||||
cu_seqlens_qko = add_padding_to_seqlens(
|
cu_seqlens_qko = add_padding_to_seqlens(
|
||||||
cu_seqlens_qko, batch_size, cu_seqlens_qko[-1]
|
cu_seqlens_qko, batch_size, cu_seqlens_qko[-1]
|
||||||
)
|
)
|
||||||
cu_seqlens_v = add_padding_to_seqlens(
|
cu_seqlens_v = add_padding_to_seqlens(
|
||||||
cu_seqlens_v, batch_size, cu_seqlens_v[-1]
|
cu_seqlens_v, batch_size, cu_seqlens_v[-1]
|
||||||
)
|
)
|
||||||
return np.concatenate([cu_seqlens_qko, cu_seqlens_v])
|
cu_seqlens = np.concatenate([cu_seqlens_qko, cu_seqlens_v])
|
||||||
|
|
||||||
|
cu_seqlens = torch.from_numpy(cu_seqlens).to(device, non_blocking=True)
|
||||||
|
return cu_seqlens
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -983,13 +983,11 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
|||||||
grid_thw_np[:, 1] * grid_thw_np[:, 2], grid_thw_np[:, 0]
|
grid_thw_np[:, 1] * grid_thw_np[:, 2], grid_thw_np[:, 0]
|
||||||
).cumsum(axis=0, dtype=np.int32)
|
).cumsum(axis=0, dtype=np.int32)
|
||||||
cu_seqlens_np = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens_np])
|
cu_seqlens_np = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens_np])
|
||||||
sequence_lengths = MMEncoderAttention.maybe_compute_sequence_lengths(
|
sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens(
|
||||||
self.attn_backend, cu_seqlens_np
|
self.attn_backend,
|
||||||
|
cu_seqlens_np,
|
||||||
|
self.device,
|
||||||
)
|
)
|
||||||
if sequence_lengths is not None:
|
|
||||||
sequence_lengths = torch.from_numpy(sequence_lengths).to(
|
|
||||||
self.device, non_blocking=True
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states_list = []
|
hidden_states_list = []
|
||||||
deepstack_visual_indexes = self.deepstack_visual_indexes
|
deepstack_visual_indexes = self.deepstack_visual_indexes
|
||||||
|
|||||||
@@ -550,13 +550,9 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
axis=0, dtype=np.int32
|
axis=0, dtype=np.int32
|
||||||
)
|
)
|
||||||
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
|
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
|
||||||
sequence_lengths = MMEncoderAttention.maybe_compute_sequence_lengths(
|
sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens(
|
||||||
self.attn_backend, cu_seqlens
|
self.attn_backend, cu_seqlens, self.device
|
||||||
)
|
)
|
||||||
if sequence_lengths is not None:
|
|
||||||
sequence_lengths = torch.from_numpy(sequence_lengths).to(
|
|
||||||
self.device, non_blocking=True
|
|
||||||
)
|
|
||||||
max_seqlen = torch.tensor(
|
max_seqlen = torch.tensor(
|
||||||
MMEncoderAttention.compute_max_seqlen(self.attn_backend, cu_seqlens),
|
MMEncoderAttention.compute_max_seqlen(self.attn_backend, cu_seqlens),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
@@ -567,8 +563,8 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
self.tp_size,
|
self.tp_size,
|
||||||
|
self.device,
|
||||||
)
|
)
|
||||||
cu_seqlens = torch.from_numpy(cu_seqlens).to(self.device, non_blocking=True)
|
|
||||||
hidden_states = hidden_states.unsqueeze(1)
|
hidden_states = hidden_states.unsqueeze(1)
|
||||||
|
|
||||||
deepstack_feature_lists = []
|
deepstack_feature_lists = []
|
||||||
|
|||||||
Reference in New Issue
Block a user