[Attention] Use FA4 for MLA prefill (#34732)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2026-03-12 12:10:17 -04:00
committed by GitHub
parent 85199f9681
commit f444c05c32
9 changed files with 413 additions and 78 deletions

View File

@@ -77,6 +77,7 @@ class MockKVBProj:
self.qk_nope_head_dim = qk_nope_head_dim
self.v_head_dim = v_head_dim
self.out_dim = qk_nope_head_dim + v_head_dim
self.weight = torch.empty(0, dtype=torch.bfloat16)
def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor]:
"""
@@ -213,6 +214,7 @@ class BenchmarkConfig:
use_cuda_graphs: bool = False
# MLA-specific
prefill_backend: str | None = None
kv_lora_rank: int | None = None
qk_nope_head_dim: int | None = None
qk_rope_head_dim: int | None = None