[Misc][Spec Decode] support different load config for draft model (#34022)

Signed-off-by: zzhengkai <zzhengkai@devgpu049.ldc1.facebook.com>
Co-authored-by: zzhengkai <zzhengkai@devgpu049.ldc1.facebook.com>
This commit is contained in:
Zhengkai Zhang
2026-02-10 14:52:43 -08:00
committed by GitHub
parent bb2fc8b5e7
commit 6f2f59f2b3
3 changed files with 8 additions and 1 deletions

View File

@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Literal, get_args
from pydantic import Field, SkipValidation, model_validator
from typing_extensions import Self
from vllm.config import LoadConfig
from vllm.config.model import ModelConfig
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
@@ -160,6 +161,10 @@ class SpeculativeConfig:
tokens with estimated probability (based on frequency counts) greater than
or equal to this value."""
draft_load_config: LoadConfig | None = None
"""Load config for the draft model. If not specified, will use the load
config from the target model."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,

View File

@@ -128,8 +128,9 @@ def get_model(
vllm_config: VllmConfig,
model_config: ModelConfig | None = None,
prefix: str = "",
load_config: LoadConfig | None = None,
) -> nn.Module:
loader = get_model_loader(vllm_config.load_config)
loader = get_model_loader(load_config or vllm_config.load_config)
if model_config is None:
model_config = vllm_config.model_config
return loader.load_model(

View File

@@ -1286,6 +1286,7 @@ class SpecDecodeBaseProposer:
model = get_model(
vllm_config=self.vllm_config,
model_config=self.speculative_config.draft_model_config,
load_config=self.speculative_config.draft_load_config,
)
return model