[Attention] Add FlashInfer Sparse MLA backend (#33451)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni
2026-02-12 12:21:54 -05:00
committed by GitHub
parent 334c715e0f
commit f2c47886fd
24 changed files with 1181 additions and 408 deletions

View File

@@ -1,11 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for the FlashMLA sparse backend utilities."""
"""Unit tests for the sparse MLA backends and utilities."""
import math
from types import MethodType, SimpleNamespace
import numpy as np
import pytest
import torch
@@ -25,6 +24,9 @@ from vllm.config import set_current_vllm_config
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mla.flashinfer_mla_sparse import (
FlashInferMLASparseBackend,
)
from vllm.v1.attention.backends.mla.flashmla_sparse import (
FlashMLASparseBackend,
triton_convert_req_index_to_global_index,
@@ -156,32 +158,48 @@ def _quantize_dequantize_fp8_ds_mla(
return dequant_kv_c, dequant_k_pe
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
@pytest.mark.skipif(
torch.cuda.get_device_capability() < (9, 0),
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
@pytest.mark.parametrize(
"backend_cls",
[FlashMLASparseBackend, FlashInferMLASparseBackend],
ids=["FlashMLA", "FlashInfer"],
)
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_ds_mla"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
@pytest.mark.parametrize("block_size", [32, 64])
def test_sparse_backend_decode_correctness(
default_vllm_config,
dist_init,
backend_cls,
batch_name,
kv_cache_dtype,
tensor_parallel_size,
block_size,
workspace_init,
):
if current_platform.is_rocm():
pytest.skip("ROCm does not support fp8_ds_mla data type for kv cache.")
if kv_cache_dtype not in backend_cls.supported_kv_cache_dtypes:
pytest.skip(f"{backend_cls.get_name()} does not support {kv_cache_dtype}")
if not torch.cuda.is_available():
pytest.skip("CUDA is required for sparse MLA decode test")
supported_block_sizes = backend_cls.get_supported_kernel_block_sizes()
if block_size not in supported_block_sizes:
pytest.skip(
f"{backend_cls.get_name()} does not support block_size={block_size}"
)
if backend_cls == FlashMLASparseBackend:
ok, reason = flashmla.is_flashmla_sparse_supported()
if not ok:
pytest.skip(reason)
elif backend_cls == FlashInferMLASparseBackend:
if not current_platform.has_device_capability(100):
pytest.skip("FlashInferMLASparseBackend requires SM 10.0 or higher")
batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name]
use_fp8_ds_mla_quantization = kv_cache_dtype == "fp8_ds_mla"
device = torch.device("cuda")
dtype = torch.bfloat16
batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name]
# Model hyper-parameters (kept intentionally small for the unit test)
total_num_heads = 128
# Compute per-rank heads for simulated TP
@@ -192,11 +210,10 @@ def test_sparse_backend_decode_correctness(
qk_rope_head_dim = 64
v_head_dim = 128
head_size = kv_lora_rank + qk_rope_head_dim
topk_tokens = 2048
topk_tokens = 128
max_seqlen = max(batch_spec.seq_lens)
total_cache_tokens = sum(batch_spec.seq_lens)
block_size = 64
# Note: We use TP=1 to avoid multi-GPU requirements in CI.
# The test simulates head partitioning via mocked methods below.
@@ -247,11 +264,55 @@ def test_sparse_backend_decode_correctness(
seq_lens = batch_spec.seq_lens
query_lens = batch_spec.query_lens
# Pre-compute positions and sparse indices for all tokens.
# We need these BEFORE computing the reference to use sparse attention masks.
total_query_tokens = sum(query_lens)
positions = []
for i in range(batch_spec.batch_size):
s_len = seq_lens[i]
q_len = query_lens[i]
ctx_len = s_len - q_len
for q_idx in range(q_len):
positions.append(ctx_len + q_idx)
# Create sparse indices with UNIQUE per-token offsets to catch bugs where
# the kernel uses wrong indices for some tokens (e.g., due to incorrect
# tensor shapes like [1, num_tokens, ...] instead of [num_tokens, 1, ...]).
# Also include -1 masked indices to verify the kernel handles them correctly.
sparse_indices = torch.empty(
total_query_tokens, topk_tokens, dtype=torch.int32, device=device
)
for tok_idx in range(total_query_tokens):
max_valid_idx = positions[tok_idx]
offset = tok_idx * 7 # Prime number for varied offsets
# Use only half the topk indices as valid, mask the rest with -1
# This tests that the kernel correctly ignores -1 indices
num_valid = min(topk_tokens // 2, max_valid_idx + 1)
if num_valid > 0:
valid_range = torch.arange(num_valid, device=device, dtype=torch.int32)
tok_indices = (valid_range + offset) % (max_valid_idx + 1)
# Pad with -1 for the remaining positions
tok_indices = torch.cat(
[
tok_indices,
torch.full(
(topk_tokens - num_valid,), -1, device=device, dtype=torch.int32
),
]
)
else:
tok_indices = torch.full(
(topk_tokens,), -1, device=device, dtype=torch.int32
)
tok_indices[0] = 0 # At least one valid index
sparse_indices[tok_idx] = tok_indices
all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], []
kv_c_contexts, k_pe_contexts = [], []
reference_outputs = []
kv_cache_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
global_token_idx = 0
for i in range(batch_spec.batch_size):
s_len = seq_lens[i]
@@ -268,40 +329,53 @@ def test_sparse_backend_decode_correctness(
kv_c_full = torch.rand(s_len, kv_lora_rank, dtype=dtype, device=device)
k_pe_full = torch.rand(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device)
# SM100 (Blackwell) uses float -> e8m0 -> bf16 scale conversion
# which truncates scales to powers of 2. Simulate this in reference.
is_sm100 = torch.cuda.get_device_capability()[0] >= 10
kv_c_full, k_pe_full = _quantize_dequantize_fp8_ds_mla(
kv_c_full,
k_pe_full.squeeze(1),
block_size=vllm_config.cache_config.block_size,
scale=kv_cache_scale,
simulate_sm100_e8m0_scales=is_sm100,
)
if use_fp8_ds_mla_quantization:
is_sm100 = torch.cuda.get_device_capability()[0] >= 10
kv_c_full, k_pe_squeezed = _quantize_dequantize_fp8_ds_mla(
kv_c_full,
k_pe_full.squeeze(1),
block_size=block_size,
scale=kv_cache_scale,
simulate_sm100_e8m0_scales=is_sm100,
)
k_pe_full = k_pe_squeezed.unsqueeze(1)
q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, W_UK)
q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
k_mqa = torch.cat([kv_c_full, k_pe_full], dim=-1)
k_mqa = k_mqa.unsqueeze(1).expand(-1, num_heads, -1)
v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_heads, -1)
k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1)
v_mqa = kv_c_full
attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
attn_mask[:, ctx_len:] = causal_mask
# Compute sparse SDPA reference per query token using its sparse indices
for q_idx in range(q_len):
tok_sparse_idx = sparse_indices[global_token_idx]
valid_mask = tok_sparse_idx >= 0
valid_indices = tok_sparse_idx[valid_mask].long()
q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
q_tok = q_mqa[q_idx : q_idx + 1] # [1, num_heads, head_dim]
k_sparse = k_mqa[valid_indices] # [num_valid, head_dim]
v_sparse = v_mqa[valid_indices] # [num_valid, kv_lora_rank]
sdpa_out = torch.nn.functional.scaled_dot_product_attention(
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale
)
sdpa_out = sdpa_out.transpose(1, 2).squeeze(0)
k_sparse = k_sparse.unsqueeze(1).expand(-1, num_heads, -1)
v_sparse = v_sparse.unsqueeze(1).expand(-1, num_heads, -1)
sdpa_out = torch.einsum("qnl,lnv->qnv", sdpa_out, W_UV)
reference_outputs.append(sdpa_out.flatten(start_dim=-2))
# SDPA: [1, num_heads, 1, head_dim] x [1, num_heads, num_valid, head_dim]
q_sdpa_in = q_tok.unsqueeze(0).transpose(1, 2)
k_sdpa_in = k_sparse.unsqueeze(0).transpose(1, 2)
v_sdpa_in = v_sparse.unsqueeze(0).transpose(1, 2)
sdpa_out = torch.nn.functional.scaled_dot_product_attention(
q_sdpa_in, k_sdpa_in, v_sdpa_in, scale=scale
)
sdpa_out = sdpa_out.transpose(1, 2).squeeze(
0
) # [1, num_heads, kv_lora_rank]
sdpa_out = torch.einsum("qnl,lnv->qnv", sdpa_out, W_UV)
reference_outputs.append(sdpa_out.flatten(start_dim=-2))
global_token_idx += 1
all_q_vllm.append(q_c)
all_kv_c_vllm.append(kv_c_full[ctx_len:])
@@ -334,42 +408,18 @@ def test_sparse_backend_decode_correctness(
num_blocks=vllm_config.cache_config.num_gpu_blocks,
common_attn_metadata=common_attn_metadata,
randomize_blocks=False,
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
kv_cache_dtype=kv_cache_dtype if use_fp8_ds_mla_quantization else "auto",
scale=kv_cache_scale,
)
builder_cls = FlashMLASparseBackend.get_builder_cls()
builder_cls = backend_cls.get_builder_cls()
builder = builder_cls(kv_cache_spec, ["placeholder"], vllm_config, device)
metadata = builder.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
seg_lengths = np.diff(starts)
positions = np.arange(starts[-1], dtype=np.int32) - np.repeat(
starts[:-1], seg_lengths
)
seq_lengths = np.asarray(common_attn_metadata.seq_lens.cpu(), dtype=np.int32)
prefix_lengths = seq_lengths - seg_lengths
positions += np.repeat(prefix_lengths, seg_lengths)
pos_gpu = torch.as_tensor(positions, device=device, dtype=torch.int32)
topk = metadata.topk_tokens
debug_indices = torch.arange(topk, device=device, dtype=torch.int32).unsqueeze(0)
token_positions = pos_gpu.unsqueeze(1)
causal_mask = debug_indices <= token_positions
debug_indices = torch.where(
causal_mask, debug_indices, torch.full_like(debug_indices, -1)
)
# FlashMLASparseImpl now reads top-k indices from the indexer-provided
# buffer, so emulate that contract with a simple namespace mock.
debug_indices = debug_indices.expand(metadata.num_actual_tokens, -1).clone()
mock_indexer = SimpleNamespace(topk_indices_buffer=debug_indices)
ok, reason = flashmla.is_flashmla_sparse_supported()
if not ok:
pytest.skip(reason)
# Use the pre-computed sparse_indices for the mock indexer
mock_indexer = SimpleNamespace(topk_indices_buffer=sparse_indices)
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
kv_b_proj_weight = kv_b_proj_weight.view(
@@ -383,7 +433,7 @@ def test_sparse_backend_decode_correctness(
).to(device=device, dtype=dtype)
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous())
impl_cls = FlashMLASparseBackend.get_impl_cls()
impl_cls = backend_cls.get_impl_cls()
with set_current_vllm_config(vllm_config):
impl = impl_cls(
num_heads=num_heads,
@@ -441,7 +491,7 @@ def test_sparse_backend_decode_correctness(
# FP8 quantization introduces some error, but should be within reasonable bounds
# BF16 (auto) should be very accurate, FP8 allows slightly more tolerance
if kv_cache_dtype == "fp8_ds_mla":
if kv_cache_dtype.startswith("fp8"):
torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.05, atol=0.05)
else:
torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.01, atol=0.01)
@@ -636,3 +686,63 @@ def test_triton_convert_req_index_to_global_index_with_prefill_workspace(block_s
def test_split_prefill_chunks(seq_lens, max_buf, expected):
out = split_prefill_chunks(seq_lens, max_buf)
assert out == expected
def test_triton_convert_returns_valid_counts():
"""Test that return_valid_counts correctly counts non-negative indices."""
device = torch.device("cuda")
num_tokens = 8
num_requests = 2
max_blocks_per_req = 10
block_size = 64
num_topk_tokens = 128
req_id = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1], dtype=torch.int32, device=device)
block_table = torch.arange(
num_requests * max_blocks_per_req, dtype=torch.int32, device=device
).view(num_requests, max_blocks_per_req)
# Create token indices with varying numbers of valid entries
# Token 0: 64 valid, 64 invalid (-1)
# Token 1: 32 valid, 96 invalid
# Token 2: 128 valid (all)
# Token 3: 1 valid, 127 invalid
# etc.
token_indices = torch.full(
(num_tokens, num_topk_tokens), -1, dtype=torch.int32, device=device
)
expected_valid = []
for i in range(num_tokens):
num_valid = [64, 32, 128, 1, 64, 32, 128, 1][i]
token_indices[i, :num_valid] = torch.arange(
num_valid, dtype=torch.int32, device=device
) % (block_size * max_blocks_per_req)
expected_valid.append(num_valid)
expected_valid_tensor = torch.tensor(
expected_valid, dtype=torch.int32, device=device
)
# Test with return_valid_counts=True
result, valid_counts = triton_convert_req_index_to_global_index(
req_id,
block_table,
token_indices,
BLOCK_SIZE=block_size,
NUM_TOPK_TOKENS=num_topk_tokens,
return_valid_counts=True,
)
torch.testing.assert_close(valid_counts, expected_valid_tensor, rtol=0, atol=0)
# Test that return_valid_counts=False returns only the indices
result_only = triton_convert_req_index_to_global_index(
req_id,
block_table,
token_indices,
BLOCK_SIZE=block_size,
NUM_TOPK_TOKENS=num_topk_tokens,
return_valid_counts=False,
)
assert isinstance(result_only, torch.Tensor)
torch.testing.assert_close(result_only, result, rtol=0, atol=0)