Disable spec-decode + chunked-prefill for draft models with tensor parallelism > 1 (#10136)

Signed-off-by: Sourashis Roy <sroy@roblox.com>
This commit is contained in:
sroy745
2024-11-08 07:56:18 -08:00
committed by GitHub
parent 0535e5fe6c
commit f6778620a9
2 changed files with 83 additions and 8 deletions

View File

@@ -1388,6 +1388,23 @@ class SpeculativeConfig:
"Chunked prefill and hidden-state based draft models are "
"not compatible.")
speculative_draft_tensor_parallel_size = \
SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size(
target_parallel_config,
speculative_draft_tensor_parallel_size,
draft_hf_config
)
if (enable_chunked_prefill and \
speculative_draft_tensor_parallel_size != 1):
# TODO - Investigate why the error reported in
# https://github.com/vllm-project/vllm/pull/9291#issuecomment-2463266258
# is happening and re-enable it.
raise ValueError(
"Chunked prefill and speculative decoding can be enabled "
"simultaneously only for draft models with tensor "
"parallel size 1.")
draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len(
speculative_max_model_len,
@@ -1466,15 +1483,16 @@ class SpeculativeConfig:
)
@staticmethod
def create_draft_parallel_config(
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: Optional[int],
draft_hf_config: PretrainedConfig,
) -> ParallelConfig:
"""Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config, except the tp_size.
def _verify_and_get_draft_model_tensor_parallel_size(
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: Optional[int],
draft_hf_config: PretrainedConfig) -> int:
"""
Verifies and adjusts the tensor parallel size for a draft model
specified using speculative_draft_tensor_parallel_size.
"""
# If speculative_draft_tensor_parallel_size is unset then set it
# appropriately else verify that it is set correctly.
if speculative_draft_tensor_parallel_size is None:
if draft_hf_config.model_type == "mlp_speculator":
speculative_draft_tensor_parallel_size = 1
@@ -1490,7 +1508,18 @@ class SpeculativeConfig:
raise ValueError(
f"{speculative_draft_tensor_parallel_size=} cannot be "
f"other value than 1 or target model tensor_parallel_size")
return speculative_draft_tensor_parallel_size
@staticmethod
def create_draft_parallel_config(
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: int,
draft_hf_config: PretrainedConfig,
) -> ParallelConfig:
"""Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config, except the tp_size.
"""
draft_parallel_config = ParallelConfig(
pipeline_parallel_size=target_parallel_config.
pipeline_parallel_size,