[Misc][Spec Decode] support different load config for draft model (#34022)
Signed-off-by: zzhengkai <zzhengkai@devgpu049.ldc1.facebook.com> Co-authored-by: zzhengkai <zzhengkai@devgpu049.ldc1.facebook.com>
This commit is contained in:
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Literal, get_args
|
||||
from pydantic import Field, SkipValidation, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.config import LoadConfig
|
||||
from vllm.config.model import ModelConfig
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.config.utils import config
|
||||
@@ -160,6 +161,10 @@ class SpeculativeConfig:
|
||||
tokens with estimated probability (based on frequency counts) greater than
|
||||
or equal to this value."""
|
||||
|
||||
draft_load_config: LoadConfig | None = None
|
||||
"""Load config for the draft model. If not specified, will use the load
|
||||
config from the target model."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
|
||||
@@ -128,8 +128,9 @@ def get_model(
|
||||
vllm_config: VllmConfig,
|
||||
model_config: ModelConfig | None = None,
|
||||
prefix: str = "",
|
||||
load_config: LoadConfig | None = None,
|
||||
) -> nn.Module:
|
||||
loader = get_model_loader(vllm_config.load_config)
|
||||
loader = get_model_loader(load_config or vllm_config.load_config)
|
||||
if model_config is None:
|
||||
model_config = vllm_config.model_config
|
||||
return loader.load_model(
|
||||
|
||||
@@ -1286,6 +1286,7 @@ class SpecDecodeBaseProposer:
|
||||
model = get_model(
|
||||
vllm_config=self.vllm_config,
|
||||
model_config=self.speculative_config.draft_model_config,
|
||||
load_config=self.speculative_config.draft_load_config,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user