[Attention] Add FlashInfer Sparse MLA backend (#33451)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
@@ -334,6 +334,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
block_size,
|
||||
use_mla=True,
|
||||
use_sparse=use_sparse,
|
||||
num_heads=self.num_heads,
|
||||
)
|
||||
|
||||
if (
|
||||
|
||||
@@ -129,6 +129,7 @@ class CpuPlatform(Platform):
|
||||
cls,
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
num_heads: int | None = None,
|
||||
) -> str:
|
||||
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
|
||||
@@ -45,17 +45,29 @@ torch.backends.cuda.enable_cudnn_sdp(False)
|
||||
def _get_backend_priorities(
|
||||
use_mla: bool,
|
||||
device_capability: DeviceCapability,
|
||||
num_heads: int | None = None,
|
||||
) -> list[AttentionBackendEnum]:
|
||||
"""Get backend priorities with lazy import to avoid circular dependency."""
|
||||
if use_mla:
|
||||
if device_capability.major == 10:
|
||||
# Prefer FlashInfer at low head counts (FlashMLA uses padding)
|
||||
if num_heads is not None and num_heads <= 16:
|
||||
sparse_backends = [
|
||||
AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
|
||||
AttentionBackendEnum.FLASHMLA_SPARSE,
|
||||
]
|
||||
else:
|
||||
sparse_backends = [
|
||||
AttentionBackendEnum.FLASHMLA_SPARSE,
|
||||
AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
|
||||
]
|
||||
return [
|
||||
AttentionBackendEnum.FLASHINFER_MLA,
|
||||
AttentionBackendEnum.CUTLASS_MLA,
|
||||
AttentionBackendEnum.FLASH_ATTN_MLA,
|
||||
AttentionBackendEnum.FLASHMLA,
|
||||
AttentionBackendEnum.TRITON_MLA,
|
||||
AttentionBackendEnum.FLASHMLA_SPARSE,
|
||||
*sparse_backends,
|
||||
]
|
||||
else:
|
||||
return [
|
||||
@@ -182,6 +194,8 @@ class CudaPlatformBase(Platform):
|
||||
use_flashmla = False
|
||||
use_cutlass_mla = False
|
||||
use_flashinfer_mla = False
|
||||
use_flashmla_sparse = False
|
||||
use_flashinfer_mla_sparse = False
|
||||
|
||||
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
|
||||
|
||||
@@ -217,6 +231,10 @@ class CudaPlatformBase(Platform):
|
||||
use_flashmla = backend == AttentionBackendEnum.FLASHMLA
|
||||
use_cutlass_mla = backend == AttentionBackendEnum.CUTLASS_MLA
|
||||
use_flashinfer_mla = backend == AttentionBackendEnum.FLASHINFER_MLA
|
||||
use_flashmla_sparse = backend == AttentionBackendEnum.FLASHMLA_SPARSE
|
||||
use_flashinfer_mla_sparse = (
|
||||
backend == AttentionBackendEnum.FLASHINFER_MLA_SPARSE
|
||||
)
|
||||
|
||||
if (
|
||||
use_flashmla
|
||||
@@ -242,12 +260,24 @@ class CudaPlatformBase(Platform):
|
||||
"Forcing kv cache block size to 64 for FlashInferMLA backend."
|
||||
)
|
||||
|
||||
# TODO(Chen): remove this hacky code
|
||||
if use_sparse and cache_config.block_size != 64:
|
||||
cache_config.block_size = 64
|
||||
logger.info(
|
||||
"Forcing kv cache block size to 64 for FlashMLASparse backend."
|
||||
)
|
||||
if use_sparse:
|
||||
if not (use_flashmla_sparse or use_flashinfer_mla_sparse):
|
||||
use_flashmla_sparse = True
|
||||
|
||||
if use_flashmla_sparse and cache_config.block_size != 64:
|
||||
cache_config.block_size = 64
|
||||
logger.info(
|
||||
"Forcing kv cache block size to 64 for FlashMLASparse backend."
|
||||
)
|
||||
elif use_flashinfer_mla_sparse and cache_config.block_size not in (
|
||||
32,
|
||||
64,
|
||||
):
|
||||
cache_config.block_size = 64
|
||||
logger.info(
|
||||
"Forcing kv cache block size to 64 for FlashInferMLASparse "
|
||||
"backend."
|
||||
)
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
# Note: model_config may be None during testing
|
||||
@@ -276,6 +306,7 @@ class CudaPlatformBase(Platform):
|
||||
cls,
|
||||
device_capability: DeviceCapability,
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
num_heads: int | None = None,
|
||||
) -> tuple[
|
||||
list[tuple["AttentionBackendEnum", int]],
|
||||
dict["AttentionBackendEnum", list[str]],
|
||||
@@ -284,7 +315,9 @@ class CudaPlatformBase(Platform):
|
||||
invalid_reasons = {}
|
||||
|
||||
backend_priorities = _get_backend_priorities(
|
||||
attn_selector_config.use_mla, device_capability
|
||||
attn_selector_config.use_mla,
|
||||
device_capability,
|
||||
num_heads,
|
||||
)
|
||||
for priority, backend in enumerate(backend_priorities):
|
||||
try:
|
||||
@@ -307,6 +340,7 @@ class CudaPlatformBase(Platform):
|
||||
cls,
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
num_heads: int | None = None,
|
||||
) -> str:
|
||||
device_capability = cls.get_device_capability()
|
||||
assert device_capability is not None
|
||||
@@ -336,6 +370,7 @@ class CudaPlatformBase(Platform):
|
||||
valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
|
||||
device_capability=device_capability,
|
||||
attn_selector_config=attn_selector_config,
|
||||
num_heads=num_heads,
|
||||
)
|
||||
reasons_str = (
|
||||
"{"
|
||||
|
||||
@@ -233,6 +233,7 @@ class Platform:
|
||||
cls,
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
num_heads: int | None = None,
|
||||
) -> str:
|
||||
"""Get the attention backend class of a device."""
|
||||
return ""
|
||||
|
||||
@@ -265,6 +265,7 @@ class RocmPlatform(Platform):
|
||||
cls,
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
num_heads: int | None = None,
|
||||
) -> str:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
|
||||
@@ -48,6 +48,7 @@ class XPUPlatform(Platform):
|
||||
cls,
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
num_heads: int | None = None,
|
||||
) -> str:
|
||||
from vllm.v1.attention.backends.utils import set_kv_cache_layout
|
||||
|
||||
|
||||
353
vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py
Normal file
353
vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py
Normal file
@@ -0,0 +1,353 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""FlashInfer MLA Sparse Attention Backend.
|
||||
|
||||
This backend uses the FlashInfer TRT-LLM MLA kernel with sparse_mla_top_k
|
||||
for models like DeepSeek-V3.2 that use index-based sparse attention.
|
||||
|
||||
For sparse MLA:
|
||||
- block_tables shape changes from [batch_size, max_num_blocks] (dense)
|
||||
to [batch_size, q_len_per_request, sparse_mla_top_k] (sparse)
|
||||
- The sparse indices represent physical cache slot positions to attend to
|
||||
- sparse_mla_top_k parameter must be set to the topk value
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
get_mla_dims,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
MultipleOf,
|
||||
SparseMLAAttentionImpl,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.sparse_utils import (
|
||||
triton_convert_req_index_to_global_index,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import KVCacheLayoutType
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.deepseek_v2 import Indexer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FLASHINFER_MLA_SPARSE_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
|
||||
|
||||
|
||||
class FlashInferMLASparseBackend(AttentionBackend):
|
||||
"""FlashInfer MLA backend with sparse attention support.
|
||||
|
||||
This backend uses the FlashInfer TRT-LLM MLA kernel with sparse_mla_top_k
|
||||
for models like DeepSeek-V3.2 that use index-based sparse attention.
|
||||
"""
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [32, 64]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHINFER_MLA_SPARSE"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashInferMLASparseImpl"]:
|
||||
return FlashInferMLASparseImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashInferMLASparseMetadataBuilder"]:
|
||||
return FlashInferMLASparseMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [576]
|
||||
|
||||
@classmethod
|
||||
def is_mla(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_sparse(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
# FlashInfer sparse MLA targets Blackwell (SM 10.x)
|
||||
return capability.major == 10
|
||||
|
||||
@classmethod
|
||||
def supports_combination(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: CacheDType | None,
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
device_capability: DeviceCapability,
|
||||
) -> str | None:
|
||||
# FlashInfer MLA sparse kernel requires qk_nope_head_dim == 128
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
if vllm_config.model_config is not None:
|
||||
hf_text_config = vllm_config.model_config.hf_text_config
|
||||
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
|
||||
if qk_nope_head_dim != 128:
|
||||
return (
|
||||
f"FlashInfer MLA Sparse kernel requires qk_nope_head_dim == 128, "
|
||||
f"but got {qk_nope_head_dim}"
|
||||
)
|
||||
# Check for index_topk which indicates sparse model
|
||||
if not hasattr(hf_text_config, "index_topk"):
|
||||
return "FlashInfer MLA Sparse requires model with index_topk config"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@classmethod
|
||||
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
|
||||
return "HND"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashInferMLASparseMetadata(AttentionMetadata):
|
||||
"""Attention metadata for FlashInfer MLA Sparse backend."""
|
||||
|
||||
num_reqs: int
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
num_actual_tokens: int
|
||||
|
||||
# Query start locations
|
||||
query_start_loc: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
req_id_per_token: torch.Tensor
|
||||
|
||||
# Sequence lengths for all requests (context + query)
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
# Sparse-specific
|
||||
block_size: int = 64
|
||||
topk_tokens: int = 2048
|
||||
|
||||
|
||||
class FlashInferMLASparseMetadataBuilder(
|
||||
AttentionMetadataBuilder[FlashInferMLASparseMetadata]
|
||||
):
|
||||
"""Builder for FlashInfer MLA Sparse attention metadata."""
|
||||
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
self.layer_names = layer_names
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.model_config = vllm_config.model_config
|
||||
self.device = device
|
||||
|
||||
self.mla_dims = get_mla_dims(self.model_config)
|
||||
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
|
||||
|
||||
self.req_id_per_token_buffer = torch.empty(
|
||||
(vllm_config.scheduler_config.max_num_batched_tokens,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> FlashInferMLASparseMetadata:
|
||||
cm = common_attn_metadata
|
||||
num_tokens = cm.num_actual_tokens
|
||||
|
||||
# Build req_id_per_token mapping
|
||||
starts = np.asarray(cm.query_start_loc_cpu, dtype=np.int32)
|
||||
seg_lengths = np.diff(starts)
|
||||
req_id_per_token = np.repeat(
|
||||
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
|
||||
)
|
||||
|
||||
# Zero-fill for cudagraphs
|
||||
self.req_id_per_token_buffer.fill_(0)
|
||||
self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
|
||||
torch.from_numpy(req_id_per_token), non_blocking=True
|
||||
)
|
||||
req_id_per_token_tensor = self.req_id_per_token_buffer[:num_tokens]
|
||||
|
||||
return FlashInferMLASparseMetadata(
|
||||
num_reqs=cm.num_reqs,
|
||||
max_query_len=cm.max_query_len,
|
||||
max_seq_len=cm.max_seq_len,
|
||||
num_actual_tokens=cm.num_actual_tokens,
|
||||
query_start_loc=cm.query_start_loc,
|
||||
slot_mapping=cm.slot_mapping,
|
||||
block_table=cm.block_table_tensor,
|
||||
req_id_per_token=req_id_per_token_tensor,
|
||||
seq_lens=cm.seq_lens,
|
||||
block_size=self.kv_cache_spec.block_size,
|
||||
topk_tokens=self.topk_tokens,
|
||||
)
|
||||
|
||||
|
||||
# Global workspace buffer (lazily initialized)
|
||||
_fi_sparse_workspace: torch.Tensor | None = None
|
||||
|
||||
|
||||
def _get_workspace_buffer(device: torch.device) -> torch.Tensor:
|
||||
global _fi_sparse_workspace
|
||||
if _fi_sparse_workspace is None:
|
||||
_fi_sparse_workspace = torch.zeros(
|
||||
FLASHINFER_MLA_SPARSE_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=device,
|
||||
)
|
||||
return _fi_sparse_workspace
|
||||
|
||||
|
||||
class FlashInferMLASparseImpl(SparseMLAAttentionImpl[FlashInferMLASparseMetadata]):
|
||||
"""FlashInfer MLA Sparse implementation.
|
||||
|
||||
Uses the TRT-LLM MLA kernel with sparse_mla_top_k parameter for
|
||||
sparse attention computation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
topk_indice_buffer: torch.Tensor | None = None,
|
||||
indexer: "Indexer | None" = None,
|
||||
**mla_args,
|
||||
) -> None:
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashInferMLASparseImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashInferMLASparseImpl"
|
||||
)
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
# MLA-specific dimensions
|
||||
self.kv_lora_rank: int = mla_args["kv_lora_rank"]
|
||||
self.qk_nope_head_dim: int = mla_args["qk_nope_head_dim"]
|
||||
self.qk_rope_head_dim: int = mla_args["qk_rope_head_dim"]
|
||||
|
||||
assert indexer is not None, "Indexer required for sparse MLA"
|
||||
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
|
||||
|
||||
self._workspace_buffer: torch.Tensor | None = None
|
||||
self.bmm1_scale: float | None = None
|
||||
self.bmm2_scale: float | None = None
|
||||
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashInferMLASparseMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
if isinstance(q, tuple):
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
num_actual_toks = q.shape[0]
|
||||
|
||||
assert self.topk_indices_buffer is not None
|
||||
topk_indices = self.topk_indices_buffer[:num_actual_toks]
|
||||
|
||||
topk_indices_physical, seq_lens = triton_convert_req_index_to_global_index(
|
||||
attn_metadata.req_id_per_token[:num_actual_toks],
|
||||
attn_metadata.block_table,
|
||||
topk_indices,
|
||||
BLOCK_SIZE=attn_metadata.block_size,
|
||||
NUM_TOPK_TOKENS=topk_indices.shape[1],
|
||||
return_valid_counts=True,
|
||||
)
|
||||
|
||||
if self._workspace_buffer is None:
|
||||
self._workspace_buffer = _get_workspace_buffer(q.device)
|
||||
|
||||
if self.bmm1_scale is None:
|
||||
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
|
||||
if self.bmm2_scale is None:
|
||||
self.bmm2_scale = layer._v_scale_float
|
||||
|
||||
o = trtllm_batch_decode_with_kv_cache_mla(
|
||||
query=q.unsqueeze(1),
|
||||
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
|
||||
workspace_buffer=self._workspace_buffer,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
block_tables=topk_indices_physical.unsqueeze(1),
|
||||
seq_lens=seq_lens,
|
||||
max_seq_len=attn_metadata.topk_tokens,
|
||||
bmm1_scale=self.bmm1_scale,
|
||||
bmm2_scale=self.bmm2_scale,
|
||||
sparse_mla_top_k=attn_metadata.topk_tokens,
|
||||
)
|
||||
return o.view(-1, o.shape[-2], o.shape[-1]), None
|
||||
@@ -15,7 +15,6 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
@@ -26,6 +25,9 @@ from vllm.v1.attention.backend import (
|
||||
MultipleOf,
|
||||
SparseMLAAttentionImpl,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.sparse_utils import (
|
||||
triton_convert_req_index_to_global_index,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
reshape_attn_output_for_spec_decode,
|
||||
reshape_query_for_spec_decode,
|
||||
@@ -203,166 +205,6 @@ class FlashMLASparseMetadata(AttentionMetadata):
|
||||
fp8_use_mixed_batch: bool = False
|
||||
|
||||
|
||||
# Kernel with prefill workspace support
|
||||
@triton.jit
|
||||
def _convert_req_index_to_global_index_kernel(
|
||||
req_id_ptr, # int32 [num_tokens]
|
||||
block_table_ptr, # int32 [num_requests, max_num_blocks_per_req]
|
||||
token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
prefill_request_id_ptr, # int32 [num_tokens], -1 for decode, >=0 for prefill
|
||||
workspace_starts_ptr, # int32 [num_prefill_reqs+1] or nullptr
|
||||
# shapes (compile-time where possible)
|
||||
max_num_blocks_per_req: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr, # tile width along columns
|
||||
HAS_PREFILL: tl.constexpr,
|
||||
# strides (in elements)
|
||||
bt_stride0,
|
||||
bt_stride1,
|
||||
ti_stride0,
|
||||
ti_stride1,
|
||||
out_stride0,
|
||||
out_stride1,
|
||||
):
|
||||
# program_id(0) -> token_id (row)
|
||||
# program_id(1) -> tile index along columns
|
||||
token_id = tl.program_id(0)
|
||||
tile_id = tl.program_id(1)
|
||||
|
||||
# Each program covers BLOCK_N consecutive columns
|
||||
indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
# Load request id for this token (no mask: grid is exact)
|
||||
req = tl.load(req_id_ptr + token_id)
|
||||
|
||||
# Load token indices for this tile
|
||||
ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1
|
||||
tok = tl.load(ti_ptr) # int32
|
||||
|
||||
# Only token == -1 should propagate as -1
|
||||
is_invalid_tok = tok < 0
|
||||
is_prefill = False
|
||||
if HAS_PREFILL:
|
||||
prefill_req_id = tl.load(prefill_request_id_ptr + token_id)
|
||||
is_prefill = prefill_req_id >= 0
|
||||
# Compute block id and in-block offset
|
||||
block_id = tok // BLOCK_SIZE
|
||||
inblock_off = tok % BLOCK_SIZE
|
||||
|
||||
# Guard block_table access
|
||||
valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0)
|
||||
bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1
|
||||
is_invalid_tok |= ~valid_block
|
||||
base = tl.load(bt_ptr, mask=valid_block & ~is_prefill, other=0)
|
||||
out_val = base * BLOCK_SIZE + inblock_off
|
||||
|
||||
# Override with prefill output if prefill is enabled
|
||||
if HAS_PREFILL:
|
||||
workspace_start = tl.load(
|
||||
workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0
|
||||
)
|
||||
prefill_out = workspace_start + tok
|
||||
out_val = tl.where(is_prefill, prefill_out, out_val)
|
||||
out_val = tl.where(is_invalid_tok, -1, out_val)
|
||||
|
||||
# Store results
|
||||
out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1
|
||||
tl.store(out_ptr_ij, out_val)
|
||||
|
||||
|
||||
def triton_convert_req_index_to_global_index(
|
||||
req_id: torch.Tensor, # int32 [num_tokens]
|
||||
block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req]
|
||||
token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
BLOCK_SIZE: int = 64,
|
||||
NUM_TOPK_TOKENS: int = 2048,
|
||||
BLOCK_N: int = 128, # tile width along columns
|
||||
HAS_PREFILL_WORKSPACE: bool = False,
|
||||
prefill_workspace_request_ids: torch.Tensor | None = None,
|
||||
prefill_workspace_starts: torch.Tensor | None = None,
|
||||
):
|
||||
"""
|
||||
out[token_id, indice_id] =
|
||||
block_table[req_id[token_id],
|
||||
token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE
|
||||
+ token_indices[token_id, indice_id] % BLOCK_SIZE
|
||||
|
||||
Only when token_indices[token_id, indice_id] == -1 do we output -1.
|
||||
For safety, we also output -1 if the derived block_id would be
|
||||
out-of-bounds.
|
||||
|
||||
When HAS_PREFILL_WORKSPACE is True, prefill tokens are mapped to workspace offsets
|
||||
instead of global cache slots. prefill_workspace_request_ids and
|
||||
prefill_workspace_starts must be provided.
|
||||
|
||||
prefill_workspace_request_ids: int32 [num_tokens], -1 for decode else
|
||||
prefill request index (maps to prefill_workspace_starts)
|
||||
prefill_workspace_starts: int32 [num_prefills], 0-indexed workspace
|
||||
starts for each prefill request
|
||||
"""
|
||||
assert req_id.dtype == torch.int32
|
||||
assert block_table.dtype == torch.int32
|
||||
assert token_indices.dtype == torch.int32
|
||||
assert token_indices.shape[1] == NUM_TOPK_TOKENS
|
||||
assert NUM_TOPK_TOKENS % BLOCK_N == 0, (
|
||||
f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by BLOCK_N ({BLOCK_N})"
|
||||
)
|
||||
|
||||
if HAS_PREFILL_WORKSPACE:
|
||||
assert prefill_workspace_request_ids is not None
|
||||
assert prefill_workspace_starts is not None
|
||||
assert prefill_workspace_request_ids.dtype == torch.int32
|
||||
assert prefill_workspace_starts.dtype == torch.int32
|
||||
|
||||
num_tokens = req_id.shape[0]
|
||||
max_num_blocks_per_req = block_table.shape[1]
|
||||
tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N
|
||||
|
||||
# Ensure contiguous tensors on the same device
|
||||
req_id_c = req_id.contiguous()
|
||||
block_table_c = block_table.contiguous()
|
||||
token_indices_c = token_indices.contiguous()
|
||||
out = torch.empty_like(token_indices_c)
|
||||
|
||||
# Strides in elements
|
||||
bt_stride0, bt_stride1 = block_table_c.stride()
|
||||
ti_stride0, ti_stride1 = token_indices_c.stride()
|
||||
out_stride0, out_stride1 = out.stride()
|
||||
|
||||
# Prepare prefill pointers
|
||||
if HAS_PREFILL_WORKSPACE:
|
||||
assert prefill_workspace_request_ids is not None # for mypy
|
||||
assert prefill_workspace_starts is not None # for mypy
|
||||
assert prefill_workspace_request_ids.is_contiguous()
|
||||
assert prefill_workspace_starts.is_contiguous()
|
||||
|
||||
# Exact 2D grid: tokens × column tiles
|
||||
grid = (num_tokens, tiles_per_row)
|
||||
|
||||
_convert_req_index_to_global_index_kernel[grid](
|
||||
req_id_c,
|
||||
block_table_c,
|
||||
token_indices_c,
|
||||
out,
|
||||
prefill_workspace_request_ids,
|
||||
prefill_workspace_starts,
|
||||
# shapes / constexprs
|
||||
max_num_blocks_per_req,
|
||||
BLOCK_SIZE,
|
||||
BLOCK_N,
|
||||
HAS_PREFILL_WORKSPACE,
|
||||
# strides
|
||||
bt_stride0,
|
||||
bt_stride1,
|
||||
ti_stride0,
|
||||
ti_stride1,
|
||||
out_stride0,
|
||||
out_stride1,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def get_prefill_workspace_size(max_model_len: int):
|
||||
# NOTE(Lucas): 5 is a magic number for controlling the prefill buffer size.
|
||||
# May be tuned later.
|
||||
|
||||
191
vllm/v1/attention/backends/mla/sparse_utils.py
Normal file
191
vllm/v1/attention/backends/mla/sparse_utils.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utility functions for sparse MLA backends."""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
# Kernel with prefill workspace support and valid count tracking
|
||||
@triton.jit
|
||||
def _convert_req_index_to_global_index_kernel(
|
||||
req_id_ptr, # int32 [num_tokens]
|
||||
block_table_ptr, # int32 [num_requests, max_num_blocks_per_req]
|
||||
token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
valid_count_ptr, # int32 [num_tokens] - output valid count per row
|
||||
prefill_request_id_ptr, # int32 [num_tokens], -1 for decode, >=0 for prefill
|
||||
workspace_starts_ptr, # int32 [num_prefill_reqs+1] or nullptr
|
||||
# shapes (compile-time where possible)
|
||||
max_num_blocks_per_req: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr, # tile width along columns
|
||||
HAS_PREFILL: tl.constexpr,
|
||||
COUNT_VALID: tl.constexpr, # whether to count valid indices
|
||||
# strides (in elements)
|
||||
bt_stride0,
|
||||
bt_stride1,
|
||||
ti_stride0,
|
||||
ti_stride1,
|
||||
out_stride0,
|
||||
out_stride1,
|
||||
):
|
||||
# program_id(0) -> token_id (row)
|
||||
# program_id(1) -> tile index along columns
|
||||
token_id = tl.program_id(0)
|
||||
tile_id = tl.program_id(1)
|
||||
|
||||
# Each program covers BLOCK_N consecutive columns
|
||||
indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
# Load request id for this token (no mask: grid is exact)
|
||||
req = tl.load(req_id_ptr + token_id)
|
||||
|
||||
# Load token indices for this tile
|
||||
ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1
|
||||
tok = tl.load(ti_ptr) # int32
|
||||
|
||||
# Only token == -1 should propagate as -1
|
||||
is_invalid_tok = tok < 0
|
||||
is_prefill = False
|
||||
if HAS_PREFILL:
|
||||
prefill_req_id = tl.load(prefill_request_id_ptr + token_id)
|
||||
is_prefill = prefill_req_id >= 0
|
||||
# Compute block id and in-block offset
|
||||
block_id = tok // BLOCK_SIZE
|
||||
inblock_off = tok % BLOCK_SIZE
|
||||
|
||||
# Guard block_table access
|
||||
valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0)
|
||||
bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1
|
||||
is_invalid_tok |= ~valid_block
|
||||
base = tl.load(bt_ptr, mask=valid_block & ~is_prefill, other=0)
|
||||
out_val = base * BLOCK_SIZE + inblock_off
|
||||
|
||||
# Override with prefill output if prefill is enabled
|
||||
if HAS_PREFILL:
|
||||
workspace_start = tl.load(
|
||||
workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0
|
||||
)
|
||||
prefill_out = workspace_start + tok
|
||||
out_val = tl.where(is_prefill, prefill_out, out_val)
|
||||
out_val = tl.where(is_invalid_tok, -1, out_val)
|
||||
|
||||
# Store results
|
||||
out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1
|
||||
tl.store(out_ptr_ij, out_val)
|
||||
|
||||
# Count valid indices in this tile and atomically add to row total
|
||||
if COUNT_VALID:
|
||||
tile_valid_count = tl.sum((~is_invalid_tok).to(tl.int32))
|
||||
tl.atomic_add(valid_count_ptr + token_id, tile_valid_count)
|
||||
|
||||
|
||||
def triton_convert_req_index_to_global_index(
|
||||
req_id: torch.Tensor, # int32 [num_tokens]
|
||||
block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req]
|
||||
token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
BLOCK_SIZE: int = 64,
|
||||
NUM_TOPK_TOKENS: int = 2048,
|
||||
BLOCK_N: int = 128, # tile width along columns
|
||||
HAS_PREFILL_WORKSPACE: bool = False,
|
||||
prefill_workspace_request_ids: torch.Tensor | None = None,
|
||||
prefill_workspace_starts: torch.Tensor | None = None,
|
||||
return_valid_counts: bool = False,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
out[token_id, indice_id] =
|
||||
block_table[req_id[token_id],
|
||||
token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE
|
||||
+ token_indices[token_id, indice_id] % BLOCK_SIZE
|
||||
|
||||
Only when token_indices[token_id, indice_id] == -1 do we output -1.
|
||||
For safety, we also output -1 if the derived block_id would be
|
||||
out-of-bounds.
|
||||
|
||||
When HAS_PREFILL_WORKSPACE is True, prefill tokens are mapped to workspace offsets
|
||||
instead of global cache slots. prefill_workspace_request_ids and
|
||||
prefill_workspace_starts must be provided.
|
||||
|
||||
prefill_workspace_request_ids: int32 [num_tokens], -1 for decode else
|
||||
prefill request index (maps to prefill_workspace_starts)
|
||||
prefill_workspace_starts: int32 [num_prefills], 0-indexed workspace
|
||||
starts for each prefill request
|
||||
|
||||
When return_valid_counts is True, also returns the count of valid (non -1)
|
||||
indices per row, computed during the same kernel pass (no extra overhead).
|
||||
"""
|
||||
assert req_id.dtype == torch.int32
|
||||
assert block_table.dtype == torch.int32
|
||||
assert token_indices.dtype == torch.int32
|
||||
assert token_indices.shape[1] == NUM_TOPK_TOKENS
|
||||
assert NUM_TOPK_TOKENS % BLOCK_N == 0, (
|
||||
f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by BLOCK_N ({BLOCK_N})"
|
||||
)
|
||||
|
||||
if HAS_PREFILL_WORKSPACE:
|
||||
assert prefill_workspace_request_ids is not None
|
||||
assert prefill_workspace_starts is not None
|
||||
assert prefill_workspace_request_ids.dtype == torch.int32
|
||||
assert prefill_workspace_starts.dtype == torch.int32
|
||||
|
||||
num_tokens = req_id.shape[0]
|
||||
max_num_blocks_per_req = block_table.shape[1]
|
||||
tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N
|
||||
|
||||
# Ensure contiguous tensors on the same device
|
||||
req_id_c = req_id.contiguous()
|
||||
block_table_c = block_table.contiguous()
|
||||
token_indices_c = token_indices.contiguous()
|
||||
out = torch.empty_like(token_indices_c)
|
||||
|
||||
# Allocate valid count buffer if needed (must be zero-initialized for atomics)
|
||||
valid_counts: torch.Tensor | None = None
|
||||
if return_valid_counts:
|
||||
valid_counts = torch.zeros(
|
||||
num_tokens, dtype=torch.int32, device=token_indices.device
|
||||
)
|
||||
|
||||
# Strides in elements
|
||||
bt_stride0, bt_stride1 = block_table_c.stride()
|
||||
ti_stride0, ti_stride1 = token_indices_c.stride()
|
||||
out_stride0, out_stride1 = out.stride()
|
||||
|
||||
# Prepare prefill pointers
|
||||
if HAS_PREFILL_WORKSPACE:
|
||||
assert prefill_workspace_request_ids is not None # for mypy
|
||||
assert prefill_workspace_starts is not None # for mypy
|
||||
assert prefill_workspace_request_ids.is_contiguous()
|
||||
assert prefill_workspace_starts.is_contiguous()
|
||||
|
||||
# Exact 2D grid: tokens × column tiles
|
||||
grid = (num_tokens, tiles_per_row)
|
||||
|
||||
_convert_req_index_to_global_index_kernel[grid](
|
||||
req_id_c,
|
||||
block_table_c,
|
||||
token_indices_c,
|
||||
out,
|
||||
valid_counts,
|
||||
prefill_workspace_request_ids,
|
||||
prefill_workspace_starts,
|
||||
# shapes / constexprs
|
||||
max_num_blocks_per_req,
|
||||
BLOCK_SIZE,
|
||||
BLOCK_N,
|
||||
HAS_PREFILL_WORKSPACE,
|
||||
return_valid_counts,
|
||||
# strides
|
||||
bt_stride0,
|
||||
bt_stride1,
|
||||
ti_stride0,
|
||||
ti_stride1,
|
||||
out_stride0,
|
||||
out_stride1,
|
||||
)
|
||||
|
||||
if return_valid_counts:
|
||||
assert valid_counts is not None
|
||||
return out, valid_counts
|
||||
return out
|
||||
@@ -62,6 +62,10 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
|
||||
FLASHINFER_MLA = (
|
||||
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
|
||||
)
|
||||
FLASHINFER_MLA_SPARSE = (
|
||||
"vllm.v1.attention.backends.mla.flashinfer_mla_sparse."
|
||||
"FlashInferMLASparseBackend"
|
||||
)
|
||||
TRITON_MLA = "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
|
||||
CUTLASS_MLA = "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
|
||||
FLASHMLA = "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
|
||||
|
||||
@@ -53,6 +53,7 @@ def get_attn_backend(
|
||||
use_sparse: bool = False,
|
||||
use_mm_prefix: bool = False,
|
||||
attn_type: str | None = None,
|
||||
num_heads: int | None = None,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
|
||||
@@ -66,7 +67,6 @@ def get_attn_backend(
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
backend_enum = vllm_config.attention_config.backend
|
||||
|
||||
attn_selector_config = AttentionSelectorConfig(
|
||||
head_size=head_size,
|
||||
@@ -81,8 +81,9 @@ def get_attn_backend(
|
||||
)
|
||||
|
||||
return _cached_get_attn_backend(
|
||||
backend=backend_enum,
|
||||
backend=vllm_config.attention_config.backend,
|
||||
attn_selector_config=attn_selector_config,
|
||||
num_heads=num_heads,
|
||||
)
|
||||
|
||||
|
||||
@@ -90,12 +91,14 @@ def get_attn_backend(
|
||||
def _cached_get_attn_backend(
|
||||
backend,
|
||||
attn_selector_config: AttentionSelectorConfig,
|
||||
num_heads: int | None = None,
|
||||
) -> type[AttentionBackend]:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
backend,
|
||||
attn_selector_config=attn_selector_config,
|
||||
num_heads=num_heads,
|
||||
)
|
||||
if not attention_cls:
|
||||
raise ValueError(
|
||||
|
||||
Reference in New Issue
Block a user