[Model] Explicit interface for vLLM models and support OOT embedding models (#9108)
This commit is contained in:
@@ -12,10 +12,12 @@ from vllm.logger import init_logger
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from .interfaces import supports_multimodal, supports_pp
|
||||
from .interfaces_base import is_embedding_model, is_text_generation_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_GENERATION_MODELS = {
|
||||
_TEXT_GENERATION_MODELS = {
|
||||
# [Decoder-only]
|
||||
"AquilaModel": ("llama", "LlamaForCausalLM"),
|
||||
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
|
||||
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
|
||||
@@ -74,10 +76,9 @@ _GENERATION_MODELS = {
|
||||
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
|
||||
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
|
||||
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
||||
# NOTE: The below models are for speculative decoding only
|
||||
"MedusaModel": ("medusa", "Medusa"),
|
||||
"EAGLEModel": ("eagle", "EAGLE"),
|
||||
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
||||
# [Encoder-decoder]
|
||||
"BartModel": ("bart", "BartForConditionalGeneration"),
|
||||
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
|
||||
}
|
||||
|
||||
_EMBEDDING_MODELS = {
|
||||
@@ -114,16 +115,18 @@ _MULTIMODAL_MODELS = {
|
||||
"MllamaForConditionalGeneration": ("mllama",
|
||||
"MllamaForConditionalGeneration"),
|
||||
}
|
||||
_CONDITIONAL_GENERATION_MODELS = {
|
||||
"BartModel": ("bart", "BartForConditionalGeneration"),
|
||||
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
|
||||
|
||||
_SPECULATIVE_DECODING_MODELS = {
|
||||
"EAGLEModel": ("eagle", "EAGLE"),
|
||||
"MedusaModel": ("medusa", "Medusa"),
|
||||
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
||||
}
|
||||
|
||||
_MODELS = {
|
||||
**_GENERATION_MODELS,
|
||||
**_TEXT_GENERATION_MODELS,
|
||||
**_EMBEDDING_MODELS,
|
||||
**_MULTIMODAL_MODELS,
|
||||
**_CONDITIONAL_GENERATION_MODELS,
|
||||
**_SPECULATIVE_DECODING_MODELS,
|
||||
}
|
||||
|
||||
# Architecture -> type or (module, class).
|
||||
@@ -317,6 +320,19 @@ class ModelRegistry:
|
||||
|
||||
return result.returncode == 0
|
||||
|
||||
@staticmethod
|
||||
def is_text_generation_model(architectures: Union[str, List[str]]) -> bool:
|
||||
if isinstance(architectures, str):
|
||||
architectures = [architectures]
|
||||
if not architectures:
|
||||
logger.warning("No model architectures are specified")
|
||||
|
||||
is_txt_gen = partial(ModelRegistry._check_stateless,
|
||||
is_text_generation_model,
|
||||
default=False)
|
||||
|
||||
return any(is_txt_gen(arch) for arch in architectures)
|
||||
|
||||
@staticmethod
|
||||
def is_embedding_model(architectures: Union[str, List[str]]) -> bool:
|
||||
if isinstance(architectures, str):
|
||||
@@ -324,7 +340,11 @@ class ModelRegistry:
|
||||
if not architectures:
|
||||
logger.warning("No model architectures are specified")
|
||||
|
||||
return any(arch in _EMBEDDING_MODELS for arch in architectures)
|
||||
is_emb = partial(ModelRegistry._check_stateless,
|
||||
is_embedding_model,
|
||||
default=False)
|
||||
|
||||
return any(is_emb(arch) for arch in architectures)
|
||||
|
||||
@staticmethod
|
||||
def is_multimodal_model(architectures: Union[str, List[str]]) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user