[Attention] FlashAttn MLA (#14258)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -52,8 +52,9 @@ class LinearAttentionMetadataBuilder(
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(common_attn_metadata,
|
||||
decode_threshold=1))
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold))
|
||||
|
||||
attn_metadata = LinearAttentionMetadata(
|
||||
num_prefills=num_prefills,
|
||||
|
||||
Reference in New Issue
Block a user