[DCP] Support Decode Context Parallel (DCP) for GQA with FlashAttention (#24864)

Signed-off-by: yuanyongjie.yyj <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <32334296+FENP@users.noreply.github.com>
Signed-off-by: Jaya Yuan <yuanyongjie.yyj@antgroup.com>
This commit is contained in:
Jaya Yuan
2025-10-14 21:07:50 +08:00
committed by GitHub
parent fdd32750f0
commit ea97940d6c
7 changed files with 209 additions and 33 deletions

View File

@@ -1202,6 +1202,23 @@ class ModelConfig:
"Supported models implement the `SupportsPP` interface."
)
decode_context_parallel_size = parallel_config.decode_context_parallel_size
if decode_context_parallel_size > 1 and not self.use_mla:
total_num_kv_heads = self.get_total_num_kv_heads()
assert tensor_parallel_size > total_num_kv_heads, (
f"tensor parallel size {tensor_parallel_size} must be greater "
f"than total num kv heads {total_num_kv_heads} when enable "
f"decode context parallel for GQA/MQA"
)
max_dcp_size = tensor_parallel_size // total_num_kv_heads
assert decode_context_parallel_size <= max_dcp_size, (
f"decode context parallel size must less than or equal to "
f"(tensor parallel size {tensor_parallel_size} // total "
f"num kv heads {total_num_kv_heads}) = {max_dcp_size}, "
f"but got {decode_context_parallel_size}"
)
def get_sliding_window(self) -> int | None:
"""Get the sliding window size from the HF text config if present."""
return getattr(self.hf_text_config, "sliding_window", None)