[Attention] Move MLA forward from backend to layer (#33284)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2026-01-30 22:30:00 -05:00
committed by GitHub
parent 010ec0c30e
commit aaa901ad55
13 changed files with 753 additions and 535 deletions

View File

@@ -12,7 +12,7 @@ import torch
from tests.v1.attention.test_mla_backends import (
BATCH_SPECS,
BatchSpec,
MockAttentionLayer,
MockSparseMLAAttentionLayer,
create_and_prepopulate_kv_cache,
)
from tests.v1.attention.utils import (
@@ -408,20 +408,31 @@ def test_sparse_backend_decode_correctness(
impl.process_weights_after_loading(dtype)
layer = MockAttentionLayer(device)
# Create mock sparse MLA layer with weight matrices
mock_layer = MockSparseMLAAttentionLayer(
impl=impl,
num_heads=num_heads,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
v_head_dim=v_head_dim,
kv_lora_rank=kv_lora_rank,
device=device,
W_UK=W_UK,
W_UV=W_UV,
)
out_buffer = torch.empty(
metadata.num_actual_tokens, num_heads * v_head_dim, dtype=dtype, device=device
)
with torch.inference_mode():
backend_output = impl.forward(
layer,
backend_output = mock_layer.forward_impl(
query_vllm,
kv_c_vllm,
k_pe_vllm,
kv_cache,
metadata,
output=out_buffer,
out_buffer,
)
assert backend_output.shape == sdpa_reference.shape