[Bugfix] Fix DCP + FA3 crash due to missing num_splits in _forward_with_dcp (#35082)

Signed-off-by: haosdent <haosdent@gmail.com>
This commit is contained in:
haosdent
2026-02-27 22:27:06 +08:00
committed by GitHub
parent fbe3f0120a
commit 6d4f9d3ad5

View File

@@ -847,6 +847,7 @@ class FlashAttentionImpl(AttentionImpl):
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
num_splits=attn_metadata.max_num_splits,
)
# FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ]
context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs(
@@ -876,6 +877,7 @@ class FlashAttentionImpl(AttentionImpl):
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
num_splits=attn_metadata.max_num_splits,
)
assert context_attn_out_cor.shape == query_attn_out.shape
assert context_lse_cor.shape == query_lse.shape