[Attention] Add FlashInfer Sparse MLA backend (#33451)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
@@ -43,6 +43,7 @@ from common import (
|
||||
ModelParameterSweep,
|
||||
ParameterSweep,
|
||||
ResultsFormatter,
|
||||
batch_spec_sort_key,
|
||||
is_mla_backend,
|
||||
)
|
||||
|
||||
@@ -218,10 +219,13 @@ def run_model_parameter_sweep(
|
||||
by_param_and_spec[key].append(r)
|
||||
break
|
||||
|
||||
# Sort by param value then spec
|
||||
# Sort by param value then spec (batch_size, q_len, kv_len)
|
||||
sorted_keys = sorted(
|
||||
by_param_and_spec.keys(),
|
||||
key=lambda x: (int(x[0]) if x[0].isdigit() else x[0], x[1]),
|
||||
key=lambda x: (
|
||||
int(x[0]) if x[0].isdigit() else x[0],
|
||||
batch_spec_sort_key(x[1]),
|
||||
),
|
||||
)
|
||||
|
||||
current_param_value = None
|
||||
@@ -330,7 +334,7 @@ def run_parameter_sweep(
|
||||
by_spec[spec] = []
|
||||
by_spec[spec].append(r)
|
||||
|
||||
for spec in sorted(by_spec.keys()):
|
||||
for spec in sorted(by_spec.keys(), key=batch_spec_sort_key):
|
||||
results = by_spec[spec]
|
||||
best = min(results, key=lambda r: r.mean_time)
|
||||
console.print(
|
||||
@@ -496,15 +500,18 @@ def main():
|
||||
if "description" in yaml_config:
|
||||
console.print(f"[dim]{yaml_config['description']}[/]")
|
||||
|
||||
# Override args with YAML values
|
||||
# (YAML takes precedence unless CLI arg was explicitly set)
|
||||
# Backend(s)
|
||||
if "backend" in yaml_config:
|
||||
args.backend = yaml_config["backend"]
|
||||
args.backends = None
|
||||
elif "backends" in yaml_config:
|
||||
args.backends = yaml_config["backends"]
|
||||
args.backend = None
|
||||
# Override args with YAML values, but CLI args take precedence
|
||||
# Check if CLI provided backends (they would be non-None and not default)
|
||||
cli_backends_provided = args.backends is not None or args.backend is not None
|
||||
|
||||
# Backend(s) - only use YAML if CLI didn't specify
|
||||
if not cli_backends_provided:
|
||||
if "backend" in yaml_config:
|
||||
args.backend = yaml_config["backend"]
|
||||
args.backends = None
|
||||
elif "backends" in yaml_config:
|
||||
args.backends = yaml_config["backends"]
|
||||
args.backend = None
|
||||
|
||||
# Check for special modes
|
||||
if "mode" in yaml_config:
|
||||
@@ -544,13 +551,15 @@ def main():
|
||||
args.num_kv_heads = model.get("num_kv_heads", args.num_kv_heads)
|
||||
args.block_size = model.get("block_size", args.block_size)
|
||||
|
||||
# Benchmark settings
|
||||
if "benchmark" in yaml_config:
|
||||
bench = yaml_config["benchmark"]
|
||||
args.device = bench.get("device", args.device)
|
||||
args.repeats = bench.get("repeats", args.repeats)
|
||||
args.warmup_iters = bench.get("warmup_iters", args.warmup_iters)
|
||||
args.profile_memory = bench.get("profile_memory", args.profile_memory)
|
||||
# Benchmark settings (top-level keys)
|
||||
if "device" in yaml_config:
|
||||
args.device = yaml_config["device"]
|
||||
if "repeats" in yaml_config:
|
||||
args.repeats = yaml_config["repeats"]
|
||||
if "warmup_iters" in yaml_config:
|
||||
args.warmup_iters = yaml_config["warmup_iters"]
|
||||
if "profile_memory" in yaml_config:
|
||||
args.profile_memory = yaml_config["profile_memory"]
|
||||
|
||||
# Parameter sweep configuration
|
||||
if "parameter_sweep" in yaml_config:
|
||||
|
||||
@@ -16,13 +16,32 @@ from batch_spec import get_batch_type, parse_batch_spec
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
|
||||
def batch_spec_sort_key(spec: str) -> tuple[int, int, int]:
|
||||
"""
|
||||
Extract sorting key from batch spec: (batch_size, max_q_len, max_kv_len).
|
||||
|
||||
This ensures results are sorted by batch size first, then query length,
|
||||
then sequence length, rather than alphabetically.
|
||||
"""
|
||||
try:
|
||||
requests = parse_batch_spec(spec)
|
||||
batch_size = len(requests)
|
||||
max_q_len = max(r.q_len for r in requests) if requests else 0
|
||||
max_kv_len = max(r.kv_len for r in requests) if requests else 0
|
||||
return (batch_size, max_q_len, max_kv_len)
|
||||
except Exception:
|
||||
# Fallback for unparseable specs
|
||||
return (0, 0, 0)
|
||||
|
||||
|
||||
# Mock classes for vLLM attention infrastructure
|
||||
|
||||
|
||||
class MockHfConfig:
|
||||
"""Mock HuggingFace config that satisfies vLLM's requirements."""
|
||||
|
||||
def __init__(self, mla_dims: dict):
|
||||
def __init__(self, mla_dims: dict, index_topk: int | None = None):
|
||||
self.num_attention_heads = mla_dims["num_q_heads"]
|
||||
self.num_key_value_heads = mla_dims["num_kv_heads"]
|
||||
self.hidden_size = mla_dims["head_dim"] * mla_dims["num_q_heads"]
|
||||
@@ -33,6 +52,8 @@ class MockHfConfig:
|
||||
self.qk_rope_head_dim = mla_dims["qk_rope_head_dim"]
|
||||
self.v_head_dim = mla_dims["v_head_dim"]
|
||||
self.qk_head_dim = mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"]
|
||||
if index_topk is not None:
|
||||
self.index_topk = index_topk
|
||||
|
||||
def get_text_config(self):
|
||||
return self
|
||||
@@ -83,6 +104,38 @@ class MockKVBProj:
|
||||
return (result,) # Return as tuple to match ColumnParallelLinear API
|
||||
|
||||
|
||||
class MockIndexer:
|
||||
"""Mock Indexer for sparse MLA backends.
|
||||
|
||||
Provides topk_indices_buffer that sparse MLA backends use to determine
|
||||
which KV cache slots to attend to for each token.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_tokens: int,
|
||||
topk_tokens: int,
|
||||
device: torch.device,
|
||||
):
|
||||
self.topk_tokens = topk_tokens
|
||||
self.topk_indices_buffer = torch.zeros(
|
||||
(max_num_tokens, topk_tokens),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def fill_random_indices(self, num_tokens: int, max_kv_len: int):
|
||||
"""Fill topk_indices_buffer with random valid indices for benchmarking."""
|
||||
indices = torch.randint(
|
||||
0,
|
||||
max_kv_len,
|
||||
(num_tokens, self.topk_tokens),
|
||||
dtype=torch.int32,
|
||||
device=self.topk_indices_buffer.device,
|
||||
)
|
||||
self.topk_indices_buffer[:num_tokens] = indices
|
||||
|
||||
|
||||
class MockLayer(AttentionLayerBase):
|
||||
"""Mock attention layer with scale parameters and impl.
|
||||
|
||||
@@ -327,6 +380,9 @@ class ResultsFormatter:
|
||||
specs_order.append(spec)
|
||||
by_spec[spec][r.config.backend] = r
|
||||
|
||||
# Sort specs by (batch_size, q_len, kv_len) instead of alphabetically
|
||||
specs_order = sorted(by_spec.keys(), key=batch_spec_sort_key)
|
||||
|
||||
# Create shortened backend names for display
|
||||
def shorten_backend_name(name: str) -> str:
|
||||
"""Shorten long backend names for table display."""
|
||||
@@ -493,10 +549,11 @@ def get_attention_scale(head_dim: int) -> float:
|
||||
|
||||
def is_mla_backend(backend: str) -> bool:
|
||||
"""
|
||||
Check if backend is an MLA backend using the backend's is_mla() property.
|
||||
Check if backend is an MLA backend using the AttentionBackendEnum.
|
||||
|
||||
Args:
|
||||
backend: Backend name (e.g., "CUTLASS_MLA", "FLASHINFER_MLA")
|
||||
backend: Backend name matching AttentionBackendEnum exactly
|
||||
(e.g., "FLASHMLA_SPARSE")
|
||||
|
||||
Returns:
|
||||
True if the backend is an MLA backend, False otherwise
|
||||
@@ -504,7 +561,8 @@ def is_mla_backend(backend: str) -> bool:
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
try:
|
||||
backend_class = AttentionBackendEnum[backend.upper()].get_class()
|
||||
backend_enum = AttentionBackendEnum[backend]
|
||||
backend_class = backend_enum.get_class()
|
||||
return backend_class.is_mla()
|
||||
except (KeyError, ValueError, ImportError):
|
||||
except (KeyError, ValueError, ImportError, AttributeError):
|
||||
return False
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
model:
|
||||
name: "deepseek-v3"
|
||||
num_layers: 60
|
||||
num_q_heads: 128
|
||||
num_q_heads: 128 # Base value, can be swept for TP simulation
|
||||
num_kv_heads: 1 # MLA uses single latent KV
|
||||
head_dim: 576
|
||||
kv_lora_rank: 512
|
||||
@@ -12,6 +12,13 @@ model:
|
||||
v_head_dim: 128
|
||||
block_size: 128 # CUTLASS MLA and FlashAttn MLA use 128
|
||||
|
||||
# Model parameter sweep: simulate tensor parallelism by varying num_q_heads
|
||||
# TP=1: 128 heads, TP=2: 64 heads, TP=4: 32 heads, TP=8: 16 heads
|
||||
model_parameter_sweep:
|
||||
param_name: "num_q_heads"
|
||||
values: [128, 64, 32, 16]
|
||||
label_format: "{backend}_{value}h"
|
||||
|
||||
batch_specs:
|
||||
# Small batches, varying sequence lengths
|
||||
- "16q1s512" # 16 requests, 512 KV cache
|
||||
@@ -34,28 +41,30 @@ batch_specs:
|
||||
# Very large batches
|
||||
- "128q1s1k" # 128 requests, 1k KV cache
|
||||
- "128q1s2k" # 128 requests, 2k KV cache
|
||||
- "128q1s4k" # 128 requests, 4k KV cache
|
||||
- "128q1s8k" # 128 requests, 8k KV cache
|
||||
|
||||
# Long context
|
||||
- "32q1s16k" # 32 requests, 16k KV cache
|
||||
- "32q1s32k" # 32 requests, 32k KV cache
|
||||
|
||||
backends:
|
||||
- cutlass_mla
|
||||
- flashinfer_mla
|
||||
- flashattn_mla # Hopper only
|
||||
- flashmla # Hopper only
|
||||
- CUTLASS_MLA
|
||||
- FLASHINFER_MLA
|
||||
- FLASH_ATTN_MLA # Hopper only
|
||||
- FLASHMLA # Hopper only
|
||||
|
||||
device: "cuda:0"
|
||||
repeats: 5
|
||||
warmup_iters: 3
|
||||
repeats: 100
|
||||
warmup_iters: 10
|
||||
profile_memory: true
|
||||
|
||||
# Backend-specific tuning
|
||||
cutlass_mla:
|
||||
CUTLASS_MLA:
|
||||
num_kv_splits: auto # or specific value like 4, 8, 16
|
||||
|
||||
flashattn_mla:
|
||||
FLASH_ATTN_MLA:
|
||||
reorder_batch_threshold: 512
|
||||
|
||||
flashmla:
|
||||
FLASHMLA:
|
||||
reorder_batch_threshold: 1
|
||||
|
||||
@@ -45,10 +45,10 @@ batch_specs:
|
||||
- "4q4k_60q1s4k" # 4 prefill + 60 decode
|
||||
|
||||
backends:
|
||||
- cutlass_mla
|
||||
- flashinfer_mla
|
||||
- flashattn_mla # Hopper only
|
||||
- flashmla # Hopper only
|
||||
- CUTLASS_MLA
|
||||
- FLASHINFER_MLA
|
||||
- FLASH_ATTN_MLA # Hopper only
|
||||
- FLASHMLA # Hopper only
|
||||
|
||||
device: "cuda:0"
|
||||
repeats: 5
|
||||
|
||||
62
benchmarks/attention_benchmarks/configs/mla_prefill.yaml
Normal file
62
benchmarks/attention_benchmarks/configs/mla_prefill.yaml
Normal file
@@ -0,0 +1,62 @@
|
||||
# MLA prefill-only benchmark configuration for sparse backends
|
||||
|
||||
model:
|
||||
name: "deepseek-v3"
|
||||
num_layers: 60
|
||||
num_q_heads: 128
|
||||
num_kv_heads: 1
|
||||
head_dim: 576
|
||||
kv_lora_rank: 512
|
||||
qk_nope_head_dim: 128
|
||||
qk_rope_head_dim: 64
|
||||
v_head_dim: 128
|
||||
block_size: 128
|
||||
|
||||
# Model parameter sweep: simulate tensor parallelism by varying num_q_heads
|
||||
# TP=1: 128 heads, TP=2: 64 heads, TP=4: 32 heads, TP=8: 16 heads
|
||||
model_parameter_sweep:
|
||||
param_name: "num_q_heads"
|
||||
values: [128, 64, 32, 16]
|
||||
label_format: "{backend}_{value}h"
|
||||
|
||||
batch_specs:
|
||||
# Pure prefill
|
||||
- "1q512"
|
||||
- "1q1k"
|
||||
- "1q2k"
|
||||
- "1q4k"
|
||||
- "1q8k"
|
||||
|
||||
# Batched pure prefill
|
||||
- "2q512"
|
||||
- "2q1k"
|
||||
- "2q2k"
|
||||
- "2q4k"
|
||||
- "2q8k"
|
||||
- "4q512"
|
||||
- "4q1k"
|
||||
- "4q2k"
|
||||
- "4q4k"
|
||||
- "4q8k"
|
||||
- "8q512"
|
||||
- "8q1k"
|
||||
- "8q2k"
|
||||
- "8q4k"
|
||||
- "8q8k"
|
||||
|
||||
# Extend
|
||||
- "1q512s4k"
|
||||
- "1q512s8k"
|
||||
- "1q1ks8k"
|
||||
- "1q2ks8k"
|
||||
- "1q2ks16k"
|
||||
- "1q4ks16k"
|
||||
|
||||
backends:
|
||||
- FLASHMLA_SPARSE
|
||||
- FLASHINFER_MLA_SPARSE
|
||||
|
||||
device: "cuda:0"
|
||||
repeats: 10
|
||||
warmup_iters: 3
|
||||
profile_memory: true
|
||||
@@ -6,7 +6,7 @@
|
||||
description: "Decode vs Prefill pipeline crossover analysis"
|
||||
|
||||
# Test FlashAttn MLA
|
||||
backend: flashattn_mla
|
||||
backend: FLASH_ATTN_MLA
|
||||
|
||||
# Mode: decode_vs_prefill comparison (special sweep mode)
|
||||
# For each batch spec, we'll test both decode and prefill pipelines
|
||||
@@ -62,11 +62,10 @@ model:
|
||||
block_size: 128
|
||||
|
||||
# Benchmark settings
|
||||
benchmark:
|
||||
device: "cuda:0"
|
||||
repeats: 15 # More repeats for spec decode variance
|
||||
warmup_iters: 5
|
||||
profile_memory: false
|
||||
device: "cuda:0"
|
||||
repeats: 15 # More repeats for spec decode variance
|
||||
warmup_iters: 5
|
||||
profile_memory: false
|
||||
|
||||
# Output
|
||||
output:
|
||||
|
||||
@@ -41,18 +41,17 @@ batch_specs:
|
||||
|
||||
# Backends that support query length > 1
|
||||
backends:
|
||||
- flashattn_mla # reorder_batch_threshold = 512
|
||||
- flashmla # reorder_batch_threshold = 1 (tunable)
|
||||
- FLASH_ATTN_MLA # reorder_batch_threshold = 512
|
||||
- FLASHMLA # reorder_batch_threshold = 1 (tunable)
|
||||
|
||||
# FlashInfer-MLA also supports uniform spec-as-decode but with different mechanism
|
||||
# - flashinfer_mla
|
||||
# - FLASHINFER_MLA
|
||||
|
||||
# Benchmark settings
|
||||
benchmark:
|
||||
device: "cuda:0"
|
||||
repeats: 10 # More repeats for statistical significance
|
||||
warmup_iters: 5
|
||||
profile_memory: false
|
||||
device: "cuda:0"
|
||||
repeats: 10 # More repeats for statistical significance
|
||||
warmup_iters: 5
|
||||
profile_memory: false
|
||||
|
||||
# Test these threshold values for optimization
|
||||
parameter_sweep:
|
||||
|
||||
@@ -36,11 +36,11 @@ batch_specs:
|
||||
- "q1ks2k" # 1k query, 2k sequence
|
||||
- "2q1ks4k" # 2 requests: 1k query, 4k sequence
|
||||
|
||||
# Available backends: flash, triton, flashinfer
|
||||
# Available backends: FLASH_ATTN, TRITON_ATTN, FLASHINFER
|
||||
backends:
|
||||
- flash
|
||||
- triton
|
||||
- flashinfer
|
||||
- FLASH_ATTN
|
||||
- TRITON_ATTN
|
||||
- FLASHINFER
|
||||
|
||||
device: "cuda:0"
|
||||
repeats: 5
|
||||
|
||||
@@ -8,14 +8,13 @@ This module provides helpers for running MLA backends without
|
||||
needing full VllmConfig integration.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from batch_spec import parse_batch_spec
|
||||
from common import (
|
||||
BenchmarkResult,
|
||||
MockHfConfig,
|
||||
MockIndexer,
|
||||
MockKVBProj,
|
||||
MockLayer,
|
||||
setup_mla_dims,
|
||||
@@ -62,6 +61,7 @@ def create_minimal_vllm_config(
|
||||
block_size: int = 128,
|
||||
max_num_seqs: int = 256,
|
||||
mla_dims: dict | None = None,
|
||||
index_topk: int | None = None,
|
||||
) -> VllmConfig:
|
||||
"""
|
||||
Create minimal VllmConfig for MLA benchmarks.
|
||||
@@ -73,6 +73,8 @@ def create_minimal_vllm_config(
|
||||
max_num_seqs: Maximum number of sequences
|
||||
mla_dims: Optional custom MLA dimensions dict. If not provided, uses
|
||||
setup_mla_dims(model_name)
|
||||
index_topk: Optional topk value for sparse MLA backends. If provided,
|
||||
the config will include index_topk for sparse attention.
|
||||
|
||||
Returns:
|
||||
VllmConfig for benchmarking
|
||||
@@ -82,7 +84,7 @@ def create_minimal_vllm_config(
|
||||
mla_dims = setup_mla_dims(model_name)
|
||||
|
||||
# Create mock HF config first (avoids downloading from HuggingFace)
|
||||
mock_hf_config = MockHfConfig(mla_dims)
|
||||
mock_hf_config = MockHfConfig(mla_dims, index_topk=index_topk)
|
||||
|
||||
# Create a temporary minimal config.json to avoid HF downloads
|
||||
# This ensures consistent ModelConfig construction without network access
|
||||
@@ -120,16 +122,12 @@ def create_minimal_vllm_config(
|
||||
seed=0,
|
||||
max_model_len=32768,
|
||||
quantization=None,
|
||||
quantization_param_path=None,
|
||||
enforce_eager=False,
|
||||
max_context_len_to_capture=None,
|
||||
max_seq_len_to_capture=8192,
|
||||
max_logprobs=20,
|
||||
disable_sliding_window=False,
|
||||
skip_tokenizer_init=True,
|
||||
served_model_name=None,
|
||||
limit_mm_per_prompt=None,
|
||||
use_async_output_proc=True,
|
||||
config_format="auto",
|
||||
)
|
||||
finally:
|
||||
@@ -180,56 +178,65 @@ def create_minimal_vllm_config(
|
||||
# ============================================================================
|
||||
|
||||
|
||||
# Backend name to class name prefix mapping
|
||||
_BACKEND_NAME_MAP = {
|
||||
"flashattn_mla": "FlashAttnMLA",
|
||||
"flashmla": "FlashMLA",
|
||||
"flashinfer_mla": "FlashInferMLA",
|
||||
"cutlass_mla": "CutlassMLA",
|
||||
}
|
||||
|
||||
# Special properties that differ from defaults
|
||||
# Backend-specific properties that can't be inferred from the backend class
|
||||
# Keys are AttentionBackendEnum names (uppercase)
|
||||
_BACKEND_PROPERTIES = {
|
||||
"flashmla": {
|
||||
"FLASHMLA": {
|
||||
"query_format": "concat", # Single concatenated tensor (vs tuple)
|
||||
"block_size": 64, # FlashMLA uses fixed block size
|
||||
},
|
||||
"flashinfer_mla": {
|
||||
"block_size": 64, # FlashInfer MLA only supports 32 or 64
|
||||
"FLASHMLA_SPARSE": {
|
||||
"query_format": "concat", # Single concatenated tensor (vs tuple)
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _get_backend_config(backend: str) -> dict:
|
||||
"""
|
||||
Get backend configuration using naming conventions.
|
||||
Get backend configuration from AttentionBackendEnum.
|
||||
|
||||
All MLA backends follow the pattern:
|
||||
- Module: vllm.v1.attention.backends.mla.{backend}
|
||||
- Impl: {Name}Impl
|
||||
- Metadata: {Name}Metadata (or MLACommonMetadata)
|
||||
- DecodeMetadata: {Name}DecodeMetadata (or MLACommonDecodeMetadata)
|
||||
- MetadataBuilder: {Name}MetadataBuilder
|
||||
Uses the registry to get the backend class and extract configuration
|
||||
from its methods (get_impl_cls, get_builder_cls, is_sparse, etc.).
|
||||
|
||||
Args:
|
||||
backend: Backend name matching AttentionBackendEnum exactly
|
||||
(e.g., "FLASHMLA_SPARSE")
|
||||
|
||||
Returns:
|
||||
Dict with backend configuration
|
||||
"""
|
||||
if backend not in _BACKEND_NAME_MAP:
|
||||
raise ValueError(f"Unknown backend: {backend}")
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
name = _BACKEND_NAME_MAP[backend]
|
||||
try:
|
||||
backend_enum = AttentionBackendEnum[backend]
|
||||
backend_class = backend_enum.get_class()
|
||||
except (KeyError, ValueError) as e:
|
||||
valid_backends = [e.name for e in AttentionBackendEnum if e.name != "CUSTOM"]
|
||||
raise ValueError(
|
||||
f"Unknown backend: {backend}. "
|
||||
f"Valid MLA backends: {[b for b in valid_backends if 'MLA' in b]}"
|
||||
) from e
|
||||
|
||||
# Get block size from backend class
|
||||
block_sizes = backend_class.get_supported_kernel_block_sizes()
|
||||
# Use first supported block size (backends typically support one for MLA)
|
||||
block_size = block_sizes[0] if block_sizes else None
|
||||
if hasattr(block_size, "value"):
|
||||
# Handle MultipleOf enum
|
||||
block_size = None
|
||||
|
||||
# Check if sparse via class method if available
|
||||
is_sparse = getattr(backend_class, "is_sparse", lambda: False)()
|
||||
|
||||
# Get properties that can't be inferred
|
||||
props = _BACKEND_PROPERTIES.get(backend, {})
|
||||
|
||||
# Check if backend uses common metadata (FlashInfer, CUTLASS)
|
||||
uses_common = backend in ("flashinfer_mla", "cutlass_mla")
|
||||
|
||||
return {
|
||||
"module": f"vllm.v1.attention.backends.mla.{backend}",
|
||||
"impl_class": f"{name}Impl",
|
||||
"metadata_class": "MLACommonMetadata" if uses_common else f"{name}Metadata",
|
||||
"decode_metadata_class": "MLACommonDecodeMetadata"
|
||||
if uses_common
|
||||
else f"{name}DecodeMetadata",
|
||||
"builder_class": f"{name}MetadataBuilder",
|
||||
"backend_class": backend_class,
|
||||
"impl_class": backend_class.get_impl_cls(),
|
||||
"builder_class": backend_class.get_builder_cls(),
|
||||
"query_format": props.get("query_format", "tuple"),
|
||||
"block_size": props.get("block_size", None),
|
||||
"block_size": block_size,
|
||||
"is_sparse": is_sparse,
|
||||
}
|
||||
|
||||
|
||||
@@ -447,22 +454,26 @@ def _create_backend_impl(
|
||||
mla_dims: dict,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
max_num_tokens: int = 8192,
|
||||
index_topk: int | None = None,
|
||||
):
|
||||
"""
|
||||
Create backend implementation instance.
|
||||
|
||||
Args:
|
||||
backend_cfg: Backend configuration dict
|
||||
backend_cfg: Backend configuration dict from _get_backend_config()
|
||||
mla_dims: MLA dimension configuration
|
||||
vllm_config: VllmConfig instance
|
||||
device: Target device
|
||||
max_num_tokens: Maximum number of tokens for sparse indexer buffer
|
||||
index_topk: Topk value for sparse MLA backends
|
||||
|
||||
Returns:
|
||||
Tuple of (impl, layer, builder_instance)
|
||||
Tuple of (impl, layer, builder_instance, indexer)
|
||||
"""
|
||||
# Import backend classes
|
||||
backend_module = importlib.import_module(backend_cfg["module"])
|
||||
impl_class = getattr(backend_module, backend_cfg["impl_class"])
|
||||
# Get classes from backend config (already resolved by _get_backend_config)
|
||||
impl_class = backend_cfg["impl_class"]
|
||||
builder_class = backend_cfg["builder_class"]
|
||||
|
||||
# Calculate scale
|
||||
scale = 1.0 / np.sqrt(mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"])
|
||||
@@ -474,26 +485,44 @@ def _create_backend_impl(
|
||||
v_head_dim=mla_dims["v_head_dim"],
|
||||
)
|
||||
|
||||
# Create indexer for sparse backends
|
||||
indexer = None
|
||||
if backend_cfg.get("is_sparse", False):
|
||||
if index_topk is None:
|
||||
index_topk = 2048 # Default topk for sparse MLA
|
||||
indexer = MockIndexer(
|
||||
max_num_tokens=max_num_tokens,
|
||||
topk_tokens=index_topk,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Build impl kwargs
|
||||
impl_kwargs = {
|
||||
"num_heads": mla_dims["num_q_heads"],
|
||||
"head_size": mla_dims["head_dim"],
|
||||
"scale": scale,
|
||||
"num_kv_heads": mla_dims["num_kv_heads"],
|
||||
"alibi_slopes": None,
|
||||
"sliding_window": None,
|
||||
"kv_cache_dtype": "auto",
|
||||
"logits_soft_cap": None,
|
||||
"attn_type": "decoder",
|
||||
"kv_sharing_target_layer_name": None,
|
||||
"q_lora_rank": None,
|
||||
"kv_lora_rank": mla_dims["kv_lora_rank"],
|
||||
"qk_nope_head_dim": mla_dims["qk_nope_head_dim"],
|
||||
"qk_rope_head_dim": mla_dims["qk_rope_head_dim"],
|
||||
"qk_head_dim": mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"],
|
||||
"v_head_dim": mla_dims["v_head_dim"],
|
||||
"kv_b_proj": mock_kv_b_proj,
|
||||
}
|
||||
|
||||
# Add indexer for sparse backends
|
||||
if indexer is not None:
|
||||
impl_kwargs["indexer"] = indexer
|
||||
|
||||
# Create impl
|
||||
impl = impl_class(
|
||||
num_heads=mla_dims["num_q_heads"],
|
||||
head_size=mla_dims["head_dim"],
|
||||
scale=scale,
|
||||
num_kv_heads=mla_dims["num_kv_heads"],
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype="auto",
|
||||
logits_soft_cap=None,
|
||||
attn_type="decoder",
|
||||
kv_sharing_target_layer_name=None,
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=mla_dims["kv_lora_rank"],
|
||||
qk_nope_head_dim=mla_dims["qk_nope_head_dim"],
|
||||
qk_rope_head_dim=mla_dims["qk_rope_head_dim"],
|
||||
qk_head_dim=mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"],
|
||||
v_head_dim=mla_dims["v_head_dim"],
|
||||
kv_b_proj=mock_kv_b_proj,
|
||||
)
|
||||
impl = impl_class(**impl_kwargs)
|
||||
|
||||
# Initialize DCP attributes
|
||||
if not hasattr(impl, "dcp_world_size") or impl.dcp_world_size in (None, -1):
|
||||
@@ -515,9 +544,7 @@ def _create_backend_impl(
|
||||
|
||||
# Create builder instance if needed
|
||||
builder_instance = None
|
||||
if backend_cfg["builder_class"]:
|
||||
builder_class = getattr(backend_module, backend_cfg["builder_class"])
|
||||
|
||||
if builder_class:
|
||||
# Populate static_forward_context so builder can find the layer
|
||||
# MockLayer inherits from AttentionLayerBase, so isinstance checks pass
|
||||
vllm_config.compilation_config.static_forward_context = {"placeholder": layer}
|
||||
@@ -529,7 +556,7 @@ def _create_backend_impl(
|
||||
device=device,
|
||||
)
|
||||
|
||||
return impl, layer, builder_instance
|
||||
return impl, layer, builder_instance, indexer
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -594,6 +621,7 @@ def _run_single_benchmark(
|
||||
backend_cfg: dict,
|
||||
mla_dims: dict,
|
||||
device: torch.device,
|
||||
indexer=None,
|
||||
) -> BenchmarkResult:
|
||||
"""
|
||||
Run a single benchmark iteration.
|
||||
@@ -606,6 +634,7 @@ def _run_single_benchmark(
|
||||
backend_cfg: Backend configuration dict
|
||||
mla_dims: MLA dimension configuration
|
||||
device: Target device
|
||||
indexer: Optional MockIndexer for sparse backends
|
||||
|
||||
Returns:
|
||||
BenchmarkResult with timing statistics
|
||||
@@ -613,7 +642,9 @@ def _run_single_benchmark(
|
||||
# Parse batch spec
|
||||
requests = parse_batch_spec(config.batch_spec)
|
||||
q_lens = [r.q_len for r in requests]
|
||||
kv_lens = [r.kv_len for r in requests]
|
||||
total_q = sum(q_lens)
|
||||
max_kv_len = max(kv_lens)
|
||||
|
||||
# Determine block size
|
||||
block_size = backend_cfg["block_size"] or config.block_size
|
||||
@@ -641,8 +672,16 @@ def _run_single_benchmark(
|
||||
torch.bfloat16,
|
||||
)
|
||||
|
||||
# Determine which forward method to use based on metadata
|
||||
if metadata.decode is not None:
|
||||
# Fill indexer with random indices for sparse backends
|
||||
is_sparse = backend_cfg.get("is_sparse", False)
|
||||
if is_sparse and indexer is not None:
|
||||
indexer.fill_random_indices(total_q, max_kv_len)
|
||||
|
||||
# Determine which forward method to use
|
||||
if is_sparse:
|
||||
# Sparse backends use forward_mqa
|
||||
forward_fn = lambda: impl.forward_mqa(decode_inputs, kv_cache, metadata, layer)
|
||||
elif metadata.decode is not None:
|
||||
forward_fn = lambda: impl._forward_decode(
|
||||
decode_inputs, kv_cache, metadata, layer
|
||||
)
|
||||
@@ -693,11 +732,13 @@ def _run_single_benchmark(
|
||||
def _run_mla_benchmark_batched(
|
||||
backend: str,
|
||||
configs_with_params: list[tuple], # [(config, threshold, num_splits), ...]
|
||||
index_topk: int = 2048,
|
||||
) -> list[BenchmarkResult]:
|
||||
"""
|
||||
Unified batched MLA benchmark runner for all backends.
|
||||
|
||||
Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla
|
||||
Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla,
|
||||
flashinfer_mla_sparse, flashmla_sparse
|
||||
|
||||
This function reuses backend initialization across multiple benchmarks
|
||||
to avoid setup/teardown overhead.
|
||||
@@ -707,6 +748,7 @@ def _run_mla_benchmark_batched(
|
||||
configs_with_params: List of (config, threshold, num_splits) tuples
|
||||
- threshold: reorder_batch_threshold (FlashAttn/FlashMLA only)
|
||||
- num_splits: num_kv_splits (CUTLASS only)
|
||||
index_topk: Topk value for sparse MLA backends (default 2048)
|
||||
|
||||
Returns:
|
||||
List of BenchmarkResult objects
|
||||
@@ -730,19 +772,27 @@ def _run_mla_benchmark_batched(
|
||||
if mla_dims is None:
|
||||
mla_dims = setup_mla_dims("deepseek-v3")
|
||||
|
||||
# Determine if this is a sparse backend
|
||||
is_sparse = backend_cfg.get("is_sparse", False)
|
||||
|
||||
# Create and set vLLM config for MLA (reused across all benchmarks)
|
||||
vllm_config = create_minimal_vllm_config(
|
||||
model_name="deepseek-v3", # Used only for model path
|
||||
block_size=block_size,
|
||||
mla_dims=mla_dims, # Use custom dims from config or default
|
||||
index_topk=index_topk if is_sparse else None,
|
||||
)
|
||||
|
||||
results = []
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
# Create backend impl, layer, and builder (reused across benchmarks)
|
||||
impl, layer, builder_instance = _create_backend_impl(
|
||||
backend_cfg, mla_dims, vllm_config, device
|
||||
# Create backend impl, layer, builder, and indexer (reused across benchmarks)
|
||||
impl, layer, builder_instance, indexer = _create_backend_impl(
|
||||
backend_cfg,
|
||||
mla_dims,
|
||||
vllm_config,
|
||||
device,
|
||||
index_topk=index_topk if is_sparse else None,
|
||||
)
|
||||
|
||||
# Run each benchmark with the shared impl
|
||||
@@ -768,6 +818,7 @@ def _run_mla_benchmark_batched(
|
||||
backend_cfg,
|
||||
mla_dims,
|
||||
device,
|
||||
indexer=indexer,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
@@ -793,20 +844,24 @@ def run_mla_benchmark(
|
||||
config,
|
||||
reorder_batch_threshold: int | None = None,
|
||||
num_kv_splits: int | None = None,
|
||||
index_topk: int = 2048,
|
||||
) -> BenchmarkResult | list[BenchmarkResult]:
|
||||
"""
|
||||
Unified MLA benchmark runner for all backends.
|
||||
|
||||
Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla
|
||||
Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla,
|
||||
flashinfer_mla_sparse, flashmla_sparse
|
||||
|
||||
Always uses batched execution internally for optimal performance.
|
||||
|
||||
Args:
|
||||
backend: Backend name (flashattn_mla, flashmla, flashinfer_mla, cutlass_mla)
|
||||
backend: Backend name (flashattn_mla, flashmla, flashinfer_mla, cutlass_mla,
|
||||
flashinfer_mla_sparse, flashmla_sparse)
|
||||
config: BenchmarkConfig or list of (BenchmarkConfig, param) tuples
|
||||
reorder_batch_threshold: Threshold override for FlashAttn/FlashMLA
|
||||
(single config mode only)
|
||||
num_kv_splits: Number of KV splits for CUTLASS (single config mode only)
|
||||
index_topk: Topk value for sparse MLA backends (default 2048)
|
||||
|
||||
Returns:
|
||||
BenchmarkResult (single mode) or list of BenchmarkResult (batched mode)
|
||||
@@ -816,9 +871,9 @@ def run_mla_benchmark(
|
||||
# Already in batched format
|
||||
if len(config) > 0 and isinstance(config[0], tuple):
|
||||
# Format: [(cfg, param), ...] where param is threshold or num_splits
|
||||
if backend in ("flashattn_mla", "flashmla"):
|
||||
if backend in ("flashattn_mla", "flashmla", "flashmla_sparse"):
|
||||
configs_with_params = [(cfg, param, None) for cfg, param in config]
|
||||
else: # cutlass_mla or flashinfer_mla
|
||||
else: # cutlass_mla, flashinfer_mla, or sparse backends
|
||||
configs_with_params = [(cfg, None, param) for cfg, param in config]
|
||||
else:
|
||||
# Format: [cfg, ...] - just configs
|
||||
@@ -830,7 +885,7 @@ def run_mla_benchmark(
|
||||
return_single = True
|
||||
|
||||
# Use unified batched execution
|
||||
results = _run_mla_benchmark_batched(backend, configs_with_params)
|
||||
results = _run_mla_benchmark_batched(backend, configs_with_params, index_topk)
|
||||
|
||||
# Return single result or list based on input
|
||||
return results[0] if return_single else results
|
||||
|
||||
@@ -40,29 +40,29 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
# ============================================================================
|
||||
|
||||
|
||||
_BACKEND_CONFIG = {
|
||||
"flash": {
|
||||
"module": "vllm.v1.attention.backends.flash_attn",
|
||||
"backend_class": "FlashAttentionBackend",
|
||||
},
|
||||
"triton": {
|
||||
"module": "vllm.v1.attention.backends.triton_attn",
|
||||
"backend_class": "TritonAttentionBackend",
|
||||
},
|
||||
"flashinfer": {
|
||||
"module": "vllm.v1.attention.backends.flashinfer",
|
||||
"backend_class": "FlashInferBackend",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _get_backend_config(backend: str) -> dict:
|
||||
if backend not in _BACKEND_CONFIG:
|
||||
"""
|
||||
Get backend configuration from AttentionBackendEnum.
|
||||
|
||||
Args:
|
||||
backend: Backend name matching AttentionBackendEnum exactly
|
||||
(e.g., "FLASH_ATTN", "TRITON_ATTN", "FLASHINFER")
|
||||
|
||||
Returns:
|
||||
Dict with backend_class
|
||||
"""
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
try:
|
||||
backend_enum = AttentionBackendEnum[backend]
|
||||
backend_class = backend_enum.get_class()
|
||||
except (KeyError, ValueError) as e:
|
||||
valid_backends = [b.name for b in AttentionBackendEnum if b.name != "CUSTOM"]
|
||||
raise ValueError(
|
||||
f"Unknown backend: {backend}. "
|
||||
f"Available: {', '.join(_BACKEND_CONFIG.keys())}"
|
||||
)
|
||||
return _BACKEND_CONFIG[backend]
|
||||
f"Unknown backend: {backend}. Valid backends: {valid_backends}"
|
||||
) from e
|
||||
|
||||
return {"backend_class": backend_class}
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -205,10 +205,7 @@ def _create_backend_impl(
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Create backend implementation instance."""
|
||||
import importlib
|
||||
|
||||
backend_module = importlib.import_module(backend_cfg["module"])
|
||||
backend_class = getattr(backend_module, backend_cfg["backend_class"])
|
||||
backend_class = backend_cfg["backend_class"]
|
||||
|
||||
scale = get_attention_scale(config.head_dim)
|
||||
|
||||
@@ -247,7 +244,7 @@ def _create_metadata_builder(
|
||||
|
||||
# Flashinfer needs get_per_layer_parameters mocked since we don't have
|
||||
# real model layers registered
|
||||
if backend_name == "flashinfer":
|
||||
if backend_name == "FLASHINFER":
|
||||
import unittest.mock
|
||||
|
||||
from vllm.v1.attention.backends.utils import PerLayerParameters
|
||||
@@ -438,7 +435,7 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
|
||||
"""
|
||||
Run standard attention benchmark with real kernels.
|
||||
|
||||
Supports: flash, triton, flashinfer
|
||||
Supports: FLASH_ATTN, TRITON_ATTN, FLASHINFER
|
||||
|
||||
Args:
|
||||
config: Benchmark configuration
|
||||
@@ -453,7 +450,7 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
|
||||
|
||||
requests = parse_batch_spec(config.batch_spec)
|
||||
|
||||
if config.backend == "flashinfer":
|
||||
if config.backend == "FLASHINFER":
|
||||
requests = reorder_for_flashinfer(requests)
|
||||
|
||||
q_lens = [r.q_len for r in requests]
|
||||
|
||||
@@ -128,6 +128,7 @@ Priority is **1 = highest** (tried first).
|
||||
| 4 | `FLASHMLA` |
|
||||
| 5 | `TRITON_MLA` |
|
||||
| 6 | `FLASHMLA_SPARSE` |
|
||||
| 7 | `FLASHINFER_MLA_SPARSE` |
|
||||
|
||||
**Ampere/Hopper (SM 8.x-9.x):**
|
||||
|
||||
@@ -204,6 +205,7 @@ configuration.
|
||||
|---------|--------|-----------|-------------|------------|------|--------|-----------|-----|-----------------|--------------|
|
||||
| `CUTLASS_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 128 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 10.x |
|
||||
| `FLASHINFER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x |
|
||||
| `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x |
|
||||
| `FLASHMLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x |
|
||||
| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
|
||||
| `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x |
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Unit tests for the FlashMLA sparse backend utilities."""
|
||||
"""Unit tests for the sparse MLA backends and utilities."""
|
||||
|
||||
import math
|
||||
from types import MethodType, SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -25,6 +24,9 @@ from vllm.config import set_current_vllm_config
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.flashinfer_mla_sparse import (
|
||||
FlashInferMLASparseBackend,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
FlashMLASparseBackend,
|
||||
triton_convert_req_index_to_global_index,
|
||||
@@ -156,32 +158,48 @@ def _quantize_dequantize_fp8_ds_mla(
|
||||
return dequant_kv_c, dequant_k_pe
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.get_device_capability() < (9, 0),
|
||||
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
|
||||
@pytest.mark.parametrize(
|
||||
"backend_cls",
|
||||
[FlashMLASparseBackend, FlashInferMLASparseBackend],
|
||||
ids=["FlashMLA", "FlashInfer"],
|
||||
)
|
||||
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_ds_mla"])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
|
||||
@pytest.mark.parametrize("block_size", [32, 64])
|
||||
def test_sparse_backend_decode_correctness(
|
||||
default_vllm_config,
|
||||
dist_init,
|
||||
backend_cls,
|
||||
batch_name,
|
||||
kv_cache_dtype,
|
||||
tensor_parallel_size,
|
||||
block_size,
|
||||
workspace_init,
|
||||
):
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip("ROCm does not support fp8_ds_mla data type for kv cache.")
|
||||
if kv_cache_dtype not in backend_cls.supported_kv_cache_dtypes:
|
||||
pytest.skip(f"{backend_cls.get_name()} does not support {kv_cache_dtype}")
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is required for sparse MLA decode test")
|
||||
supported_block_sizes = backend_cls.get_supported_kernel_block_sizes()
|
||||
if block_size not in supported_block_sizes:
|
||||
pytest.skip(
|
||||
f"{backend_cls.get_name()} does not support block_size={block_size}"
|
||||
)
|
||||
|
||||
if backend_cls == FlashMLASparseBackend:
|
||||
ok, reason = flashmla.is_flashmla_sparse_supported()
|
||||
if not ok:
|
||||
pytest.skip(reason)
|
||||
elif backend_cls == FlashInferMLASparseBackend:
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip("FlashInferMLASparseBackend requires SM 10.0 or higher")
|
||||
|
||||
batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name]
|
||||
use_fp8_ds_mla_quantization = kv_cache_dtype == "fp8_ds_mla"
|
||||
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
|
||||
batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name]
|
||||
|
||||
# Model hyper-parameters (kept intentionally small for the unit test)
|
||||
total_num_heads = 128
|
||||
# Compute per-rank heads for simulated TP
|
||||
@@ -192,11 +210,10 @@ def test_sparse_backend_decode_correctness(
|
||||
qk_rope_head_dim = 64
|
||||
v_head_dim = 128
|
||||
head_size = kv_lora_rank + qk_rope_head_dim
|
||||
topk_tokens = 2048
|
||||
topk_tokens = 128
|
||||
|
||||
max_seqlen = max(batch_spec.seq_lens)
|
||||
total_cache_tokens = sum(batch_spec.seq_lens)
|
||||
block_size = 64
|
||||
|
||||
# Note: We use TP=1 to avoid multi-GPU requirements in CI.
|
||||
# The test simulates head partitioning via mocked methods below.
|
||||
@@ -247,11 +264,55 @@ def test_sparse_backend_decode_correctness(
|
||||
seq_lens = batch_spec.seq_lens
|
||||
query_lens = batch_spec.query_lens
|
||||
|
||||
# Pre-compute positions and sparse indices for all tokens.
|
||||
# We need these BEFORE computing the reference to use sparse attention masks.
|
||||
total_query_tokens = sum(query_lens)
|
||||
positions = []
|
||||
for i in range(batch_spec.batch_size):
|
||||
s_len = seq_lens[i]
|
||||
q_len = query_lens[i]
|
||||
ctx_len = s_len - q_len
|
||||
for q_idx in range(q_len):
|
||||
positions.append(ctx_len + q_idx)
|
||||
|
||||
# Create sparse indices with UNIQUE per-token offsets to catch bugs where
|
||||
# the kernel uses wrong indices for some tokens (e.g., due to incorrect
|
||||
# tensor shapes like [1, num_tokens, ...] instead of [num_tokens, 1, ...]).
|
||||
# Also include -1 masked indices to verify the kernel handles them correctly.
|
||||
sparse_indices = torch.empty(
|
||||
total_query_tokens, topk_tokens, dtype=torch.int32, device=device
|
||||
)
|
||||
for tok_idx in range(total_query_tokens):
|
||||
max_valid_idx = positions[tok_idx]
|
||||
offset = tok_idx * 7 # Prime number for varied offsets
|
||||
# Use only half the topk indices as valid, mask the rest with -1
|
||||
# This tests that the kernel correctly ignores -1 indices
|
||||
num_valid = min(topk_tokens // 2, max_valid_idx + 1)
|
||||
if num_valid > 0:
|
||||
valid_range = torch.arange(num_valid, device=device, dtype=torch.int32)
|
||||
tok_indices = (valid_range + offset) % (max_valid_idx + 1)
|
||||
# Pad with -1 for the remaining positions
|
||||
tok_indices = torch.cat(
|
||||
[
|
||||
tok_indices,
|
||||
torch.full(
|
||||
(topk_tokens - num_valid,), -1, device=device, dtype=torch.int32
|
||||
),
|
||||
]
|
||||
)
|
||||
else:
|
||||
tok_indices = torch.full(
|
||||
(topk_tokens,), -1, device=device, dtype=torch.int32
|
||||
)
|
||||
tok_indices[0] = 0 # At least one valid index
|
||||
sparse_indices[tok_idx] = tok_indices
|
||||
|
||||
all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], []
|
||||
kv_c_contexts, k_pe_contexts = [], []
|
||||
reference_outputs = []
|
||||
|
||||
kv_cache_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
global_token_idx = 0
|
||||
|
||||
for i in range(batch_spec.batch_size):
|
||||
s_len = seq_lens[i]
|
||||
@@ -268,40 +329,53 @@ def test_sparse_backend_decode_correctness(
|
||||
kv_c_full = torch.rand(s_len, kv_lora_rank, dtype=dtype, device=device)
|
||||
k_pe_full = torch.rand(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device)
|
||||
|
||||
# SM100 (Blackwell) uses float -> e8m0 -> bf16 scale conversion
|
||||
# which truncates scales to powers of 2. Simulate this in reference.
|
||||
is_sm100 = torch.cuda.get_device_capability()[0] >= 10
|
||||
kv_c_full, k_pe_full = _quantize_dequantize_fp8_ds_mla(
|
||||
kv_c_full,
|
||||
k_pe_full.squeeze(1),
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
scale=kv_cache_scale,
|
||||
simulate_sm100_e8m0_scales=is_sm100,
|
||||
)
|
||||
if use_fp8_ds_mla_quantization:
|
||||
is_sm100 = torch.cuda.get_device_capability()[0] >= 10
|
||||
kv_c_full, k_pe_squeezed = _quantize_dequantize_fp8_ds_mla(
|
||||
kv_c_full,
|
||||
k_pe_full.squeeze(1),
|
||||
block_size=block_size,
|
||||
scale=kv_cache_scale,
|
||||
simulate_sm100_e8m0_scales=is_sm100,
|
||||
)
|
||||
k_pe_full = k_pe_squeezed.unsqueeze(1)
|
||||
|
||||
q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
|
||||
ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, W_UK)
|
||||
q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
|
||||
|
||||
k_mqa = torch.cat([kv_c_full, k_pe_full], dim=-1)
|
||||
k_mqa = k_mqa.unsqueeze(1).expand(-1, num_heads, -1)
|
||||
v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_heads, -1)
|
||||
k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1)
|
||||
v_mqa = kv_c_full
|
||||
|
||||
attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
|
||||
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
|
||||
attn_mask[:, ctx_len:] = causal_mask
|
||||
# Compute sparse SDPA reference per query token using its sparse indices
|
||||
for q_idx in range(q_len):
|
||||
tok_sparse_idx = sparse_indices[global_token_idx]
|
||||
valid_mask = tok_sparse_idx >= 0
|
||||
valid_indices = tok_sparse_idx[valid_mask].long()
|
||||
|
||||
q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
|
||||
k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
|
||||
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
|
||||
q_tok = q_mqa[q_idx : q_idx + 1] # [1, num_heads, head_dim]
|
||||
k_sparse = k_mqa[valid_indices] # [num_valid, head_dim]
|
||||
v_sparse = v_mqa[valid_indices] # [num_valid, kv_lora_rank]
|
||||
|
||||
sdpa_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale
|
||||
)
|
||||
sdpa_out = sdpa_out.transpose(1, 2).squeeze(0)
|
||||
k_sparse = k_sparse.unsqueeze(1).expand(-1, num_heads, -1)
|
||||
v_sparse = v_sparse.unsqueeze(1).expand(-1, num_heads, -1)
|
||||
|
||||
sdpa_out = torch.einsum("qnl,lnv->qnv", sdpa_out, W_UV)
|
||||
reference_outputs.append(sdpa_out.flatten(start_dim=-2))
|
||||
# SDPA: [1, num_heads, 1, head_dim] x [1, num_heads, num_valid, head_dim]
|
||||
q_sdpa_in = q_tok.unsqueeze(0).transpose(1, 2)
|
||||
k_sdpa_in = k_sparse.unsqueeze(0).transpose(1, 2)
|
||||
v_sdpa_in = v_sparse.unsqueeze(0).transpose(1, 2)
|
||||
|
||||
sdpa_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, scale=scale
|
||||
)
|
||||
sdpa_out = sdpa_out.transpose(1, 2).squeeze(
|
||||
0
|
||||
) # [1, num_heads, kv_lora_rank]
|
||||
|
||||
sdpa_out = torch.einsum("qnl,lnv->qnv", sdpa_out, W_UV)
|
||||
reference_outputs.append(sdpa_out.flatten(start_dim=-2))
|
||||
|
||||
global_token_idx += 1
|
||||
|
||||
all_q_vllm.append(q_c)
|
||||
all_kv_c_vllm.append(kv_c_full[ctx_len:])
|
||||
@@ -334,42 +408,18 @@ def test_sparse_backend_decode_correctness(
|
||||
num_blocks=vllm_config.cache_config.num_gpu_blocks,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
randomize_blocks=False,
|
||||
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
||||
kv_cache_dtype=kv_cache_dtype if use_fp8_ds_mla_quantization else "auto",
|
||||
scale=kv_cache_scale,
|
||||
)
|
||||
|
||||
builder_cls = FlashMLASparseBackend.get_builder_cls()
|
||||
builder_cls = backend_cls.get_builder_cls()
|
||||
builder = builder_cls(kv_cache_spec, ["placeholder"], vllm_config, device)
|
||||
metadata = builder.build(
|
||||
common_prefix_len=0, common_attn_metadata=common_attn_metadata
|
||||
)
|
||||
|
||||
starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
|
||||
seg_lengths = np.diff(starts)
|
||||
positions = np.arange(starts[-1], dtype=np.int32) - np.repeat(
|
||||
starts[:-1], seg_lengths
|
||||
)
|
||||
seq_lengths = np.asarray(common_attn_metadata.seq_lens.cpu(), dtype=np.int32)
|
||||
prefix_lengths = seq_lengths - seg_lengths
|
||||
positions += np.repeat(prefix_lengths, seg_lengths)
|
||||
|
||||
pos_gpu = torch.as_tensor(positions, device=device, dtype=torch.int32)
|
||||
topk = metadata.topk_tokens
|
||||
debug_indices = torch.arange(topk, device=device, dtype=torch.int32).unsqueeze(0)
|
||||
token_positions = pos_gpu.unsqueeze(1)
|
||||
causal_mask = debug_indices <= token_positions
|
||||
debug_indices = torch.where(
|
||||
causal_mask, debug_indices, torch.full_like(debug_indices, -1)
|
||||
)
|
||||
|
||||
# FlashMLASparseImpl now reads top-k indices from the indexer-provided
|
||||
# buffer, so emulate that contract with a simple namespace mock.
|
||||
debug_indices = debug_indices.expand(metadata.num_actual_tokens, -1).clone()
|
||||
mock_indexer = SimpleNamespace(topk_indices_buffer=debug_indices)
|
||||
|
||||
ok, reason = flashmla.is_flashmla_sparse_supported()
|
||||
if not ok:
|
||||
pytest.skip(reason)
|
||||
# Use the pre-computed sparse_indices for the mock indexer
|
||||
mock_indexer = SimpleNamespace(topk_indices_buffer=sparse_indices)
|
||||
|
||||
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
@@ -383,7 +433,7 @@ def test_sparse_backend_decode_correctness(
|
||||
).to(device=device, dtype=dtype)
|
||||
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous())
|
||||
|
||||
impl_cls = FlashMLASparseBackend.get_impl_cls()
|
||||
impl_cls = backend_cls.get_impl_cls()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
impl = impl_cls(
|
||||
num_heads=num_heads,
|
||||
@@ -441,7 +491,7 @@ def test_sparse_backend_decode_correctness(
|
||||
|
||||
# FP8 quantization introduces some error, but should be within reasonable bounds
|
||||
# BF16 (auto) should be very accurate, FP8 allows slightly more tolerance
|
||||
if kv_cache_dtype == "fp8_ds_mla":
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.05, atol=0.05)
|
||||
else:
|
||||
torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.01, atol=0.01)
|
||||
@@ -636,3 +686,63 @@ def test_triton_convert_req_index_to_global_index_with_prefill_workspace(block_s
|
||||
def test_split_prefill_chunks(seq_lens, max_buf, expected):
|
||||
out = split_prefill_chunks(seq_lens, max_buf)
|
||||
assert out == expected
|
||||
|
||||
|
||||
def test_triton_convert_returns_valid_counts():
|
||||
"""Test that return_valid_counts correctly counts non-negative indices."""
|
||||
device = torch.device("cuda")
|
||||
num_tokens = 8
|
||||
num_requests = 2
|
||||
max_blocks_per_req = 10
|
||||
block_size = 64
|
||||
num_topk_tokens = 128
|
||||
|
||||
req_id = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1], dtype=torch.int32, device=device)
|
||||
block_table = torch.arange(
|
||||
num_requests * max_blocks_per_req, dtype=torch.int32, device=device
|
||||
).view(num_requests, max_blocks_per_req)
|
||||
|
||||
# Create token indices with varying numbers of valid entries
|
||||
# Token 0: 64 valid, 64 invalid (-1)
|
||||
# Token 1: 32 valid, 96 invalid
|
||||
# Token 2: 128 valid (all)
|
||||
# Token 3: 1 valid, 127 invalid
|
||||
# etc.
|
||||
token_indices = torch.full(
|
||||
(num_tokens, num_topk_tokens), -1, dtype=torch.int32, device=device
|
||||
)
|
||||
expected_valid = []
|
||||
for i in range(num_tokens):
|
||||
num_valid = [64, 32, 128, 1, 64, 32, 128, 1][i]
|
||||
token_indices[i, :num_valid] = torch.arange(
|
||||
num_valid, dtype=torch.int32, device=device
|
||||
) % (block_size * max_blocks_per_req)
|
||||
expected_valid.append(num_valid)
|
||||
|
||||
expected_valid_tensor = torch.tensor(
|
||||
expected_valid, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# Test with return_valid_counts=True
|
||||
result, valid_counts = triton_convert_req_index_to_global_index(
|
||||
req_id,
|
||||
block_table,
|
||||
token_indices,
|
||||
BLOCK_SIZE=block_size,
|
||||
NUM_TOPK_TOKENS=num_topk_tokens,
|
||||
return_valid_counts=True,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(valid_counts, expected_valid_tensor, rtol=0, atol=0)
|
||||
|
||||
# Test that return_valid_counts=False returns only the indices
|
||||
result_only = triton_convert_req_index_to_global_index(
|
||||
req_id,
|
||||
block_table,
|
||||
token_indices,
|
||||
BLOCK_SIZE=block_size,
|
||||
NUM_TOPK_TOKENS=num_topk_tokens,
|
||||
return_valid_counts=False,
|
||||
)
|
||||
assert isinstance(result_only, torch.Tensor)
|
||||
torch.testing.assert_close(result_only, result, rtol=0, atol=0)
|
||||
|
||||
@@ -901,10 +901,50 @@ def parse_cuda_priority_lists() -> dict[str, list[str]]:
|
||||
|
||||
|
||||
def _get_backends_from_return(stmts: list) -> list[str]:
|
||||
"""Extract backend names from return statements in a list of statements."""
|
||||
"""Extract backend names from return statements in a list of statements.
|
||||
|
||||
Handles starred unpacking (e.g. ``*sparse_backends``) by resolving the
|
||||
variable from assignments found in the same statement list. When the
|
||||
variable is conditionally assigned (inside an ``if/else``), the ``else``
|
||||
branch value is used as the representative default.
|
||||
"""
|
||||
# Collect variable assignments so we can resolve starred expressions.
|
||||
# For conditional assignments, last-written (else branch) wins.
|
||||
var_assigns: dict[str, list[str]] = {}
|
||||
for stmt in stmts:
|
||||
if isinstance(stmt, ast.Assign) and isinstance(stmt.value, ast.List):
|
||||
for target in stmt.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
var_assigns[target.id] = [
|
||||
e.attr for e in stmt.value.elts if isinstance(e, ast.Attribute)
|
||||
]
|
||||
elif isinstance(stmt, ast.If):
|
||||
for branch in (stmt.body, stmt.orelse):
|
||||
for branch_stmt in branch:
|
||||
if isinstance(branch_stmt, ast.Assign) and isinstance(
|
||||
branch_stmt.value, ast.List
|
||||
):
|
||||
for target in branch_stmt.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
var_assigns[target.id] = [
|
||||
e.attr
|
||||
for e in branch_stmt.value.elts
|
||||
if isinstance(e, ast.Attribute)
|
||||
]
|
||||
|
||||
for stmt in stmts:
|
||||
if isinstance(stmt, ast.Return) and isinstance(stmt.value, ast.List):
|
||||
return [e.attr for e in stmt.value.elts if isinstance(e, ast.Attribute)]
|
||||
backends: list[str] = []
|
||||
for e in stmt.value.elts:
|
||||
if isinstance(e, ast.Attribute):
|
||||
backends.append(e.attr)
|
||||
elif (
|
||||
isinstance(e, ast.Starred)
|
||||
and isinstance(e.value, ast.Name)
|
||||
and e.value.id in var_assigns
|
||||
):
|
||||
backends.extend(var_assigns[e.value.id])
|
||||
return backends
|
||||
return []
|
||||
|
||||
|
||||
|
||||
@@ -334,6 +334,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
block_size,
|
||||
use_mla=True,
|
||||
use_sparse=use_sparse,
|
||||
num_heads=self.num_heads,
|
||||
)
|
||||
|
||||
if (
|
||||
|
||||
@@ -129,6 +129,7 @@ class CpuPlatform(Platform):
|
||||
cls,
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
num_heads: int | None = None,
|
||||
) -> str:
|
||||
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
|
||||
@@ -45,17 +45,29 @@ torch.backends.cuda.enable_cudnn_sdp(False)
|
||||
def _get_backend_priorities(
|
||||
use_mla: bool,
|
||||
device_capability: DeviceCapability,
|
||||
num_heads: int | None = None,
|
||||
) -> list[AttentionBackendEnum]:
|
||||
"""Get backend priorities with lazy import to avoid circular dependency."""
|
||||
if use_mla:
|
||||
if device_capability.major == 10:
|
||||
# Prefer FlashInfer at low head counts (FlashMLA uses padding)
|
||||
if num_heads is not None and num_heads <= 16:
|
||||
sparse_backends = [
|
||||
AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
|
||||
AttentionBackendEnum.FLASHMLA_SPARSE,
|
||||
]
|
||||
else:
|
||||
sparse_backends = [
|
||||
AttentionBackendEnum.FLASHMLA_SPARSE,
|
||||
AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
|
||||
]
|
||||
return [
|
||||
AttentionBackendEnum.FLASHINFER_MLA,
|
||||
AttentionBackendEnum.CUTLASS_MLA,
|
||||
AttentionBackendEnum.FLASH_ATTN_MLA,
|
||||
AttentionBackendEnum.FLASHMLA,
|
||||
AttentionBackendEnum.TRITON_MLA,
|
||||
AttentionBackendEnum.FLASHMLA_SPARSE,
|
||||
*sparse_backends,
|
||||
]
|
||||
else:
|
||||
return [
|
||||
@@ -182,6 +194,8 @@ class CudaPlatformBase(Platform):
|
||||
use_flashmla = False
|
||||
use_cutlass_mla = False
|
||||
use_flashinfer_mla = False
|
||||
use_flashmla_sparse = False
|
||||
use_flashinfer_mla_sparse = False
|
||||
|
||||
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
|
||||
|
||||
@@ -217,6 +231,10 @@ class CudaPlatformBase(Platform):
|
||||
use_flashmla = backend == AttentionBackendEnum.FLASHMLA
|
||||
use_cutlass_mla = backend == AttentionBackendEnum.CUTLASS_MLA
|
||||
use_flashinfer_mla = backend == AttentionBackendEnum.FLASHINFER_MLA
|
||||
use_flashmla_sparse = backend == AttentionBackendEnum.FLASHMLA_SPARSE
|
||||
use_flashinfer_mla_sparse = (
|
||||
backend == AttentionBackendEnum.FLASHINFER_MLA_SPARSE
|
||||
)
|
||||
|
||||
if (
|
||||
use_flashmla
|
||||
@@ -242,12 +260,24 @@ class CudaPlatformBase(Platform):
|
||||
"Forcing kv cache block size to 64 for FlashInferMLA backend."
|
||||
)
|
||||
|
||||
# TODO(Chen): remove this hacky code
|
||||
if use_sparse and cache_config.block_size != 64:
|
||||
cache_config.block_size = 64
|
||||
logger.info(
|
||||
"Forcing kv cache block size to 64 for FlashMLASparse backend."
|
||||
)
|
||||
if use_sparse:
|
||||
if not (use_flashmla_sparse or use_flashinfer_mla_sparse):
|
||||
use_flashmla_sparse = True
|
||||
|
||||
if use_flashmla_sparse and cache_config.block_size != 64:
|
||||
cache_config.block_size = 64
|
||||
logger.info(
|
||||
"Forcing kv cache block size to 64 for FlashMLASparse backend."
|
||||
)
|
||||
elif use_flashinfer_mla_sparse and cache_config.block_size not in (
|
||||
32,
|
||||
64,
|
||||
):
|
||||
cache_config.block_size = 64
|
||||
logger.info(
|
||||
"Forcing kv cache block size to 64 for FlashInferMLASparse "
|
||||
"backend."
|
||||
)
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
# Note: model_config may be None during testing
|
||||
@@ -276,6 +306,7 @@ class CudaPlatformBase(Platform):
|
||||
cls,
|
||||
device_capability: DeviceCapability,
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
num_heads: int | None = None,
|
||||
) -> tuple[
|
||||
list[tuple["AttentionBackendEnum", int]],
|
||||
dict["AttentionBackendEnum", list[str]],
|
||||
@@ -284,7 +315,9 @@ class CudaPlatformBase(Platform):
|
||||
invalid_reasons = {}
|
||||
|
||||
backend_priorities = _get_backend_priorities(
|
||||
attn_selector_config.use_mla, device_capability
|
||||
attn_selector_config.use_mla,
|
||||
device_capability,
|
||||
num_heads,
|
||||
)
|
||||
for priority, backend in enumerate(backend_priorities):
|
||||
try:
|
||||
@@ -307,6 +340,7 @@ class CudaPlatformBase(Platform):
|
||||
cls,
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
num_heads: int | None = None,
|
||||
) -> str:
|
||||
device_capability = cls.get_device_capability()
|
||||
assert device_capability is not None
|
||||
@@ -336,6 +370,7 @@ class CudaPlatformBase(Platform):
|
||||
valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
|
||||
device_capability=device_capability,
|
||||
attn_selector_config=attn_selector_config,
|
||||
num_heads=num_heads,
|
||||
)
|
||||
reasons_str = (
|
||||
"{"
|
||||
|
||||
@@ -233,6 +233,7 @@ class Platform:
|
||||
cls,
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
num_heads: int | None = None,
|
||||
) -> str:
|
||||
"""Get the attention backend class of a device."""
|
||||
return ""
|
||||
|
||||
@@ -265,6 +265,7 @@ class RocmPlatform(Platform):
|
||||
cls,
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
num_heads: int | None = None,
|
||||
) -> str:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
|
||||
@@ -48,6 +48,7 @@ class XPUPlatform(Platform):
|
||||
cls,
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
num_heads: int | None = None,
|
||||
) -> str:
|
||||
from vllm.v1.attention.backends.utils import set_kv_cache_layout
|
||||
|
||||
|
||||
353
vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py
Normal file
353
vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py
Normal file
@@ -0,0 +1,353 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""FlashInfer MLA Sparse Attention Backend.
|
||||
|
||||
This backend uses the FlashInfer TRT-LLM MLA kernel with sparse_mla_top_k
|
||||
for models like DeepSeek-V3.2 that use index-based sparse attention.
|
||||
|
||||
For sparse MLA:
|
||||
- block_tables shape changes from [batch_size, max_num_blocks] (dense)
|
||||
to [batch_size, q_len_per_request, sparse_mla_top_k] (sparse)
|
||||
- The sparse indices represent physical cache slot positions to attend to
|
||||
- sparse_mla_top_k parameter must be set to the topk value
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
get_mla_dims,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
MultipleOf,
|
||||
SparseMLAAttentionImpl,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.sparse_utils import (
|
||||
triton_convert_req_index_to_global_index,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import KVCacheLayoutType
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.deepseek_v2 import Indexer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FLASHINFER_MLA_SPARSE_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
|
||||
|
||||
|
||||
class FlashInferMLASparseBackend(AttentionBackend):
|
||||
"""FlashInfer MLA backend with sparse attention support.
|
||||
|
||||
This backend uses the FlashInfer TRT-LLM MLA kernel with sparse_mla_top_k
|
||||
for models like DeepSeek-V3.2 that use index-based sparse attention.
|
||||
"""
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [32, 64]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHINFER_MLA_SPARSE"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashInferMLASparseImpl"]:
|
||||
return FlashInferMLASparseImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashInferMLASparseMetadataBuilder"]:
|
||||
return FlashInferMLASparseMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [576]
|
||||
|
||||
@classmethod
|
||||
def is_mla(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_sparse(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
# FlashInfer sparse MLA targets Blackwell (SM 10.x)
|
||||
return capability.major == 10
|
||||
|
||||
@classmethod
|
||||
def supports_combination(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: CacheDType | None,
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
device_capability: DeviceCapability,
|
||||
) -> str | None:
|
||||
# FlashInfer MLA sparse kernel requires qk_nope_head_dim == 128
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
if vllm_config.model_config is not None:
|
||||
hf_text_config = vllm_config.model_config.hf_text_config
|
||||
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
|
||||
if qk_nope_head_dim != 128:
|
||||
return (
|
||||
f"FlashInfer MLA Sparse kernel requires qk_nope_head_dim == 128, "
|
||||
f"but got {qk_nope_head_dim}"
|
||||
)
|
||||
# Check for index_topk which indicates sparse model
|
||||
if not hasattr(hf_text_config, "index_topk"):
|
||||
return "FlashInfer MLA Sparse requires model with index_topk config"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@classmethod
|
||||
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
|
||||
return "HND"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashInferMLASparseMetadata(AttentionMetadata):
|
||||
"""Attention metadata for FlashInfer MLA Sparse backend."""
|
||||
|
||||
num_reqs: int
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
num_actual_tokens: int
|
||||
|
||||
# Query start locations
|
||||
query_start_loc: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
req_id_per_token: torch.Tensor
|
||||
|
||||
# Sequence lengths for all requests (context + query)
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
# Sparse-specific
|
||||
block_size: int = 64
|
||||
topk_tokens: int = 2048
|
||||
|
||||
|
||||
class FlashInferMLASparseMetadataBuilder(
|
||||
AttentionMetadataBuilder[FlashInferMLASparseMetadata]
|
||||
):
|
||||
"""Builder for FlashInfer MLA Sparse attention metadata."""
|
||||
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
self.layer_names = layer_names
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.model_config = vllm_config.model_config
|
||||
self.device = device
|
||||
|
||||
self.mla_dims = get_mla_dims(self.model_config)
|
||||
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
|
||||
|
||||
self.req_id_per_token_buffer = torch.empty(
|
||||
(vllm_config.scheduler_config.max_num_batched_tokens,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> FlashInferMLASparseMetadata:
|
||||
cm = common_attn_metadata
|
||||
num_tokens = cm.num_actual_tokens
|
||||
|
||||
# Build req_id_per_token mapping
|
||||
starts = np.asarray(cm.query_start_loc_cpu, dtype=np.int32)
|
||||
seg_lengths = np.diff(starts)
|
||||
req_id_per_token = np.repeat(
|
||||
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
|
||||
)
|
||||
|
||||
# Zero-fill for cudagraphs
|
||||
self.req_id_per_token_buffer.fill_(0)
|
||||
self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
|
||||
torch.from_numpy(req_id_per_token), non_blocking=True
|
||||
)
|
||||
req_id_per_token_tensor = self.req_id_per_token_buffer[:num_tokens]
|
||||
|
||||
return FlashInferMLASparseMetadata(
|
||||
num_reqs=cm.num_reqs,
|
||||
max_query_len=cm.max_query_len,
|
||||
max_seq_len=cm.max_seq_len,
|
||||
num_actual_tokens=cm.num_actual_tokens,
|
||||
query_start_loc=cm.query_start_loc,
|
||||
slot_mapping=cm.slot_mapping,
|
||||
block_table=cm.block_table_tensor,
|
||||
req_id_per_token=req_id_per_token_tensor,
|
||||
seq_lens=cm.seq_lens,
|
||||
block_size=self.kv_cache_spec.block_size,
|
||||
topk_tokens=self.topk_tokens,
|
||||
)
|
||||
|
||||
|
||||
# Global workspace buffer (lazily initialized)
|
||||
_fi_sparse_workspace: torch.Tensor | None = None
|
||||
|
||||
|
||||
def _get_workspace_buffer(device: torch.device) -> torch.Tensor:
|
||||
global _fi_sparse_workspace
|
||||
if _fi_sparse_workspace is None:
|
||||
_fi_sparse_workspace = torch.zeros(
|
||||
FLASHINFER_MLA_SPARSE_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=device,
|
||||
)
|
||||
return _fi_sparse_workspace
|
||||
|
||||
|
||||
class FlashInferMLASparseImpl(SparseMLAAttentionImpl[FlashInferMLASparseMetadata]):
|
||||
"""FlashInfer MLA Sparse implementation.
|
||||
|
||||
Uses the TRT-LLM MLA kernel with sparse_mla_top_k parameter for
|
||||
sparse attention computation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
topk_indice_buffer: torch.Tensor | None = None,
|
||||
indexer: "Indexer | None" = None,
|
||||
**mla_args,
|
||||
) -> None:
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashInferMLASparseImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashInferMLASparseImpl"
|
||||
)
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
# MLA-specific dimensions
|
||||
self.kv_lora_rank: int = mla_args["kv_lora_rank"]
|
||||
self.qk_nope_head_dim: int = mla_args["qk_nope_head_dim"]
|
||||
self.qk_rope_head_dim: int = mla_args["qk_rope_head_dim"]
|
||||
|
||||
assert indexer is not None, "Indexer required for sparse MLA"
|
||||
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
|
||||
|
||||
self._workspace_buffer: torch.Tensor | None = None
|
||||
self.bmm1_scale: float | None = None
|
||||
self.bmm2_scale: float | None = None
|
||||
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashInferMLASparseMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
if isinstance(q, tuple):
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
num_actual_toks = q.shape[0]
|
||||
|
||||
assert self.topk_indices_buffer is not None
|
||||
topk_indices = self.topk_indices_buffer[:num_actual_toks]
|
||||
|
||||
topk_indices_physical, seq_lens = triton_convert_req_index_to_global_index(
|
||||
attn_metadata.req_id_per_token[:num_actual_toks],
|
||||
attn_metadata.block_table,
|
||||
topk_indices,
|
||||
BLOCK_SIZE=attn_metadata.block_size,
|
||||
NUM_TOPK_TOKENS=topk_indices.shape[1],
|
||||
return_valid_counts=True,
|
||||
)
|
||||
|
||||
if self._workspace_buffer is None:
|
||||
self._workspace_buffer = _get_workspace_buffer(q.device)
|
||||
|
||||
if self.bmm1_scale is None:
|
||||
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
|
||||
if self.bmm2_scale is None:
|
||||
self.bmm2_scale = layer._v_scale_float
|
||||
|
||||
o = trtllm_batch_decode_with_kv_cache_mla(
|
||||
query=q.unsqueeze(1),
|
||||
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
|
||||
workspace_buffer=self._workspace_buffer,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
block_tables=topk_indices_physical.unsqueeze(1),
|
||||
seq_lens=seq_lens,
|
||||
max_seq_len=attn_metadata.topk_tokens,
|
||||
bmm1_scale=self.bmm1_scale,
|
||||
bmm2_scale=self.bmm2_scale,
|
||||
sparse_mla_top_k=attn_metadata.topk_tokens,
|
||||
)
|
||||
return o.view(-1, o.shape[-2], o.shape[-1]), None
|
||||
@@ -15,7 +15,6 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
@@ -26,6 +25,9 @@ from vllm.v1.attention.backend import (
|
||||
MultipleOf,
|
||||
SparseMLAAttentionImpl,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.sparse_utils import (
|
||||
triton_convert_req_index_to_global_index,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
reshape_attn_output_for_spec_decode,
|
||||
reshape_query_for_spec_decode,
|
||||
@@ -203,166 +205,6 @@ class FlashMLASparseMetadata(AttentionMetadata):
|
||||
fp8_use_mixed_batch: bool = False
|
||||
|
||||
|
||||
# Kernel with prefill workspace support
|
||||
@triton.jit
|
||||
def _convert_req_index_to_global_index_kernel(
|
||||
req_id_ptr, # int32 [num_tokens]
|
||||
block_table_ptr, # int32 [num_requests, max_num_blocks_per_req]
|
||||
token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
prefill_request_id_ptr, # int32 [num_tokens], -1 for decode, >=0 for prefill
|
||||
workspace_starts_ptr, # int32 [num_prefill_reqs+1] or nullptr
|
||||
# shapes (compile-time where possible)
|
||||
max_num_blocks_per_req: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr, # tile width along columns
|
||||
HAS_PREFILL: tl.constexpr,
|
||||
# strides (in elements)
|
||||
bt_stride0,
|
||||
bt_stride1,
|
||||
ti_stride0,
|
||||
ti_stride1,
|
||||
out_stride0,
|
||||
out_stride1,
|
||||
):
|
||||
# program_id(0) -> token_id (row)
|
||||
# program_id(1) -> tile index along columns
|
||||
token_id = tl.program_id(0)
|
||||
tile_id = tl.program_id(1)
|
||||
|
||||
# Each program covers BLOCK_N consecutive columns
|
||||
indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
# Load request id for this token (no mask: grid is exact)
|
||||
req = tl.load(req_id_ptr + token_id)
|
||||
|
||||
# Load token indices for this tile
|
||||
ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1
|
||||
tok = tl.load(ti_ptr) # int32
|
||||
|
||||
# Only token == -1 should propagate as -1
|
||||
is_invalid_tok = tok < 0
|
||||
is_prefill = False
|
||||
if HAS_PREFILL:
|
||||
prefill_req_id = tl.load(prefill_request_id_ptr + token_id)
|
||||
is_prefill = prefill_req_id >= 0
|
||||
# Compute block id and in-block offset
|
||||
block_id = tok // BLOCK_SIZE
|
||||
inblock_off = tok % BLOCK_SIZE
|
||||
|
||||
# Guard block_table access
|
||||
valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0)
|
||||
bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1
|
||||
is_invalid_tok |= ~valid_block
|
||||
base = tl.load(bt_ptr, mask=valid_block & ~is_prefill, other=0)
|
||||
out_val = base * BLOCK_SIZE + inblock_off
|
||||
|
||||
# Override with prefill output if prefill is enabled
|
||||
if HAS_PREFILL:
|
||||
workspace_start = tl.load(
|
||||
workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0
|
||||
)
|
||||
prefill_out = workspace_start + tok
|
||||
out_val = tl.where(is_prefill, prefill_out, out_val)
|
||||
out_val = tl.where(is_invalid_tok, -1, out_val)
|
||||
|
||||
# Store results
|
||||
out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1
|
||||
tl.store(out_ptr_ij, out_val)
|
||||
|
||||
|
||||
def triton_convert_req_index_to_global_index(
|
||||
req_id: torch.Tensor, # int32 [num_tokens]
|
||||
block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req]
|
||||
token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
BLOCK_SIZE: int = 64,
|
||||
NUM_TOPK_TOKENS: int = 2048,
|
||||
BLOCK_N: int = 128, # tile width along columns
|
||||
HAS_PREFILL_WORKSPACE: bool = False,
|
||||
prefill_workspace_request_ids: torch.Tensor | None = None,
|
||||
prefill_workspace_starts: torch.Tensor | None = None,
|
||||
):
|
||||
"""
|
||||
out[token_id, indice_id] =
|
||||
block_table[req_id[token_id],
|
||||
token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE
|
||||
+ token_indices[token_id, indice_id] % BLOCK_SIZE
|
||||
|
||||
Only when token_indices[token_id, indice_id] == -1 do we output -1.
|
||||
For safety, we also output -1 if the derived block_id would be
|
||||
out-of-bounds.
|
||||
|
||||
When HAS_PREFILL_WORKSPACE is True, prefill tokens are mapped to workspace offsets
|
||||
instead of global cache slots. prefill_workspace_request_ids and
|
||||
prefill_workspace_starts must be provided.
|
||||
|
||||
prefill_workspace_request_ids: int32 [num_tokens], -1 for decode else
|
||||
prefill request index (maps to prefill_workspace_starts)
|
||||
prefill_workspace_starts: int32 [num_prefills], 0-indexed workspace
|
||||
starts for each prefill request
|
||||
"""
|
||||
assert req_id.dtype == torch.int32
|
||||
assert block_table.dtype == torch.int32
|
||||
assert token_indices.dtype == torch.int32
|
||||
assert token_indices.shape[1] == NUM_TOPK_TOKENS
|
||||
assert NUM_TOPK_TOKENS % BLOCK_N == 0, (
|
||||
f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by BLOCK_N ({BLOCK_N})"
|
||||
)
|
||||
|
||||
if HAS_PREFILL_WORKSPACE:
|
||||
assert prefill_workspace_request_ids is not None
|
||||
assert prefill_workspace_starts is not None
|
||||
assert prefill_workspace_request_ids.dtype == torch.int32
|
||||
assert prefill_workspace_starts.dtype == torch.int32
|
||||
|
||||
num_tokens = req_id.shape[0]
|
||||
max_num_blocks_per_req = block_table.shape[1]
|
||||
tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N
|
||||
|
||||
# Ensure contiguous tensors on the same device
|
||||
req_id_c = req_id.contiguous()
|
||||
block_table_c = block_table.contiguous()
|
||||
token_indices_c = token_indices.contiguous()
|
||||
out = torch.empty_like(token_indices_c)
|
||||
|
||||
# Strides in elements
|
||||
bt_stride0, bt_stride1 = block_table_c.stride()
|
||||
ti_stride0, ti_stride1 = token_indices_c.stride()
|
||||
out_stride0, out_stride1 = out.stride()
|
||||
|
||||
# Prepare prefill pointers
|
||||
if HAS_PREFILL_WORKSPACE:
|
||||
assert prefill_workspace_request_ids is not None # for mypy
|
||||
assert prefill_workspace_starts is not None # for mypy
|
||||
assert prefill_workspace_request_ids.is_contiguous()
|
||||
assert prefill_workspace_starts.is_contiguous()
|
||||
|
||||
# Exact 2D grid: tokens × column tiles
|
||||
grid = (num_tokens, tiles_per_row)
|
||||
|
||||
_convert_req_index_to_global_index_kernel[grid](
|
||||
req_id_c,
|
||||
block_table_c,
|
||||
token_indices_c,
|
||||
out,
|
||||
prefill_workspace_request_ids,
|
||||
prefill_workspace_starts,
|
||||
# shapes / constexprs
|
||||
max_num_blocks_per_req,
|
||||
BLOCK_SIZE,
|
||||
BLOCK_N,
|
||||
HAS_PREFILL_WORKSPACE,
|
||||
# strides
|
||||
bt_stride0,
|
||||
bt_stride1,
|
||||
ti_stride0,
|
||||
ti_stride1,
|
||||
out_stride0,
|
||||
out_stride1,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def get_prefill_workspace_size(max_model_len: int):
|
||||
# NOTE(Lucas): 5 is a magic number for controlling the prefill buffer size.
|
||||
# May be tuned later.
|
||||
|
||||
191
vllm/v1/attention/backends/mla/sparse_utils.py
Normal file
191
vllm/v1/attention/backends/mla/sparse_utils.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utility functions for sparse MLA backends."""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
# Kernel with prefill workspace support and valid count tracking
|
||||
@triton.jit
|
||||
def _convert_req_index_to_global_index_kernel(
|
||||
req_id_ptr, # int32 [num_tokens]
|
||||
block_table_ptr, # int32 [num_requests, max_num_blocks_per_req]
|
||||
token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
valid_count_ptr, # int32 [num_tokens] - output valid count per row
|
||||
prefill_request_id_ptr, # int32 [num_tokens], -1 for decode, >=0 for prefill
|
||||
workspace_starts_ptr, # int32 [num_prefill_reqs+1] or nullptr
|
||||
# shapes (compile-time where possible)
|
||||
max_num_blocks_per_req: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr, # tile width along columns
|
||||
HAS_PREFILL: tl.constexpr,
|
||||
COUNT_VALID: tl.constexpr, # whether to count valid indices
|
||||
# strides (in elements)
|
||||
bt_stride0,
|
||||
bt_stride1,
|
||||
ti_stride0,
|
||||
ti_stride1,
|
||||
out_stride0,
|
||||
out_stride1,
|
||||
):
|
||||
# program_id(0) -> token_id (row)
|
||||
# program_id(1) -> tile index along columns
|
||||
token_id = tl.program_id(0)
|
||||
tile_id = tl.program_id(1)
|
||||
|
||||
# Each program covers BLOCK_N consecutive columns
|
||||
indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
# Load request id for this token (no mask: grid is exact)
|
||||
req = tl.load(req_id_ptr + token_id)
|
||||
|
||||
# Load token indices for this tile
|
||||
ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1
|
||||
tok = tl.load(ti_ptr) # int32
|
||||
|
||||
# Only token == -1 should propagate as -1
|
||||
is_invalid_tok = tok < 0
|
||||
is_prefill = False
|
||||
if HAS_PREFILL:
|
||||
prefill_req_id = tl.load(prefill_request_id_ptr + token_id)
|
||||
is_prefill = prefill_req_id >= 0
|
||||
# Compute block id and in-block offset
|
||||
block_id = tok // BLOCK_SIZE
|
||||
inblock_off = tok % BLOCK_SIZE
|
||||
|
||||
# Guard block_table access
|
||||
valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0)
|
||||
bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1
|
||||
is_invalid_tok |= ~valid_block
|
||||
base = tl.load(bt_ptr, mask=valid_block & ~is_prefill, other=0)
|
||||
out_val = base * BLOCK_SIZE + inblock_off
|
||||
|
||||
# Override with prefill output if prefill is enabled
|
||||
if HAS_PREFILL:
|
||||
workspace_start = tl.load(
|
||||
workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0
|
||||
)
|
||||
prefill_out = workspace_start + tok
|
||||
out_val = tl.where(is_prefill, prefill_out, out_val)
|
||||
out_val = tl.where(is_invalid_tok, -1, out_val)
|
||||
|
||||
# Store results
|
||||
out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1
|
||||
tl.store(out_ptr_ij, out_val)
|
||||
|
||||
# Count valid indices in this tile and atomically add to row total
|
||||
if COUNT_VALID:
|
||||
tile_valid_count = tl.sum((~is_invalid_tok).to(tl.int32))
|
||||
tl.atomic_add(valid_count_ptr + token_id, tile_valid_count)
|
||||
|
||||
|
||||
def triton_convert_req_index_to_global_index(
|
||||
req_id: torch.Tensor, # int32 [num_tokens]
|
||||
block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req]
|
||||
token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
BLOCK_SIZE: int = 64,
|
||||
NUM_TOPK_TOKENS: int = 2048,
|
||||
BLOCK_N: int = 128, # tile width along columns
|
||||
HAS_PREFILL_WORKSPACE: bool = False,
|
||||
prefill_workspace_request_ids: torch.Tensor | None = None,
|
||||
prefill_workspace_starts: torch.Tensor | None = None,
|
||||
return_valid_counts: bool = False,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
out[token_id, indice_id] =
|
||||
block_table[req_id[token_id],
|
||||
token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE
|
||||
+ token_indices[token_id, indice_id] % BLOCK_SIZE
|
||||
|
||||
Only when token_indices[token_id, indice_id] == -1 do we output -1.
|
||||
For safety, we also output -1 if the derived block_id would be
|
||||
out-of-bounds.
|
||||
|
||||
When HAS_PREFILL_WORKSPACE is True, prefill tokens are mapped to workspace offsets
|
||||
instead of global cache slots. prefill_workspace_request_ids and
|
||||
prefill_workspace_starts must be provided.
|
||||
|
||||
prefill_workspace_request_ids: int32 [num_tokens], -1 for decode else
|
||||
prefill request index (maps to prefill_workspace_starts)
|
||||
prefill_workspace_starts: int32 [num_prefills], 0-indexed workspace
|
||||
starts for each prefill request
|
||||
|
||||
When return_valid_counts is True, also returns the count of valid (non -1)
|
||||
indices per row, computed during the same kernel pass (no extra overhead).
|
||||
"""
|
||||
assert req_id.dtype == torch.int32
|
||||
assert block_table.dtype == torch.int32
|
||||
assert token_indices.dtype == torch.int32
|
||||
assert token_indices.shape[1] == NUM_TOPK_TOKENS
|
||||
assert NUM_TOPK_TOKENS % BLOCK_N == 0, (
|
||||
f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by BLOCK_N ({BLOCK_N})"
|
||||
)
|
||||
|
||||
if HAS_PREFILL_WORKSPACE:
|
||||
assert prefill_workspace_request_ids is not None
|
||||
assert prefill_workspace_starts is not None
|
||||
assert prefill_workspace_request_ids.dtype == torch.int32
|
||||
assert prefill_workspace_starts.dtype == torch.int32
|
||||
|
||||
num_tokens = req_id.shape[0]
|
||||
max_num_blocks_per_req = block_table.shape[1]
|
||||
tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N
|
||||
|
||||
# Ensure contiguous tensors on the same device
|
||||
req_id_c = req_id.contiguous()
|
||||
block_table_c = block_table.contiguous()
|
||||
token_indices_c = token_indices.contiguous()
|
||||
out = torch.empty_like(token_indices_c)
|
||||
|
||||
# Allocate valid count buffer if needed (must be zero-initialized for atomics)
|
||||
valid_counts: torch.Tensor | None = None
|
||||
if return_valid_counts:
|
||||
valid_counts = torch.zeros(
|
||||
num_tokens, dtype=torch.int32, device=token_indices.device
|
||||
)
|
||||
|
||||
# Strides in elements
|
||||
bt_stride0, bt_stride1 = block_table_c.stride()
|
||||
ti_stride0, ti_stride1 = token_indices_c.stride()
|
||||
out_stride0, out_stride1 = out.stride()
|
||||
|
||||
# Prepare prefill pointers
|
||||
if HAS_PREFILL_WORKSPACE:
|
||||
assert prefill_workspace_request_ids is not None # for mypy
|
||||
assert prefill_workspace_starts is not None # for mypy
|
||||
assert prefill_workspace_request_ids.is_contiguous()
|
||||
assert prefill_workspace_starts.is_contiguous()
|
||||
|
||||
# Exact 2D grid: tokens × column tiles
|
||||
grid = (num_tokens, tiles_per_row)
|
||||
|
||||
_convert_req_index_to_global_index_kernel[grid](
|
||||
req_id_c,
|
||||
block_table_c,
|
||||
token_indices_c,
|
||||
out,
|
||||
valid_counts,
|
||||
prefill_workspace_request_ids,
|
||||
prefill_workspace_starts,
|
||||
# shapes / constexprs
|
||||
max_num_blocks_per_req,
|
||||
BLOCK_SIZE,
|
||||
BLOCK_N,
|
||||
HAS_PREFILL_WORKSPACE,
|
||||
return_valid_counts,
|
||||
# strides
|
||||
bt_stride0,
|
||||
bt_stride1,
|
||||
ti_stride0,
|
||||
ti_stride1,
|
||||
out_stride0,
|
||||
out_stride1,
|
||||
)
|
||||
|
||||
if return_valid_counts:
|
||||
assert valid_counts is not None
|
||||
return out, valid_counts
|
||||
return out
|
||||
@@ -62,6 +62,10 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
|
||||
FLASHINFER_MLA = (
|
||||
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
|
||||
)
|
||||
FLASHINFER_MLA_SPARSE = (
|
||||
"vllm.v1.attention.backends.mla.flashinfer_mla_sparse."
|
||||
"FlashInferMLASparseBackend"
|
||||
)
|
||||
TRITON_MLA = "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
|
||||
CUTLASS_MLA = "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
|
||||
FLASHMLA = "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
|
||||
|
||||
@@ -53,6 +53,7 @@ def get_attn_backend(
|
||||
use_sparse: bool = False,
|
||||
use_mm_prefix: bool = False,
|
||||
attn_type: str | None = None,
|
||||
num_heads: int | None = None,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
|
||||
@@ -66,7 +67,6 @@ def get_attn_backend(
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
backend_enum = vllm_config.attention_config.backend
|
||||
|
||||
attn_selector_config = AttentionSelectorConfig(
|
||||
head_size=head_size,
|
||||
@@ -81,8 +81,9 @@ def get_attn_backend(
|
||||
)
|
||||
|
||||
return _cached_get_attn_backend(
|
||||
backend=backend_enum,
|
||||
backend=vllm_config.attention_config.backend,
|
||||
attn_selector_config=attn_selector_config,
|
||||
num_heads=num_heads,
|
||||
)
|
||||
|
||||
|
||||
@@ -90,12 +91,14 @@ def get_attn_backend(
|
||||
def _cached_get_attn_backend(
|
||||
backend,
|
||||
attn_selector_config: AttentionSelectorConfig,
|
||||
num_heads: int | None = None,
|
||||
) -> type[AttentionBackend]:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
backend,
|
||||
attn_selector_config=attn_selector_config,
|
||||
num_heads=num_heads,
|
||||
)
|
||||
if not attention_cls:
|
||||
raise ValueError(
|
||||
|
||||
Reference in New Issue
Block a user