[Model] Explicit interface for vLLM models and support OOT embedding models (#9108)

This commit is contained in:
Cyrus Leung
2024-10-07 14:10:35 +08:00
committed by GitHub
parent 18b296fdb2
commit 8c6de96ea1
10 changed files with 342 additions and 37 deletions

View File

@@ -12,10 +12,12 @@ from vllm.logger import init_logger
from vllm.utils import is_hip
from .interfaces import supports_multimodal, supports_pp
from .interfaces_base import is_embedding_model, is_text_generation_model
logger = init_logger(__name__)
_GENERATION_MODELS = {
_TEXT_GENERATION_MODELS = {
# [Decoder-only]
"AquilaModel": ("llama", "LlamaForCausalLM"),
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
@@ -74,10 +76,9 @@ _GENERATION_MODELS = {
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
# NOTE: The below models are for speculative decoding only
"MedusaModel": ("medusa", "Medusa"),
"EAGLEModel": ("eagle", "EAGLE"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
# [Encoder-decoder]
"BartModel": ("bart", "BartForConditionalGeneration"),
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
}
_EMBEDDING_MODELS = {
@@ -114,16 +115,18 @@ _MULTIMODAL_MODELS = {
"MllamaForConditionalGeneration": ("mllama",
"MllamaForConditionalGeneration"),
}
_CONDITIONAL_GENERATION_MODELS = {
"BartModel": ("bart", "BartForConditionalGeneration"),
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
_SPECULATIVE_DECODING_MODELS = {
"EAGLEModel": ("eagle", "EAGLE"),
"MedusaModel": ("medusa", "Medusa"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
}
_MODELS = {
**_GENERATION_MODELS,
**_TEXT_GENERATION_MODELS,
**_EMBEDDING_MODELS,
**_MULTIMODAL_MODELS,
**_CONDITIONAL_GENERATION_MODELS,
**_SPECULATIVE_DECODING_MODELS,
}
# Architecture -> type or (module, class).
@@ -317,6 +320,19 @@ class ModelRegistry:
return result.returncode == 0
@staticmethod
def is_text_generation_model(architectures: Union[str, List[str]]) -> bool:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")
is_txt_gen = partial(ModelRegistry._check_stateless,
is_text_generation_model,
default=False)
return any(is_txt_gen(arch) for arch in architectures)
@staticmethod
def is_embedding_model(architectures: Union[str, List[str]]) -> bool:
if isinstance(architectures, str):
@@ -324,7 +340,11 @@ class ModelRegistry:
if not architectures:
logger.warning("No model architectures are specified")
return any(arch in _EMBEDDING_MODELS for arch in architectures)
is_emb = partial(ModelRegistry._check_stateless,
is_embedding_model,
default=False)
return any(is_emb(arch) for arch in architectures)
@staticmethod
def is_multimodal_model(architectures: Union[str, List[str]]) -> bool: