[Attention] Move MLA forward from backend to layer (#33284)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -12,7 +12,7 @@ import torch
|
||||
from tests.v1.attention.test_mla_backends import (
|
||||
BATCH_SPECS,
|
||||
BatchSpec,
|
||||
MockAttentionLayer,
|
||||
MockSparseMLAAttentionLayer,
|
||||
create_and_prepopulate_kv_cache,
|
||||
)
|
||||
from tests.v1.attention.utils import (
|
||||
@@ -408,20 +408,31 @@ def test_sparse_backend_decode_correctness(
|
||||
|
||||
impl.process_weights_after_loading(dtype)
|
||||
|
||||
layer = MockAttentionLayer(device)
|
||||
# Create mock sparse MLA layer with weight matrices
|
||||
mock_layer = MockSparseMLAAttentionLayer(
|
||||
impl=impl,
|
||||
num_heads=num_heads,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
device=device,
|
||||
W_UK=W_UK,
|
||||
W_UV=W_UV,
|
||||
)
|
||||
|
||||
out_buffer = torch.empty(
|
||||
metadata.num_actual_tokens, num_heads * v_head_dim, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
backend_output = impl.forward(
|
||||
layer,
|
||||
backend_output = mock_layer.forward_impl(
|
||||
query_vllm,
|
||||
kv_c_vllm,
|
||||
k_pe_vllm,
|
||||
kv_cache,
|
||||
metadata,
|
||||
output=out_buffer,
|
||||
out_buffer,
|
||||
)
|
||||
|
||||
assert backend_output.shape == sdpa_reference.shape
|
||||
|
||||
Reference in New Issue
Block a user