diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index d903bd89c..940dc7515 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -847,6 +847,7 @@ class FlashAttentionImpl(AttentionImpl): q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + num_splits=attn_metadata.max_num_splits, ) # FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ] context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs( @@ -876,6 +877,7 @@ class FlashAttentionImpl(AttentionImpl): q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + num_splits=attn_metadata.max_num_splits, ) assert context_attn_out_cor.shape == query_attn_out.shape assert context_lse_cor.shape == query_lse.shape