[Attention] FlashAttn MLA (#14258)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -15,7 +15,7 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
BACKENDS_TO_TEST = [
|
||||
_Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1,
|
||||
_Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1, _Backend.FLASH_ATTN_MLA,
|
||||
_Backend.TRITON_MLA_VLLM_V1
|
||||
]
|
||||
|
||||
@@ -69,20 +69,6 @@ BATCH_SPECS = {
|
||||
}
|
||||
|
||||
|
||||
def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec,
|
||||
device: torch.device,
|
||||
num_blocks: int = 100) -> torch.Tensor:
|
||||
"""Create a dummy KV cache tensor for testing."""
|
||||
kv_cache = torch.randn(
|
||||
num_blocks,
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.head_size, # latent dimension
|
||||
dtype=_convert_dtype_to_torch(kv_cache_spec.dtype),
|
||||
device=device,
|
||||
)
|
||||
return kv_cache
|
||||
|
||||
|
||||
def create_and_prepopulate_kv_cache(
|
||||
kv_c_contexts: list[torch.Tensor],
|
||||
k_pe_contexts: list[torch.Tensor],
|
||||
@@ -315,7 +301,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
|
||||
# 2. Generate data and compute SDPA reference output for MLA
|
||||
all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], []
|
||||
all_sdpa_outputs = []
|
||||
all_sdpa_outputs: list[list[torch.Tensor]] = []
|
||||
kv_c_contexts, k_pe_contexts = [], []
|
||||
|
||||
# Create shared MLA weight matrices for consistency across all sequences
|
||||
@@ -331,6 +317,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
device=device)
|
||||
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
|
||||
|
||||
for i, backend in enumerate(BACKENDS_TO_TEST):
|
||||
all_sdpa_outputs.append([])
|
||||
|
||||
for i in range(batch_size):
|
||||
s_len = seq_lens[i]
|
||||
q_len = query_lens[i]
|
||||
@@ -358,85 +347,93 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
# Determine if this is decode (single token)
|
||||
# or prefill (multiple tokens)
|
||||
is_decode = q_len == 1
|
||||
# Determine if this is decode or prefill
|
||||
is_decode = []
|
||||
for i, backend in enumerate(BACKENDS_TO_TEST):
|
||||
builder_cls, _ = get_attention_backend(backend)
|
||||
is_decode.append(q_len <= builder_cls.reorder_batch_threshold)
|
||||
|
||||
# Split q into nope and rope components
|
||||
q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
|
||||
|
||||
if is_decode:
|
||||
# Decode path: MQA-style attention in latent space
|
||||
# Transform q_nope to latent space: q_nope @ W_UK
|
||||
# q_nope: [1, num_heads, qk_nope_head_dim]
|
||||
# W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim]
|
||||
ql_nope = torch.einsum("qnh,lnh->qnl", q_nope,
|
||||
W_UK) # [1, num_heads, kv_lora_rank]
|
||||
#######################################################
|
||||
# Decode path: MQA-style attention in latent space
|
||||
# Transform q_nope to latent space: q_nope @ W_UK
|
||||
# q_nope: [1, num_heads, qk_nope_head_dim]
|
||||
# W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim]
|
||||
ql_nope = torch.einsum("qnh,lnh->qnl", q_nope,
|
||||
W_UK) # [1, num_heads, kv_lora_rank]
|
||||
|
||||
# Build MQA attention inputs
|
||||
# Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim]
|
||||
q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
|
||||
# K: [s_len, kv_lora_rank + qk_rope_head_dim]
|
||||
# (broadcasted to all heads)
|
||||
k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1)
|
||||
k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1)
|
||||
# V: [s_len, kv_lora_rank] (broadcasted to all heads)
|
||||
v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1)
|
||||
# Build MQA attention inputs
|
||||
# Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim]
|
||||
q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
|
||||
# K: [s_len, kv_lora_rank + qk_rope_head_dim]
|
||||
# (broadcasted to all heads)
|
||||
k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1)
|
||||
k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1)
|
||||
# V: [s_len, kv_lora_rank] (broadcasted to all heads)
|
||||
v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1)
|
||||
|
||||
# SDPA expects (N, H, L, D)
|
||||
q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
|
||||
k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
|
||||
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
|
||||
# Create custom attention mask for decode path:
|
||||
# - Query tokens can attend to all context tokens
|
||||
# - Query tokens can only attend to query tokens up to their position
|
||||
attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
|
||||
# Apply causal mask only to the query portion (context_len onwards)
|
||||
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
|
||||
attn_mask[:, context_len:] = causal_mask
|
||||
|
||||
sdpa_out_i = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, is_causal=False, scale=scale)
|
||||
sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze(
|
||||
0) # [1, num_heads, kv_lora_rank]
|
||||
# SDPA expects (N, H, L, D)
|
||||
q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
|
||||
k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
|
||||
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
|
||||
|
||||
# Project back to output space: sdpa_out @ W_UV
|
||||
sdpa_out_i = torch.einsum("qnl,lnv->qnv", sdpa_out_i, W_UV)
|
||||
sdpa_out_i = sdpa_out_i.flatten(start_dim=-2)
|
||||
else:
|
||||
# Prefill path: MHA-style attention with full sequence
|
||||
# Apply kv_b_proj to the full kv_c tensor
|
||||
kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full,
|
||||
kv_b_proj_weight)
|
||||
k_nope_full, v_full = kv_nope_full.split(
|
||||
[qk_nope_head_dim, v_head_dim], dim=-1)
|
||||
sdpa_out_i_decode = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
|
||||
sdpa_out_i_decode = sdpa_out_i_decode.transpose(1, 2).squeeze(
|
||||
0) # [1, num_heads, kv_lora_rank]
|
||||
|
||||
# Build attention inputs for full sequence
|
||||
q_mha = torch.cat([q_nope, q_pe],
|
||||
dim=-1) # [q_len, num_heads, total_dim]
|
||||
k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1)
|
||||
k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1)
|
||||
# Project back to output space: sdpa_out @ W_UV
|
||||
sdpa_out_i_decode = torch.einsum("qnl,lnv->qnv", sdpa_out_i_decode,
|
||||
W_UV)
|
||||
sdpa_out_i_decode = sdpa_out_i_decode.flatten(start_dim=-2)
|
||||
|
||||
# Create custom attention mask:
|
||||
# - Query tokens can attend to all context tokens
|
||||
# - Query tokens can only attend to query tokens up to their pos
|
||||
attn_mask = torch.ones(q_len,
|
||||
s_len,
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
# Apply causal mask only to the query portion (context_len onwards)
|
||||
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
|
||||
attn_mask[:, context_len:] = causal_mask
|
||||
#######################################################
|
||||
# Prefill path: MHA-style attention with full sequence
|
||||
# Apply kv_b_proj to the full kv_c tensor
|
||||
kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full, kv_b_proj_weight)
|
||||
k_nope_full, v_full = kv_nope_full.split(
|
||||
[qk_nope_head_dim, v_head_dim], dim=-1)
|
||||
|
||||
# SDPA expects (N, H, L, D)
|
||||
q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2)
|
||||
k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2)
|
||||
v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2)
|
||||
# Build attention inputs for full sequence
|
||||
q_mha = torch.cat([q_nope, q_pe],
|
||||
dim=-1) # [q_len, num_heads, total_dim]
|
||||
k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1)
|
||||
k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1)
|
||||
|
||||
# Single attention call with custom mask
|
||||
sdpa_out_i = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_sdpa_in,
|
||||
k_sdpa_in,
|
||||
v_sdpa_in,
|
||||
attn_mask=attn_mask,
|
||||
scale=scale)
|
||||
sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze(0)
|
||||
sdpa_out_i = sdpa_out_i.flatten(start_dim=-2)
|
||||
# Create custom attention mask:
|
||||
# - Query tokens can attend to all context tokens
|
||||
# - Query tokens can only attend to query tokens up to their pos
|
||||
attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
|
||||
# Apply causal mask only to the query portion (context_len onwards)
|
||||
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
|
||||
attn_mask[:, context_len:] = causal_mask
|
||||
|
||||
all_sdpa_outputs.append(sdpa_out_i)
|
||||
# SDPA expects (N, H, L, D)
|
||||
q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2)
|
||||
k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2)
|
||||
v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2)
|
||||
|
||||
# Single attention call with custom mask
|
||||
sdpa_out_i_prefill = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
|
||||
sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0)
|
||||
sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2)
|
||||
|
||||
for i, backend in enumerate(BACKENDS_TO_TEST):
|
||||
if is_decode[i]:
|
||||
all_sdpa_outputs[i].append(sdpa_out_i_decode)
|
||||
else:
|
||||
all_sdpa_outputs[i].append(sdpa_out_i_prefill)
|
||||
|
||||
# Inputs for vLLM MLA backends are just the new tokens
|
||||
all_q_vllm.append(q_c)
|
||||
@@ -451,7 +448,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
query_vllm = torch.cat(all_q_vllm, dim=0)
|
||||
kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
|
||||
k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
|
||||
sdpa_output = torch.cat(all_sdpa_outputs, dim=0)
|
||||
sdpa_outputs = []
|
||||
for i, backend in enumerate(BACKENDS_TO_TEST):
|
||||
sdpa_outputs.append(torch.cat(all_sdpa_outputs[i], dim=0))
|
||||
|
||||
# Create mock kv_b_proj using the same weights as reference implementation
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
@@ -486,7 +485,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
randomize_blocks=True)
|
||||
|
||||
# 4. Run vLLM backends and compare
|
||||
for backend_name in BACKENDS_TO_TEST:
|
||||
for i, backend_name in enumerate(BACKENDS_TO_TEST):
|
||||
backend_output = run_attention_backend(
|
||||
backend_name, kv_cache_spec, ["placeholder"], vllm_config, device,
|
||||
common_attn_metadata, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache,
|
||||
@@ -494,12 +493,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
mock_kv_b_proj)
|
||||
|
||||
# Check shape and dtype consistency
|
||||
assert backend_output.shape == sdpa_output.shape, (
|
||||
assert backend_output.shape == sdpa_outputs[i].shape, (
|
||||
f"[{backend_name}] shape {backend_output.shape} != "
|
||||
f"SDPA shape {sdpa_output.shape}")
|
||||
assert backend_output.dtype == sdpa_output.dtype, (
|
||||
f"SDPA shape {sdpa_outputs[i].shape}")
|
||||
assert backend_output.dtype == sdpa_outputs[i].dtype, (
|
||||
f"[{backend_name}] dtype {backend_output.dtype} != "
|
||||
f"SDPA dtype {sdpa_output.dtype}")
|
||||
f"SDPA dtype {sdpa_outputs[i].dtype}")
|
||||
|
||||
assert torch.isfinite(backend_output).all(), (
|
||||
f"[{backend_name}] produced non-finite values")
|
||||
@@ -508,12 +507,13 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
rtol = 1e-2
|
||||
atol = 5e-1
|
||||
|
||||
max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item()
|
||||
max_diff = torch.max(torch.abs(backend_output -
|
||||
sdpa_outputs[i])).item()
|
||||
max_rel_diff = torch.max(
|
||||
torch.abs(backend_output - sdpa_output) /
|
||||
torch.abs(sdpa_output)).item()
|
||||
torch.abs(backend_output - sdpa_outputs[i]) /
|
||||
torch.abs(sdpa_outputs[i])).item()
|
||||
all_close = torch.allclose(backend_output,
|
||||
sdpa_output,
|
||||
sdpa_outputs[i],
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user