[Speculative Decoding] Support draft model on different tensor-parallel size than target model (#5414)

This commit is contained in:
Woo-Yeon Lee
2024-06-25 18:56:06 +09:00
committed by GitHub
parent f23871e9ee
commit 2ce5d6688b
11 changed files with 388 additions and 59 deletions

View File

@@ -94,6 +94,7 @@ class EngineArgs:
guided_decoding_backend: str = 'outlines'
# Speculative decoding configuration.
speculative_model: Optional[str] = None
speculative_draft_tensor_parallel_size: Optional[int] = None
num_speculative_tokens: Optional[int] = None
speculative_max_model_len: Optional[int] = None
speculative_disable_by_batch_size: Optional[int] = None
@@ -537,6 +538,13 @@ class EngineArgs:
default=EngineArgs.num_speculative_tokens,
help='The number of speculative tokens to sample from '
'the draft model in speculative decoding.')
parser.add_argument(
'--speculative-draft-tensor-parallel-size',
'-spec-draft-tp',
type=int,
default=EngineArgs.speculative_draft_tensor_parallel_size,
help='Number of tensor parallel replicas for '
'the draft model in speculative decoding.')
parser.add_argument(
'--speculative-max-model-len',
@@ -686,6 +694,8 @@ class EngineArgs:
target_parallel_config=parallel_config,
target_dtype=self.dtype,
speculative_model=self.speculative_model,
speculative_draft_tensor_parallel_size = \
self.speculative_draft_tensor_parallel_size,
num_speculative_tokens=self.num_speculative_tokens,
speculative_disable_by_batch_size=self.
speculative_disable_by_batch_size,