[Attention] Add FlashInfer Sparse MLA backend (#33451)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
@@ -1,11 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Unit tests for the FlashMLA sparse backend utilities."""
|
||||
"""Unit tests for the sparse MLA backends and utilities."""
|
||||
|
||||
import math
|
||||
from types import MethodType, SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -25,6 +24,9 @@ from vllm.config import set_current_vllm_config
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.flashinfer_mla_sparse import (
|
||||
FlashInferMLASparseBackend,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
FlashMLASparseBackend,
|
||||
triton_convert_req_index_to_global_index,
|
||||
@@ -156,32 +158,48 @@ def _quantize_dequantize_fp8_ds_mla(
|
||||
return dequant_kv_c, dequant_k_pe
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.get_device_capability() < (9, 0),
|
||||
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
|
||||
@pytest.mark.parametrize(
|
||||
"backend_cls",
|
||||
[FlashMLASparseBackend, FlashInferMLASparseBackend],
|
||||
ids=["FlashMLA", "FlashInfer"],
|
||||
)
|
||||
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_ds_mla"])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
|
||||
@pytest.mark.parametrize("block_size", [32, 64])
|
||||
def test_sparse_backend_decode_correctness(
|
||||
default_vllm_config,
|
||||
dist_init,
|
||||
backend_cls,
|
||||
batch_name,
|
||||
kv_cache_dtype,
|
||||
tensor_parallel_size,
|
||||
block_size,
|
||||
workspace_init,
|
||||
):
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip("ROCm does not support fp8_ds_mla data type for kv cache.")
|
||||
if kv_cache_dtype not in backend_cls.supported_kv_cache_dtypes:
|
||||
pytest.skip(f"{backend_cls.get_name()} does not support {kv_cache_dtype}")
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is required for sparse MLA decode test")
|
||||
supported_block_sizes = backend_cls.get_supported_kernel_block_sizes()
|
||||
if block_size not in supported_block_sizes:
|
||||
pytest.skip(
|
||||
f"{backend_cls.get_name()} does not support block_size={block_size}"
|
||||
)
|
||||
|
||||
if backend_cls == FlashMLASparseBackend:
|
||||
ok, reason = flashmla.is_flashmla_sparse_supported()
|
||||
if not ok:
|
||||
pytest.skip(reason)
|
||||
elif backend_cls == FlashInferMLASparseBackend:
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip("FlashInferMLASparseBackend requires SM 10.0 or higher")
|
||||
|
||||
batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name]
|
||||
use_fp8_ds_mla_quantization = kv_cache_dtype == "fp8_ds_mla"
|
||||
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
|
||||
batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name]
|
||||
|
||||
# Model hyper-parameters (kept intentionally small for the unit test)
|
||||
total_num_heads = 128
|
||||
# Compute per-rank heads for simulated TP
|
||||
@@ -192,11 +210,10 @@ def test_sparse_backend_decode_correctness(
|
||||
qk_rope_head_dim = 64
|
||||
v_head_dim = 128
|
||||
head_size = kv_lora_rank + qk_rope_head_dim
|
||||
topk_tokens = 2048
|
||||
topk_tokens = 128
|
||||
|
||||
max_seqlen = max(batch_spec.seq_lens)
|
||||
total_cache_tokens = sum(batch_spec.seq_lens)
|
||||
block_size = 64
|
||||
|
||||
# Note: We use TP=1 to avoid multi-GPU requirements in CI.
|
||||
# The test simulates head partitioning via mocked methods below.
|
||||
@@ -247,11 +264,55 @@ def test_sparse_backend_decode_correctness(
|
||||
seq_lens = batch_spec.seq_lens
|
||||
query_lens = batch_spec.query_lens
|
||||
|
||||
# Pre-compute positions and sparse indices for all tokens.
|
||||
# We need these BEFORE computing the reference to use sparse attention masks.
|
||||
total_query_tokens = sum(query_lens)
|
||||
positions = []
|
||||
for i in range(batch_spec.batch_size):
|
||||
s_len = seq_lens[i]
|
||||
q_len = query_lens[i]
|
||||
ctx_len = s_len - q_len
|
||||
for q_idx in range(q_len):
|
||||
positions.append(ctx_len + q_idx)
|
||||
|
||||
# Create sparse indices with UNIQUE per-token offsets to catch bugs where
|
||||
# the kernel uses wrong indices for some tokens (e.g., due to incorrect
|
||||
# tensor shapes like [1, num_tokens, ...] instead of [num_tokens, 1, ...]).
|
||||
# Also include -1 masked indices to verify the kernel handles them correctly.
|
||||
sparse_indices = torch.empty(
|
||||
total_query_tokens, topk_tokens, dtype=torch.int32, device=device
|
||||
)
|
||||
for tok_idx in range(total_query_tokens):
|
||||
max_valid_idx = positions[tok_idx]
|
||||
offset = tok_idx * 7 # Prime number for varied offsets
|
||||
# Use only half the topk indices as valid, mask the rest with -1
|
||||
# This tests that the kernel correctly ignores -1 indices
|
||||
num_valid = min(topk_tokens // 2, max_valid_idx + 1)
|
||||
if num_valid > 0:
|
||||
valid_range = torch.arange(num_valid, device=device, dtype=torch.int32)
|
||||
tok_indices = (valid_range + offset) % (max_valid_idx + 1)
|
||||
# Pad with -1 for the remaining positions
|
||||
tok_indices = torch.cat(
|
||||
[
|
||||
tok_indices,
|
||||
torch.full(
|
||||
(topk_tokens - num_valid,), -1, device=device, dtype=torch.int32
|
||||
),
|
||||
]
|
||||
)
|
||||
else:
|
||||
tok_indices = torch.full(
|
||||
(topk_tokens,), -1, device=device, dtype=torch.int32
|
||||
)
|
||||
tok_indices[0] = 0 # At least one valid index
|
||||
sparse_indices[tok_idx] = tok_indices
|
||||
|
||||
all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], []
|
||||
kv_c_contexts, k_pe_contexts = [], []
|
||||
reference_outputs = []
|
||||
|
||||
kv_cache_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
global_token_idx = 0
|
||||
|
||||
for i in range(batch_spec.batch_size):
|
||||
s_len = seq_lens[i]
|
||||
@@ -268,40 +329,53 @@ def test_sparse_backend_decode_correctness(
|
||||
kv_c_full = torch.rand(s_len, kv_lora_rank, dtype=dtype, device=device)
|
||||
k_pe_full = torch.rand(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device)
|
||||
|
||||
# SM100 (Blackwell) uses float -> e8m0 -> bf16 scale conversion
|
||||
# which truncates scales to powers of 2. Simulate this in reference.
|
||||
is_sm100 = torch.cuda.get_device_capability()[0] >= 10
|
||||
kv_c_full, k_pe_full = _quantize_dequantize_fp8_ds_mla(
|
||||
kv_c_full,
|
||||
k_pe_full.squeeze(1),
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
scale=kv_cache_scale,
|
||||
simulate_sm100_e8m0_scales=is_sm100,
|
||||
)
|
||||
if use_fp8_ds_mla_quantization:
|
||||
is_sm100 = torch.cuda.get_device_capability()[0] >= 10
|
||||
kv_c_full, k_pe_squeezed = _quantize_dequantize_fp8_ds_mla(
|
||||
kv_c_full,
|
||||
k_pe_full.squeeze(1),
|
||||
block_size=block_size,
|
||||
scale=kv_cache_scale,
|
||||
simulate_sm100_e8m0_scales=is_sm100,
|
||||
)
|
||||
k_pe_full = k_pe_squeezed.unsqueeze(1)
|
||||
|
||||
q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
|
||||
ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, W_UK)
|
||||
q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
|
||||
|
||||
k_mqa = torch.cat([kv_c_full, k_pe_full], dim=-1)
|
||||
k_mqa = k_mqa.unsqueeze(1).expand(-1, num_heads, -1)
|
||||
v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_heads, -1)
|
||||
k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1)
|
||||
v_mqa = kv_c_full
|
||||
|
||||
attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
|
||||
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
|
||||
attn_mask[:, ctx_len:] = causal_mask
|
||||
# Compute sparse SDPA reference per query token using its sparse indices
|
||||
for q_idx in range(q_len):
|
||||
tok_sparse_idx = sparse_indices[global_token_idx]
|
||||
valid_mask = tok_sparse_idx >= 0
|
||||
valid_indices = tok_sparse_idx[valid_mask].long()
|
||||
|
||||
q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
|
||||
k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
|
||||
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
|
||||
q_tok = q_mqa[q_idx : q_idx + 1] # [1, num_heads, head_dim]
|
||||
k_sparse = k_mqa[valid_indices] # [num_valid, head_dim]
|
||||
v_sparse = v_mqa[valid_indices] # [num_valid, kv_lora_rank]
|
||||
|
||||
sdpa_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale
|
||||
)
|
||||
sdpa_out = sdpa_out.transpose(1, 2).squeeze(0)
|
||||
k_sparse = k_sparse.unsqueeze(1).expand(-1, num_heads, -1)
|
||||
v_sparse = v_sparse.unsqueeze(1).expand(-1, num_heads, -1)
|
||||
|
||||
sdpa_out = torch.einsum("qnl,lnv->qnv", sdpa_out, W_UV)
|
||||
reference_outputs.append(sdpa_out.flatten(start_dim=-2))
|
||||
# SDPA: [1, num_heads, 1, head_dim] x [1, num_heads, num_valid, head_dim]
|
||||
q_sdpa_in = q_tok.unsqueeze(0).transpose(1, 2)
|
||||
k_sdpa_in = k_sparse.unsqueeze(0).transpose(1, 2)
|
||||
v_sdpa_in = v_sparse.unsqueeze(0).transpose(1, 2)
|
||||
|
||||
sdpa_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, scale=scale
|
||||
)
|
||||
sdpa_out = sdpa_out.transpose(1, 2).squeeze(
|
||||
0
|
||||
) # [1, num_heads, kv_lora_rank]
|
||||
|
||||
sdpa_out = torch.einsum("qnl,lnv->qnv", sdpa_out, W_UV)
|
||||
reference_outputs.append(sdpa_out.flatten(start_dim=-2))
|
||||
|
||||
global_token_idx += 1
|
||||
|
||||
all_q_vllm.append(q_c)
|
||||
all_kv_c_vllm.append(kv_c_full[ctx_len:])
|
||||
@@ -334,42 +408,18 @@ def test_sparse_backend_decode_correctness(
|
||||
num_blocks=vllm_config.cache_config.num_gpu_blocks,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
randomize_blocks=False,
|
||||
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
||||
kv_cache_dtype=kv_cache_dtype if use_fp8_ds_mla_quantization else "auto",
|
||||
scale=kv_cache_scale,
|
||||
)
|
||||
|
||||
builder_cls = FlashMLASparseBackend.get_builder_cls()
|
||||
builder_cls = backend_cls.get_builder_cls()
|
||||
builder = builder_cls(kv_cache_spec, ["placeholder"], vllm_config, device)
|
||||
metadata = builder.build(
|
||||
common_prefix_len=0, common_attn_metadata=common_attn_metadata
|
||||
)
|
||||
|
||||
starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
|
||||
seg_lengths = np.diff(starts)
|
||||
positions = np.arange(starts[-1], dtype=np.int32) - np.repeat(
|
||||
starts[:-1], seg_lengths
|
||||
)
|
||||
seq_lengths = np.asarray(common_attn_metadata.seq_lens.cpu(), dtype=np.int32)
|
||||
prefix_lengths = seq_lengths - seg_lengths
|
||||
positions += np.repeat(prefix_lengths, seg_lengths)
|
||||
|
||||
pos_gpu = torch.as_tensor(positions, device=device, dtype=torch.int32)
|
||||
topk = metadata.topk_tokens
|
||||
debug_indices = torch.arange(topk, device=device, dtype=torch.int32).unsqueeze(0)
|
||||
token_positions = pos_gpu.unsqueeze(1)
|
||||
causal_mask = debug_indices <= token_positions
|
||||
debug_indices = torch.where(
|
||||
causal_mask, debug_indices, torch.full_like(debug_indices, -1)
|
||||
)
|
||||
|
||||
# FlashMLASparseImpl now reads top-k indices from the indexer-provided
|
||||
# buffer, so emulate that contract with a simple namespace mock.
|
||||
debug_indices = debug_indices.expand(metadata.num_actual_tokens, -1).clone()
|
||||
mock_indexer = SimpleNamespace(topk_indices_buffer=debug_indices)
|
||||
|
||||
ok, reason = flashmla.is_flashmla_sparse_supported()
|
||||
if not ok:
|
||||
pytest.skip(reason)
|
||||
# Use the pre-computed sparse_indices for the mock indexer
|
||||
mock_indexer = SimpleNamespace(topk_indices_buffer=sparse_indices)
|
||||
|
||||
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
@@ -383,7 +433,7 @@ def test_sparse_backend_decode_correctness(
|
||||
).to(device=device, dtype=dtype)
|
||||
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous())
|
||||
|
||||
impl_cls = FlashMLASparseBackend.get_impl_cls()
|
||||
impl_cls = backend_cls.get_impl_cls()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
impl = impl_cls(
|
||||
num_heads=num_heads,
|
||||
@@ -441,7 +491,7 @@ def test_sparse_backend_decode_correctness(
|
||||
|
||||
# FP8 quantization introduces some error, but should be within reasonable bounds
|
||||
# BF16 (auto) should be very accurate, FP8 allows slightly more tolerance
|
||||
if kv_cache_dtype == "fp8_ds_mla":
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.05, atol=0.05)
|
||||
else:
|
||||
torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.01, atol=0.01)
|
||||
@@ -636,3 +686,63 @@ def test_triton_convert_req_index_to_global_index_with_prefill_workspace(block_s
|
||||
def test_split_prefill_chunks(seq_lens, max_buf, expected):
|
||||
out = split_prefill_chunks(seq_lens, max_buf)
|
||||
assert out == expected
|
||||
|
||||
|
||||
def test_triton_convert_returns_valid_counts():
|
||||
"""Test that return_valid_counts correctly counts non-negative indices."""
|
||||
device = torch.device("cuda")
|
||||
num_tokens = 8
|
||||
num_requests = 2
|
||||
max_blocks_per_req = 10
|
||||
block_size = 64
|
||||
num_topk_tokens = 128
|
||||
|
||||
req_id = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1], dtype=torch.int32, device=device)
|
||||
block_table = torch.arange(
|
||||
num_requests * max_blocks_per_req, dtype=torch.int32, device=device
|
||||
).view(num_requests, max_blocks_per_req)
|
||||
|
||||
# Create token indices with varying numbers of valid entries
|
||||
# Token 0: 64 valid, 64 invalid (-1)
|
||||
# Token 1: 32 valid, 96 invalid
|
||||
# Token 2: 128 valid (all)
|
||||
# Token 3: 1 valid, 127 invalid
|
||||
# etc.
|
||||
token_indices = torch.full(
|
||||
(num_tokens, num_topk_tokens), -1, dtype=torch.int32, device=device
|
||||
)
|
||||
expected_valid = []
|
||||
for i in range(num_tokens):
|
||||
num_valid = [64, 32, 128, 1, 64, 32, 128, 1][i]
|
||||
token_indices[i, :num_valid] = torch.arange(
|
||||
num_valid, dtype=torch.int32, device=device
|
||||
) % (block_size * max_blocks_per_req)
|
||||
expected_valid.append(num_valid)
|
||||
|
||||
expected_valid_tensor = torch.tensor(
|
||||
expected_valid, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# Test with return_valid_counts=True
|
||||
result, valid_counts = triton_convert_req_index_to_global_index(
|
||||
req_id,
|
||||
block_table,
|
||||
token_indices,
|
||||
BLOCK_SIZE=block_size,
|
||||
NUM_TOPK_TOKENS=num_topk_tokens,
|
||||
return_valid_counts=True,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(valid_counts, expected_valid_tensor, rtol=0, atol=0)
|
||||
|
||||
# Test that return_valid_counts=False returns only the indices
|
||||
result_only = triton_convert_req_index_to_global_index(
|
||||
req_id,
|
||||
block_table,
|
||||
token_indices,
|
||||
BLOCK_SIZE=block_size,
|
||||
NUM_TOPK_TOKENS=num_topk_tokens,
|
||||
return_valid_counts=False,
|
||||
)
|
||||
assert isinstance(result_only, torch.Tensor)
|
||||
torch.testing.assert_close(result_only, result, rtol=0, atol=0)
|
||||
|
||||
Reference in New Issue
Block a user