[Feature] Support Decode Context Parallel (DCP) for MLA (#23734)
Signed-off-by: hongchao <hongchao@msh.team> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: hongchao <hongchao@msh.team> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -86,6 +86,12 @@ class FullAttentionSpec(AttentionSpec):
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
dcp_world_size = \
|
||||
vllm_config.parallel_config.decode_context_parallel_size
|
||||
# Note(hc): each dcp rank only need save
|
||||
# (max_model_len//dcp_world_size) tokens locally.
|
||||
if dcp_world_size > 1:
|
||||
max_model_len = cdiv(max_model_len, dcp_world_size)
|
||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||
|
||||
@classmethod
|
||||
@@ -162,6 +168,8 @@ class SlidingWindowSpec(AttentionSpec):
|
||||
assert not self.use_mla, "MLA is not supported for sliding window"
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
assert vllm_config.parallel_config.decode_context_parallel_size == 1, \
|
||||
"DCP not support sliding window."
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
max_num_batched_tokens = (
|
||||
vllm_config.scheduler_config.max_num_batched_tokens)
|
||||
|
||||
Reference in New Issue
Block a user