[BugFix] Fix FA3 IMA with FULL_AND_PIECEWISE and cascade attention (default) (#28702)

(cherry picked from commit db56a59970)
This commit is contained in:
Lucas Wilkinson
2025-11-14 07:19:22 -05:00
committed by Kevin H. Luu
parent f7adf64aac
commit c505dd6b61
2 changed files with 5 additions and 2 deletions

View File

@@ -170,6 +170,7 @@ def test_cascade(
logits_soft_cap=soft_cap if soft_cap is not None else 0,
block_table=block_tables,
common_prefix_len=common_prefix_len,
max_num_splits=0, # no max
fa_version=fa_version,
)

View File

@@ -675,6 +675,7 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
max_num_splits=attn_metadata.max_num_splits,
fa_version=self.vllm_flash_attn_version,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
@@ -921,6 +922,7 @@ def cascade_attention(
logits_soft_cap: float,
block_table: torch.Tensor,
common_prefix_len: int,
max_num_splits: int,
fa_version: int,
prefix_scheduler_metadata: torch.Tensor | None = None,
suffix_scheduler_metadata: torch.Tensor | None = None,
@@ -965,7 +967,7 @@ def cascade_attention(
# s_aux is incorporated into prefix_lse inside the GPU kernel,
# enabling its effect during the final attention merge.
s_aux=s_aux,
num_splits=1 if vllm_is_batch_invariant() else 0,
num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
)
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
@@ -990,7 +992,7 @@ def cascade_attention(
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
num_splits=1 if vllm_is_batch_invariant() else 0,
num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
)
# Merge prefix and suffix outputs, and store the result in output.