[V1][Spec Decoding] Use model_loader.get_model() to load models (#18273)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin
2025-05-23 03:05:44 +01:00
committed by GitHub
parent 04eb88dc80
commit c6b636f9fb
16 changed files with 59 additions and 135 deletions

View File

@@ -1,8 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
from torch import nn
from vllm.config import LoadConfig, LoadFormat, VllmConfig
from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.bitsandbytes_loader import (
BitsAndBytesModelLoader)
@@ -47,9 +49,14 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
return DefaultModelLoader(load_config)
def get_model(*, vllm_config: VllmConfig) -> nn.Module:
def get_model(*,
vllm_config: VllmConfig,
model_config: Optional[ModelConfig] = None) -> nn.Module:
loader = get_model_loader(vllm_config.load_config)
return loader.load_model(vllm_config=vllm_config)
if model_config is None:
model_config = vllm_config.model_config
return loader.load_model(vllm_config=vllm_config,
model_config=model_config)
__all__ = [

View File

@@ -18,6 +18,7 @@ class BaseModelLoader(ABC):
raise NotImplementedError
@abstractmethod
def load_model(self, *, vllm_config: VllmConfig) -> nn.Module:
def load_model(self, *, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
"""Load a model with the given configurations."""
raise NotImplementedError

View File

@@ -569,10 +569,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):

View File

@@ -264,13 +264,14 @@ class DefaultModelLoader(BaseModelLoader):
fall_back_to_pt=True,
allow_patterns_overrides=None)
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config)
model = initialize_model(vllm_config=vllm_config,
model_config=model_config)
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(

View File

@@ -22,9 +22,9 @@ class DummyModelLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None:
pass # Nothing to download
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:

View File

@@ -92,9 +92,9 @@ class GGUFModelLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model)
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
local_model_path = self._prepare_weights(model_config.model)
gguf_weights_map = self._get_gguf_weights_map(model_config)
# we can only know if tie word embeddings after mapping weights

View File

@@ -100,11 +100,10 @@ class RunaiModelStreamerLoader(BaseModelLoader):
"""Download model if necessary"""
self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
"""Perform streaming of the model to destination"""
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:

View File

@@ -100,9 +100,9 @@ class ShardedStateLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
from vllm.distributed import get_tensor_model_parallel_rank

View File

@@ -93,8 +93,8 @@ class TensorizerLoader(BaseModelLoader):
with self.tensorizer_config.open_stream():
pass
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
model_config = vllm_config.model_config
def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
parallel_config = vllm_config.parallel_config
self._verify_config(model_config, parallel_config)

View File

@@ -42,9 +42,11 @@ def initialize_model(
*,
prefix: str = "",
model_class: Optional[type[nn.Module]] = None,
model_config: Optional[ModelConfig] = None,
) -> nn.Module:
"""Initialize a model with the given configurations."""
model_config = vllm_config.model_config
if model_config is None:
model_config = vllm_config.model_config
if model_class is None:
model_class, _ = get_model_architecture(model_config)

View File

@@ -130,13 +130,15 @@ class LlamaModel(nn.Module):
class EagleLlamaForCausalLM(LlamaForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config)
self.model = LlamaModel(vllm_config=vllm_config,
prefix="model",
start_layer_id=start_layer_id)
start_layer_id=target_layer_num)
logit_scale = getattr(self.config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.config.vocab_size,

View File

@@ -175,13 +175,15 @@ class LlamaModel(nn.Module):
class Eagle3LlamaForCausalLM(LlamaForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config)
self.model = LlamaModel(vllm_config=vllm_config,
start_layer_id=start_layer_id,
prefix="model")
prefix="model",
start_layer_id=target_layer_num)
logit_scale = getattr(self.config, "logit_scale", 1.0)
self.lm_head = ParallelLMHead(
@@ -193,8 +195,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
self.logits_processor = LogitsProcessor(self.config.draft_vocab_size,
scale=logit_scale)
self.draft_id_to_target_id = nn.Parameter(
torch.zeros((self.config.draft_vocab_size),
dtype=torch.long).type(torch.LongTensor),
torch.zeros(self.config.draft_vocab_size, dtype=torch.long),
requires_grad=False,
)

View File

@@ -51,10 +51,7 @@ class Medusa(nn.Module):
needs to have truncated_vocab_size (=k) as an attribute."""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
if hasattr(vllm_config, 'draft_model_config'):
config = vllm_config.draft_model_config.hf_config
else:
config = vllm_config.model_config.hf_config
config = vllm_config.speculative_config.draft_model_config.hf_config
super().__init__()
self.config = config
self.blocks = nn.ModuleList([