[Core] Support model loader plugins (#21067)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
37
tests/model_executor/model_loader/test_registry.py
Normal file
37
tests/model_executor/model_loader/test_registry.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig
|
||||
from vllm.model_executor.model_loader import (get_model_loader,
|
||||
register_model_loader)
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
|
||||
|
||||
@register_model_loader("custom_load_format")
|
||||
class CustomModelLoader(BaseModelLoader):
|
||||
|
||||
def __init__(self, load_config: LoadConfig) -> None:
|
||||
super().__init__(load_config)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
pass
|
||||
|
||||
def load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def test_register_model_loader():
|
||||
load_config = LoadConfig(load_format="custom_load_format")
|
||||
assert isinstance(get_model_loader(load_config), CustomModelLoader)
|
||||
|
||||
|
||||
def test_invalid_model_loader():
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
@register_model_loader("invalid_load_format")
|
||||
class InValidModelLoader:
|
||||
pass
|
||||
Reference in New Issue
Block a user