feat: spec decode with draft models (#24322)
Signed-off-by: Tomas Ruiz <tomas.ruiz.te@gmail.com>
This commit is contained in:
@@ -124,12 +124,17 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
|
||||
|
||||
def get_model(
|
||||
*, vllm_config: VllmConfig, model_config: ModelConfig | None = None
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
model_config: ModelConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
loader = get_model_loader(vllm_config.load_config)
|
||||
if model_config is None:
|
||||
model_config = vllm_config.model_config
|
||||
return loader.load_model(vllm_config=vllm_config, model_config=model_config)
|
||||
return loader.load_model(
|
||||
vllm_config=vllm_config, model_config=model_config, prefix=prefix
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -36,7 +36,7 @@ class BaseModelLoader(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
def load_model(
|
||||
self, vllm_config: VllmConfig, model_config: ModelConfig
|
||||
self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
|
||||
) -> nn.Module:
|
||||
"""Load a model with the given configurations."""
|
||||
device_config = vllm_config.device_config
|
||||
@@ -48,7 +48,7 @@ class BaseModelLoader(ABC):
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(
|
||||
vllm_config=vllm_config, model_config=model_config
|
||||
vllm_config=vllm_config, model_config=model_config, prefix=prefix
|
||||
)
|
||||
|
||||
log_model_inspection(model)
|
||||
|
||||
@@ -335,7 +335,7 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
)
|
||||
|
||||
def load_model(
|
||||
self, vllm_config: VllmConfig, model_config: ModelConfig
|
||||
self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
|
||||
) -> nn.Module:
|
||||
device_config = vllm_config.device_config
|
||||
local_model_path = self._prepare_weights(model_config)
|
||||
@@ -364,7 +364,7 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(vllm_config=vllm_config)
|
||||
model = initialize_model(vllm_config=vllm_config, prefix=prefix)
|
||||
self.load_weights(model, model_config)
|
||||
|
||||
process_weights_after_loading(model, model_config, target_device)
|
||||
|
||||
@@ -68,6 +68,7 @@ class TensorizerLoader(BaseModelLoader):
|
||||
def _load_model_serialized_cpu(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
"""Load a serialized model with tensorizer to the CPU.
|
||||
|
||||
@@ -80,7 +81,7 @@ class TensorizerLoader(BaseModelLoader):
|
||||
model_config = vllm_config.model_config
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = initialize_model(vllm_config=vllm_config)
|
||||
model = initialize_model(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
model.load_weights(self._get_weights_iterator())
|
||||
return model.eval()
|
||||
@@ -112,7 +113,7 @@ class TensorizerLoader(BaseModelLoader):
|
||||
model.load_weights(self._get_weights_iterator())
|
||||
|
||||
def load_model(
|
||||
self, vllm_config: VllmConfig, model_config: ModelConfig
|
||||
self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
|
||||
) -> nn.Module:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self._verify_config(model_config, parallel_config)
|
||||
@@ -134,7 +135,7 @@ class TensorizerLoader(BaseModelLoader):
|
||||
)
|
||||
self.load_weights(model, model_config)
|
||||
return model
|
||||
return self._load_model_serialized_cpu(vllm_config=vllm_config)
|
||||
return self._load_model_serialized_cpu(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
@staticmethod
|
||||
def save_model(
|
||||
|
||||
Reference in New Issue
Block a user