[Bugfix] Fix DCP + FA3 crash due to missing num_splits in _forward_with_dcp (#35082)
Signed-off-by: haosdent <haosdent@gmail.com>
This commit is contained in:
@@ -847,6 +847,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
num_splits=attn_metadata.max_num_splits,
|
||||
)
|
||||
# FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ]
|
||||
context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs(
|
||||
@@ -876,6 +877,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
num_splits=attn_metadata.max_num_splits,
|
||||
)
|
||||
assert context_attn_out_cor.shape == query_attn_out.shape
|
||||
assert context_lse_cor.shape == query_lse.shape
|
||||
|
||||
Reference in New Issue
Block a user