[DCP] Support Decode Context Parallel (DCP) for GQA with FlashAttention (#24864)
Signed-off-by: yuanyongjie.yyj <yuanyongjie.yyj@antgroup.com> Signed-off-by: FENP <32334296+FENP@users.noreply.github.com> Signed-off-by: Jaya Yuan <yuanyongjie.yyj@antgroup.com>
This commit is contained in:
@@ -1202,6 +1202,23 @@ class ModelConfig:
|
||||
"Supported models implement the `SupportsPP` interface."
|
||||
)
|
||||
|
||||
decode_context_parallel_size = parallel_config.decode_context_parallel_size
|
||||
if decode_context_parallel_size > 1 and not self.use_mla:
|
||||
total_num_kv_heads = self.get_total_num_kv_heads()
|
||||
assert tensor_parallel_size > total_num_kv_heads, (
|
||||
f"tensor parallel size {tensor_parallel_size} must be greater "
|
||||
f"than total num kv heads {total_num_kv_heads} when enable "
|
||||
f"decode context parallel for GQA/MQA"
|
||||
)
|
||||
|
||||
max_dcp_size = tensor_parallel_size // total_num_kv_heads
|
||||
assert decode_context_parallel_size <= max_dcp_size, (
|
||||
f"decode context parallel size must less than or equal to "
|
||||
f"(tensor parallel size {tensor_parallel_size} // total "
|
||||
f"num kv heads {total_num_kv_heads}) = {max_dcp_size}, "
|
||||
f"but got {decode_context_parallel_size}"
|
||||
)
|
||||
|
||||
def get_sliding_window(self) -> int | None:
|
||||
"""Get the sliding window size from the HF text config if present."""
|
||||
return getattr(self.hf_text_config, "sliding_window", None)
|
||||
|
||||
Reference in New Issue
Block a user