feat: spec decode with draft models (#24322)

Signed-off-by: Tomas Ruiz <tomas.ruiz.te@gmail.com>
This commit is contained in:
Tomas Ruiz
2026-01-19 15:05:46 -06:00
committed by GitHub
parent 73f2a81c75
commit 4a5299c93f
21 changed files with 897 additions and 115 deletions

View File

@@ -124,12 +124,17 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
def get_model(
*, vllm_config: VllmConfig, model_config: ModelConfig | None = None
*,
vllm_config: VllmConfig,
model_config: ModelConfig | None = None,
prefix: str = "",
) -> nn.Module:
loader = get_model_loader(vllm_config.load_config)
if model_config is None:
model_config = vllm_config.model_config
return loader.load_model(vllm_config=vllm_config, model_config=model_config)
return loader.load_model(
vllm_config=vllm_config, model_config=model_config, prefix=prefix
)
__all__ = [

View File

@@ -36,7 +36,7 @@ class BaseModelLoader(ABC):
raise NotImplementedError
def load_model(
self, vllm_config: VllmConfig, model_config: ModelConfig
self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
) -> nn.Module:
"""Load a model with the given configurations."""
device_config = vllm_config.device_config
@@ -48,7 +48,7 @@ class BaseModelLoader(ABC):
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(
vllm_config=vllm_config, model_config=model_config
vllm_config=vllm_config, model_config=model_config, prefix=prefix
)
log_model_inspection(model)

View File

@@ -335,7 +335,7 @@ class GGUFModelLoader(BaseModelLoader):
)
def load_model(
self, vllm_config: VllmConfig, model_config: ModelConfig
self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
) -> nn.Module:
device_config = vllm_config.device_config
local_model_path = self._prepare_weights(model_config)
@@ -364,7 +364,7 @@ class GGUFModelLoader(BaseModelLoader):
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config)
model = initialize_model(vllm_config=vllm_config, prefix=prefix)
self.load_weights(model, model_config)
process_weights_after_loading(model, model_config, target_device)

View File

@@ -68,6 +68,7 @@ class TensorizerLoader(BaseModelLoader):
def _load_model_serialized_cpu(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> nn.Module:
"""Load a serialized model with tensorizer to the CPU.
@@ -80,7 +81,7 @@ class TensorizerLoader(BaseModelLoader):
model_config = vllm_config.model_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = initialize_model(vllm_config=vllm_config)
model = initialize_model(vllm_config=vllm_config, prefix=prefix)
model.load_weights(self._get_weights_iterator())
return model.eval()
@@ -112,7 +113,7 @@ class TensorizerLoader(BaseModelLoader):
model.load_weights(self._get_weights_iterator())
def load_model(
self, vllm_config: VllmConfig, model_config: ModelConfig
self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
) -> nn.Module:
parallel_config = vllm_config.parallel_config
self._verify_config(model_config, parallel_config)
@@ -134,7 +135,7 @@ class TensorizerLoader(BaseModelLoader):
)
self.load_weights(model, model_config)
return model
return self._load_model_serialized_cpu(vllm_config=vllm_config)
return self._load_model_serialized_cpu(vllm_config=vllm_config, prefix=prefix)
@staticmethod
def save_model(