diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 8117349d8..47e4a7bbb 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Literal, get_args from pydantic import Field, SkipValidation, model_validator from typing_extensions import Self +from vllm.config import LoadConfig from vllm.config.model import ModelConfig from vllm.config.parallel import ParallelConfig from vllm.config.utils import config @@ -160,6 +161,10 @@ class SpeculativeConfig: tokens with estimated probability (based on frequency counts) greater than or equal to this value.""" + draft_load_config: LoadConfig | None = None + """Load config for the draft model. If not specified, will use the load + config from the target model.""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index e1d8d2ead..ff95d5b94 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -128,8 +128,9 @@ def get_model( vllm_config: VllmConfig, model_config: ModelConfig | None = None, prefix: str = "", + load_config: LoadConfig | None = None, ) -> nn.Module: - loader = get_model_loader(vllm_config.load_config) + loader = get_model_loader(load_config or vllm_config.load_config) if model_config is None: model_config = vllm_config.model_config return loader.load_model( diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index d29ee00fa..b5532d652 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1286,6 +1286,7 @@ class SpecDecodeBaseProposer: model = get_model( vllm_config=self.vllm_config, model_config=self.speculative_config.draft_model_config, + load_config=self.speculative_config.draft_load_config, ) return model