From 6f2f59f2b333151aac19f8ca7bf71d83c1a7c068 Mon Sep 17 00:00:00 2001 From: Zhengkai Zhang <33679250+ZhengkaiZ@users.noreply.github.com> Date: Tue, 10 Feb 2026 14:52:43 -0800 Subject: [PATCH] [Misc][Spec Decode] support different load config for draft model (#34022) Signed-off-by: zzhengkai Co-authored-by: zzhengkai --- vllm/config/speculative.py | 5 +++++ vllm/model_executor/model_loader/__init__.py | 3 ++- vllm/v1/spec_decode/eagle.py | 1 + 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 8117349d8..47e4a7bbb 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Literal, get_args from pydantic import Field, SkipValidation, model_validator from typing_extensions import Self +from vllm.config import LoadConfig from vllm.config.model import ModelConfig from vllm.config.parallel import ParallelConfig from vllm.config.utils import config @@ -160,6 +161,10 @@ class SpeculativeConfig: tokens with estimated probability (based on frequency counts) greater than or equal to this value.""" + draft_load_config: LoadConfig | None = None + """Load config for the draft model. If not specified, will use the load + config from the target model.""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index e1d8d2ead..ff95d5b94 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -128,8 +128,9 @@ def get_model( vllm_config: VllmConfig, model_config: ModelConfig | None = None, prefix: str = "", + load_config: LoadConfig | None = None, ) -> nn.Module: - loader = get_model_loader(vllm_config.load_config) + loader = get_model_loader(load_config or vllm_config.load_config) if model_config is None: model_config = vllm_config.model_config return loader.load_model( diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index d29ee00fa..b5532d652 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1286,6 +1286,7 @@ class SpecDecodeBaseProposer: model = get_model( vllm_config=self.vllm_config, model_config=self.speculative_config.draft_model_config, + load_config=self.speculative_config.draft_load_config, ) return model