[Attention] FlashAttn MLA (#14258)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Lucas Wilkinson
2025-09-04 05:47:59 -04:00
committed by GitHub
parent 2c301ee2eb
commit 402759d472
22 changed files with 480 additions and 200 deletions

View File

@@ -15,7 +15,7 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec
BACKENDS_TO_TEST = [
_Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1,
_Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1, _Backend.FLASH_ATTN_MLA,
_Backend.TRITON_MLA_VLLM_V1
]
@@ -69,20 +69,6 @@ BATCH_SPECS = {
}
def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec,
device: torch.device,
num_blocks: int = 100) -> torch.Tensor:
"""Create a dummy KV cache tensor for testing."""
kv_cache = torch.randn(
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.head_size, # latent dimension
dtype=_convert_dtype_to_torch(kv_cache_spec.dtype),
device=device,
)
return kv_cache
def create_and_prepopulate_kv_cache(
kv_c_contexts: list[torch.Tensor],
k_pe_contexts: list[torch.Tensor],
@@ -315,7 +301,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
# 2. Generate data and compute SDPA reference output for MLA
all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], []
all_sdpa_outputs = []
all_sdpa_outputs: list[list[torch.Tensor]] = []
kv_c_contexts, k_pe_contexts = [], []
# Create shared MLA weight matrices for consistency across all sequences
@@ -331,6 +317,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
device=device)
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
for i, backend in enumerate(BACKENDS_TO_TEST):
all_sdpa_outputs.append([])
for i in range(batch_size):
s_len = seq_lens[i]
q_len = query_lens[i]
@@ -358,85 +347,93 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
dtype=dtype,
device=device)
# Determine if this is decode (single token)
# or prefill (multiple tokens)
is_decode = q_len == 1
# Determine if this is decode or prefill
is_decode = []
for i, backend in enumerate(BACKENDS_TO_TEST):
builder_cls, _ = get_attention_backend(backend)
is_decode.append(q_len <= builder_cls.reorder_batch_threshold)
# Split q into nope and rope components
q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
if is_decode:
# Decode path: MQA-style attention in latent space
# Transform q_nope to latent space: q_nope @ W_UK
# q_nope: [1, num_heads, qk_nope_head_dim]
# W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim]
ql_nope = torch.einsum("qnh,lnh->qnl", q_nope,
W_UK) # [1, num_heads, kv_lora_rank]
#######################################################
# Decode path: MQA-style attention in latent space
# Transform q_nope to latent space: q_nope @ W_UK
# q_nope: [1, num_heads, qk_nope_head_dim]
# W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim]
ql_nope = torch.einsum("qnh,lnh->qnl", q_nope,
W_UK) # [1, num_heads, kv_lora_rank]
# Build MQA attention inputs
# Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim]
q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
# K: [s_len, kv_lora_rank + qk_rope_head_dim]
# (broadcasted to all heads)
k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1)
k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1)
# V: [s_len, kv_lora_rank] (broadcasted to all heads)
v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1)
# Build MQA attention inputs
# Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim]
q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
# K: [s_len, kv_lora_rank + qk_rope_head_dim]
# (broadcasted to all heads)
k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1)
k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1)
# V: [s_len, kv_lora_rank] (broadcasted to all heads)
v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1)
# SDPA expects (N, H, L, D)
q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
# Create custom attention mask for decode path:
# - Query tokens can attend to all context tokens
# - Query tokens can only attend to query tokens up to their position
attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
# Apply causal mask only to the query portion (context_len onwards)
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
attn_mask[:, context_len:] = causal_mask
sdpa_out_i = torch.nn.functional.scaled_dot_product_attention(
q_sdpa_in, k_sdpa_in, v_sdpa_in, is_causal=False, scale=scale)
sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze(
0) # [1, num_heads, kv_lora_rank]
# SDPA expects (N, H, L, D)
q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
# Project back to output space: sdpa_out @ W_UV
sdpa_out_i = torch.einsum("qnl,lnv->qnv", sdpa_out_i, W_UV)
sdpa_out_i = sdpa_out_i.flatten(start_dim=-2)
else:
# Prefill path: MHA-style attention with full sequence
# Apply kv_b_proj to the full kv_c tensor
kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full,
kv_b_proj_weight)
k_nope_full, v_full = kv_nope_full.split(
[qk_nope_head_dim, v_head_dim], dim=-1)
sdpa_out_i_decode = torch.nn.functional.scaled_dot_product_attention(
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
sdpa_out_i_decode = sdpa_out_i_decode.transpose(1, 2).squeeze(
0) # [1, num_heads, kv_lora_rank]
# Build attention inputs for full sequence
q_mha = torch.cat([q_nope, q_pe],
dim=-1) # [q_len, num_heads, total_dim]
k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1)
k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1)
# Project back to output space: sdpa_out @ W_UV
sdpa_out_i_decode = torch.einsum("qnl,lnv->qnv", sdpa_out_i_decode,
W_UV)
sdpa_out_i_decode = sdpa_out_i_decode.flatten(start_dim=-2)
# Create custom attention mask:
# - Query tokens can attend to all context tokens
# - Query tokens can only attend to query tokens up to their pos
attn_mask = torch.ones(q_len,
s_len,
dtype=torch.bool,
device=device)
# Apply causal mask only to the query portion (context_len onwards)
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
attn_mask[:, context_len:] = causal_mask
#######################################################
# Prefill path: MHA-style attention with full sequence
# Apply kv_b_proj to the full kv_c tensor
kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full, kv_b_proj_weight)
k_nope_full, v_full = kv_nope_full.split(
[qk_nope_head_dim, v_head_dim], dim=-1)
# SDPA expects (N, H, L, D)
q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2)
k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2)
v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2)
# Build attention inputs for full sequence
q_mha = torch.cat([q_nope, q_pe],
dim=-1) # [q_len, num_heads, total_dim]
k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1)
k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1)
# Single attention call with custom mask
sdpa_out_i = torch.nn.functional.scaled_dot_product_attention(
q_sdpa_in,
k_sdpa_in,
v_sdpa_in,
attn_mask=attn_mask,
scale=scale)
sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze(0)
sdpa_out_i = sdpa_out_i.flatten(start_dim=-2)
# Create custom attention mask:
# - Query tokens can attend to all context tokens
# - Query tokens can only attend to query tokens up to their pos
attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
# Apply causal mask only to the query portion (context_len onwards)
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
attn_mask[:, context_len:] = causal_mask
all_sdpa_outputs.append(sdpa_out_i)
# SDPA expects (N, H, L, D)
q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2)
k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2)
v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2)
# Single attention call with custom mask
sdpa_out_i_prefill = torch.nn.functional.scaled_dot_product_attention(
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0)
sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2)
for i, backend in enumerate(BACKENDS_TO_TEST):
if is_decode[i]:
all_sdpa_outputs[i].append(sdpa_out_i_decode)
else:
all_sdpa_outputs[i].append(sdpa_out_i_prefill)
# Inputs for vLLM MLA backends are just the new tokens
all_q_vllm.append(q_c)
@@ -451,7 +448,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
query_vllm = torch.cat(all_q_vllm, dim=0)
kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
sdpa_output = torch.cat(all_sdpa_outputs, dim=0)
sdpa_outputs = []
for i, backend in enumerate(BACKENDS_TO_TEST):
sdpa_outputs.append(torch.cat(all_sdpa_outputs[i], dim=0))
# Create mock kv_b_proj using the same weights as reference implementation
from vllm.model_executor.layers.linear import ColumnParallelLinear
@@ -486,7 +485,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
randomize_blocks=True)
# 4. Run vLLM backends and compare
for backend_name in BACKENDS_TO_TEST:
for i, backend_name in enumerate(BACKENDS_TO_TEST):
backend_output = run_attention_backend(
backend_name, kv_cache_spec, ["placeholder"], vllm_config, device,
common_attn_metadata, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache,
@@ -494,12 +493,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
mock_kv_b_proj)
# Check shape and dtype consistency
assert backend_output.shape == sdpa_output.shape, (
assert backend_output.shape == sdpa_outputs[i].shape, (
f"[{backend_name}] shape {backend_output.shape} != "
f"SDPA shape {sdpa_output.shape}")
assert backend_output.dtype == sdpa_output.dtype, (
f"SDPA shape {sdpa_outputs[i].shape}")
assert backend_output.dtype == sdpa_outputs[i].dtype, (
f"[{backend_name}] dtype {backend_output.dtype} != "
f"SDPA dtype {sdpa_output.dtype}")
f"SDPA dtype {sdpa_outputs[i].dtype}")
assert torch.isfinite(backend_output).all(), (
f"[{backend_name}] produced non-finite values")
@@ -508,12 +507,13 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
rtol = 1e-2
atol = 5e-1
max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item()
max_diff = torch.max(torch.abs(backend_output -
sdpa_outputs[i])).item()
max_rel_diff = torch.max(
torch.abs(backend_output - sdpa_output) /
torch.abs(sdpa_output)).item()
torch.abs(backend_output - sdpa_outputs[i]) /
torch.abs(sdpa_outputs[i])).item()
all_close = torch.allclose(backend_output,
sdpa_output,
sdpa_outputs[i],
rtol=rtol,
atol=atol)