[Misc] Add quantization config support for speculative model. (#7343)

This commit is contained in:
shangmingc
2024-08-16 10:34:28 +08:00
committed by GitHub
parent 9c8e2d1161
commit b67ae00cdb
3 changed files with 71 additions and 4 deletions

View File

@@ -961,6 +961,7 @@ class SpeculativeConfig:
target_parallel_config: ParallelConfig,
target_dtype: str,
speculative_model: Optional[str],
speculative_model_quantization: Optional[str],
speculative_draft_tensor_parallel_size: Optional[int],
num_speculative_tokens: Optional[int],
speculative_max_model_len: Optional[int],
@@ -989,6 +990,9 @@ class SpeculativeConfig:
target_dtype (str): The data type used for the target model.
speculative_model (Optional[str]): The name of the speculative
model, if provided.
speculative_model_quantization (Optional[str]): Quantization method
that was used to quantize the speculative model weights. If
None, we assume the model weights are not quantized.
speculative_draft_tensor_parallel_size (Optional[int]): The degree
of the tensor parallelism for the draft model.
num_speculative_tokens (Optional[int]): The number of speculative
@@ -1056,11 +1060,11 @@ class SpeculativeConfig:
"Speculative decoding requires usage of the V2 "
"block manager. Enable it with --use-v2-block-manager.")
# TODO: The user should be able to specify revision/quantization/max
# model len for the draft model. It is not currently supported.
# TODO: The user should be able to specify revision/max model len
# for the draft model. It is not currently supported.
draft_revision = None
draft_code_revision = None
draft_quantization = None
draft_quantization = speculative_model_quantization
if speculative_model == "[ngram]":
if ngram_prompt_lookup_min is None:
@@ -1217,7 +1221,7 @@ class SpeculativeConfig:
elif speculative_draft_tensor_parallel_size != 1:
# TODO(wooyeon): allow tp values larger than 1
raise ValueError(
f"{speculative_draft_tensor_parallel_size=} cannot be"
f"{speculative_draft_tensor_parallel_size=} cannot be "
f"other value than 1")
draft_parallel_config = ParallelConfig(