[Attention] FlashAttn MLA (#14258)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Lucas Wilkinson
2025-09-04 05:47:59 -04:00
committed by GitHub
parent 2c301ee2eb
commit 402759d472
22 changed files with 480 additions and 200 deletions

View File

@@ -52,8 +52,9 @@ class LinearAttentionMetadataBuilder(
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=1))
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold))
attn_metadata = LinearAttentionMetadata(
num_prefills=num_prefills,