2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
2025-06-03 11:20:17 -07:00
|
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
2024-11-15 08:55:54 +08:00
|
|
|
"""
|
|
|
|
|
Whenever you add an architecture to this page, please also update
|
|
|
|
|
`tests/models/registry.py` with example HuggingFace models for it.
|
|
|
|
|
"""
|
2025-10-05 15:06:22 +01:00
|
|
|
|
2024-10-04 18:01:37 +08:00
|
|
|
import importlib
|
2025-09-20 07:51:13 -04:00
|
|
|
import json
|
2024-10-29 13:08:20 +08:00
|
|
|
import os
|
2024-10-09 21:58:27 -07:00
|
|
|
import pickle
|
2024-10-04 18:01:37 +08:00
|
|
|
import subprocess
|
|
|
|
|
import sys
|
2024-10-09 21:58:27 -07:00
|
|
|
import tempfile
|
2024-10-11 19:08:11 +08:00
|
|
|
from abc import ABC, abstractmethod
|
2025-05-15 06:06:50 +01:00
|
|
|
from collections.abc import Callable, Set
|
2025-09-20 07:51:13 -04:00
|
|
|
from dataclasses import asdict, dataclass, field
|
2024-10-11 19:08:11 +08:00
|
|
|
from functools import lru_cache
|
2025-09-20 07:51:13 -04:00
|
|
|
from pathlib import Path
|
2025-11-28 14:05:48 +08:00
|
|
|
from typing import TYPE_CHECKING, Any, TypeVar
|
2024-10-04 18:01:37 +08:00
|
|
|
|
|
|
|
|
import torch.nn as nn
|
2025-07-28 10:42:40 +08:00
|
|
|
import transformers
|
2024-10-04 18:01:37 +08:00
|
|
|
|
2025-09-20 07:51:13 -04:00
|
|
|
from vllm import envs
|
2025-09-19 17:22:33 +01:00
|
|
|
from vllm.config import (
|
|
|
|
|
ModelConfig,
|
|
|
|
|
iter_architecture_defaults,
|
2025-07-28 10:42:40 +08:00
|
|
|
try_match_architecture_defaults,
|
|
|
|
|
)
|
2024-10-04 18:01:37 +08:00
|
|
|
from vllm.logger import init_logger
|
2025-09-20 07:51:13 -04:00
|
|
|
from vllm.logging_utils import logtime
|
2025-07-28 10:42:40 +08:00
|
|
|
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
|
2025-11-25 18:50:22 -06:00
|
|
|
from vllm.utils.hashing import safe_hash
|
2025-10-05 15:06:22 +01:00
|
|
|
|
2025-11-28 14:05:48 +08:00
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from vllm.config.model import AttnTypeStr
|
|
|
|
|
from vllm.config.pooler import PoolingTypeStr
|
|
|
|
|
else:
|
|
|
|
|
AttnTypeStr = Any
|
|
|
|
|
PoolingTypeStr = Any
|
|
|
|
|
|
|
|
|
|
|
2025-08-27 21:24:09 +08:00
|
|
|
from .interfaces import (
|
|
|
|
|
has_inner_state,
|
|
|
|
|
has_noops,
|
|
|
|
|
is_attention_free,
|
|
|
|
|
is_hybrid,
|
2025-12-23 14:31:55 +01:00
|
|
|
requires_raw_input_tokens,
|
2025-08-27 21:24:09 +08:00
|
|
|
supports_cross_encoding,
|
2025-10-27 15:05:20 +02:00
|
|
|
supports_mamba_prefix_caching,
|
2025-08-27 21:41:22 +08:00
|
|
|
supports_multimodal,
|
|
|
|
|
supports_multimodal_encoder_tp_data,
|
2025-08-28 01:01:50 +08:00
|
|
|
supports_multimodal_raw_input_only,
|
|
|
|
|
supports_pp,
|
2025-08-27 21:41:22 +08:00
|
|
|
supports_transcription,
|
|
|
|
|
)
|
2025-08-27 21:24:09 +08:00
|
|
|
from .interfaces_base import (
|
2025-11-28 14:05:48 +08:00
|
|
|
get_attn_type,
|
2025-08-27 21:24:09 +08:00
|
|
|
get_default_pooling_type,
|
|
|
|
|
is_pooling_model,
|
|
|
|
|
is_text_generation_model,
|
|
|
|
|
)
|
2024-10-04 18:01:37 +08:00
|
|
|
|
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
2024-10-07 14:10:35 +08:00
|
|
|
_TEXT_GENERATION_MODELS = {
|
|
|
|
|
# [Decoder-only]
|
2025-11-17 15:11:20 -08:00
|
|
|
"AfmoeForCausalLM": ("afmoe", "AfmoeForCausalLM"),
|
2025-08-29 14:29:18 +02:00
|
|
|
"ApertusForCausalLM": ("apertus", "ApertusForCausalLM"),
|
2024-10-04 18:01:37 +08:00
|
|
|
"AquilaModel": ("llama", "LlamaForCausalLM"),
|
|
|
|
|
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
|
2025-07-22 13:27:43 +05:30
|
|
|
"ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
|
2024-10-04 18:01:37 +08:00
|
|
|
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
|
2024-10-23 09:14:44 -07:00
|
|
|
# baichuan-7b, upper case 'C' in the class name
|
|
|
|
|
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
|
|
|
|
|
# baichuan-13b, lower case 'c' in the class name
|
|
|
|
|
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
|
2025-07-14 22:10:32 +08:00
|
|
|
"BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
|
2025-09-15 20:09:30 +08:00
|
|
|
"BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
|
2025-02-07 07:22:42 +08:00
|
|
|
"BambaForCausalLM": ("bamba", "BambaForCausalLM"),
|
2024-10-04 18:01:37 +08:00
|
|
|
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
|
2025-02-13 22:19:15 +08:00
|
|
|
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
|
2025-04-07 23:15:58 +08:00
|
|
|
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
|
2024-10-04 18:01:37 +08:00
|
|
|
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
|
2024-12-16 11:56:19 +02:00
|
|
|
"Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
|
2025-09-25 09:37:03 +02:00
|
|
|
"CwmForCausalLM": ("llama", "LlamaForCausalLM"),
|
2024-10-04 18:01:37 +08:00
|
|
|
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
|
2025-03-31 15:35:14 +03:00
|
|
|
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
|
2025-11-08 13:01:27 +08:00
|
|
|
"DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"),
|
2024-10-04 18:01:37 +08:00
|
|
|
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
|
2025-02-06 02:16:20 -05:00
|
|
|
"DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
|
2025-09-30 05:14:41 -04:00
|
|
|
"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
|
2025-06-30 10:34:36 +08:00
|
|
|
"Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
|
2025-07-28 14:22:32 +02:00
|
|
|
"Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
|
2025-07-02 18:37:01 +08:00
|
|
|
"Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
|
2024-10-04 18:01:37 +08:00
|
|
|
"ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
|
2025-07-19 15:25:44 +09:00
|
|
|
"Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
|
2025-01-19 19:40:40 +01:00
|
|
|
"Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
|
2025-10-14 21:26:11 +08:00
|
|
|
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
|
|
|
|
|
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
|
|
|
|
"FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"),
|
2025-10-10 09:43:15 -07:00
|
|
|
"FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
|
2024-10-04 18:01:37 +08:00
|
|
|
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
|
|
|
|
|
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
|
2025-03-12 08:36:33 -07:00
|
|
|
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
|
2025-08-09 18:56:25 +02:00
|
|
|
"Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
|
2025-09-11 15:32:09 +08:00
|
|
|
"Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
|
2024-11-28 22:53:31 +08:00
|
|
|
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
2025-04-10 09:19:42 +08:00
|
|
|
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
|
2025-07-20 06:40:31 +08:00
|
|
|
"Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
|
2025-08-05 23:26:00 -07:00
|
|
|
"GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
|
2024-10-04 18:01:37 +08:00
|
|
|
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
|
|
|
|
|
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
|
|
|
|
|
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
|
|
|
|
|
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
|
|
|
|
|
"GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
|
|
|
|
|
"GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
|
2025-10-05 17:18:11 +01:00
|
|
|
"GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"), # noqa: E501
|
|
|
|
|
"GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"), # noqa: E501
|
2024-12-11 22:39:16 -08:00
|
|
|
"GritLM": ("gritlm", "GritLM"),
|
2026-01-08 13:59:48 +01:00
|
|
|
"Grok1ModelForCausalLM": ("grok1", "GrokForCausalLM"),
|
|
|
|
|
"Grok1ForCausalLM": ("grok1", "GrokForCausalLM"),
|
2025-07-23 18:54:08 +08:00
|
|
|
"HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
|
|
|
|
|
"HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
|
2025-07-25 22:05:42 +09:00
|
|
|
"HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
|
2024-10-04 18:01:37 +08:00
|
|
|
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
|
|
|
|
|
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
|
2024-10-23 00:01:46 +08:00
|
|
|
"InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
|
2025-01-15 19:35:17 +08:00
|
|
|
"InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
|
2026-01-08 22:42:57 +08:00
|
|
|
"IQuestCoderForCausalLM": ("llama", "LlamaForCausalLM"),
|
|
|
|
|
"IQuestLoopCoderForCausalLM": ("iquest_loopcoder", "IQuestLoopCoderForCausalLM"),
|
2024-10-04 18:01:37 +08:00
|
|
|
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
|
2025-12-18 21:16:58 +05:30
|
|
|
"Jais2ForCausalLM": ("jais2", "Jais2ForCausalLM"),
|
2024-10-04 18:01:37 +08:00
|
|
|
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
|
2025-10-30 21:02:27 +08:00
|
|
|
"KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"), # noqa: E501
|
2025-08-21 01:35:07 -06:00
|
|
|
"Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
|
2025-10-08 01:03:05 +09:00
|
|
|
"Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
|
2024-10-04 18:01:37 +08:00
|
|
|
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
|
2025-10-06 00:29:18 -07:00
|
|
|
"Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
|
2024-10-04 18:01:37 +08:00
|
|
|
# For decapoda-research/llama-*
|
|
|
|
|
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
2025-09-25 12:53:40 +08:00
|
|
|
"LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"),
|
2024-10-11 11:40:06 -04:00
|
|
|
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
2025-02-17 07:17:50 -05:00
|
|
|
"Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
|
2024-11-15 08:55:54 +08:00
|
|
|
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
|
|
|
|
|
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
|
2025-10-14 21:26:11 +08:00
|
|
|
"MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
|
|
|
|
|
"MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
|
|
|
|
|
"MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
|
2025-10-27 00:59:11 +08:00
|
|
|
"MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
|
2024-10-04 18:01:37 +08:00
|
|
|
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
|
2025-12-02 11:29:00 +01:00
|
|
|
"MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
|
2024-10-04 18:01:37 +08:00
|
|
|
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
|
|
|
|
|
# transformers's mpt class has lower case
|
|
|
|
|
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
|
|
|
|
|
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
|
2025-05-13 07:25:33 +08:00
|
|
|
"MiMoForCausalLM": ("mimo", "MiMoForCausalLM"),
|
2025-12-20 01:17:03 +08:00
|
|
|
"MiMoV2FlashForCausalLM": ("mimo_v2_flash", "MiMoV2FlashForCausalLM"),
|
2024-10-04 18:01:37 +08:00
|
|
|
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
|
2025-06-05 14:29:28 -07:00
|
|
|
"NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
|
2024-10-04 18:01:37 +08:00
|
|
|
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
|
2024-11-25 14:26:40 -08:00
|
|
|
"Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
|
2025-09-12 20:26:21 -07:00
|
|
|
"Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
|
2024-10-04 18:01:37 +08:00
|
|
|
"OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
|
|
|
|
|
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
|
|
|
|
|
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
|
2025-10-30 22:34:41 +08:00
|
|
|
"OuroForCausalLM": ("ouro", "OuroForCausalLM"),
|
2025-11-05 00:17:20 +08:00
|
|
|
"PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
|
2025-12-31 00:11:38 +08:00
|
|
|
"PanguProMoEV2ForCausalLM": ("openpangu", "PanguProMoEV2ForCausalLM"),
|
2025-11-05 00:17:20 +08:00
|
|
|
"PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
|
2024-10-04 18:01:37 +08:00
|
|
|
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
|
|
|
|
|
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
|
|
|
|
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
|
|
|
|
|
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
|
2025-04-16 11:31:30 +09:00
|
|
|
"Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
|
2025-11-20 20:00:19 +09:00
|
|
|
"Plamo3ForCausalLM": ("plamo3", "Plamo3ForCausalLM"),
|
2025-02-13 22:19:15 +08:00
|
|
|
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
2024-10-04 18:01:37 +08:00
|
|
|
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
|
|
|
|
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
|
2025-04-07 19:06:41 +08:00
|
|
|
"Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
|
|
|
|
|
"Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
|
2024-10-04 18:01:37 +08:00
|
|
|
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
|
2025-08-22 12:58:10 +08:00
|
|
|
"SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
|
2025-07-31 23:19:06 +08:00
|
|
|
"Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
|
2024-10-04 18:01:37 +08:00
|
|
|
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
|
|
|
|
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
|
|
|
|
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
|
|
|
|
|
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
|
2025-11-26 21:15:00 +08:00
|
|
|
"TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
|
2024-11-27 19:32:35 +08:00
|
|
|
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
|
2025-03-22 17:04:44 +08:00
|
|
|
"TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
|
2024-12-01 19:27:13 -08:00
|
|
|
"XverseForCausalLM": ("llama", "LlamaForCausalLM"),
|
2025-03-18 08:56:21 -07:00
|
|
|
"Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
|
2024-10-04 18:01:37 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_EMBEDDING_MODELS = {
|
2024-10-16 14:31:00 +08:00
|
|
|
# [Text-only]
|
2024-10-17 19:21:01 -04:00
|
|
|
"BertModel": ("bert", "BertEmbeddingModel"),
|
2025-10-13 02:00:52 +09:00
|
|
|
"BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
|
2025-03-31 15:35:14 +03:00
|
|
|
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
|
2024-12-01 08:02:54 +08:00
|
|
|
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
|
2025-09-06 13:58:36 +08:00
|
|
|
"Gemma3TextModel": ("gemma3", "Gemma3Model"),
|
2024-11-28 22:53:31 +08:00
|
|
|
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
2025-06-20 14:07:41 +02:00
|
|
|
"GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
|
2024-12-11 22:39:16 -08:00
|
|
|
"GritLM": ("gritlm", "GritLM"),
|
2025-05-14 16:31:31 +08:00
|
|
|
"GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
|
|
|
|
|
"GteNewModel": ("bert_with_rope", "GteNewModel"),
|
2024-12-28 14:14:10 +08:00
|
|
|
"InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
|
2024-12-19 16:48:06 +02:00
|
|
|
"JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
|
2025-12-23 12:19:16 +01:00
|
|
|
"LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
|
2024-12-01 08:02:54 +08:00
|
|
|
"LlamaModel": ("llama", "LlamaForCausalLM"),
|
2024-11-06 12:05:05 +08:00
|
|
|
**{
|
|
|
|
|
# Multiple models share the same architecture, so we include them all
|
2025-10-05 17:18:11 +01:00
|
|
|
k: (mod, arch)
|
|
|
|
|
for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
|
2024-11-06 12:05:05 +08:00
|
|
|
if arch == "LlamaForCausalLM"
|
|
|
|
|
},
|
2024-12-01 08:02:54 +08:00
|
|
|
"MistralModel": ("llama", "LlamaForCausalLM"),
|
2025-05-14 16:31:31 +08:00
|
|
|
"ModernBertModel": ("modernbert", "ModernBertModel"),
|
2025-05-11 15:59:43 +08:00
|
|
|
"NomicBertModel": ("bert_with_rope", "NomicBertModel"),
|
2024-10-31 00:33:42 +08:00
|
|
|
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
|
2025-05-30 14:53:37 +08:00
|
|
|
"Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
|
2024-11-15 12:23:09 +08:00
|
|
|
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
2024-11-06 12:05:05 +08:00
|
|
|
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
2025-01-20 14:59:46 +08:00
|
|
|
"Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
|
2025-04-18 23:11:57 +08:00
|
|
|
"RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
|
|
|
|
|
"RobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
2025-11-26 21:15:00 +08:00
|
|
|
"TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
|
2024-11-27 19:32:35 +08:00
|
|
|
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
|
2025-04-18 23:11:57 +08:00
|
|
|
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
2024-10-16 14:31:00 +08:00
|
|
|
# [Multimodal]
|
2025-10-04 21:21:42 +08:00
|
|
|
"CLIPModel": ("clip", "CLIPEmbeddingModel"),
|
2025-10-05 17:18:11 +01:00
|
|
|
"LlavaNextForConditionalGeneration": (
|
|
|
|
|
"llava_next",
|
|
|
|
|
"LlavaNextForConditionalGeneration",
|
2025-10-06 00:29:18 -07:00
|
|
|
),
|
2024-10-16 14:31:00 +08:00
|
|
|
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
2024-12-01 08:02:54 +08:00
|
|
|
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
2025-10-24 04:19:48 +08:00
|
|
|
"SiglipModel": ("siglip", "SiglipEmbeddingModel"),
|
2025-09-04 08:22:41 +01:00
|
|
|
# Technically Terratorch models work on images, both in
|
|
|
|
|
# input and output. I am adding it here because it piggy-backs on embedding
|
2025-02-12 04:34:30 +00:00
|
|
|
# models for the time being.
|
2025-09-04 08:22:41 +01:00
|
|
|
"PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
|
|
|
|
|
"Terratorch": ("terratorch", "Terratorch"),
|
2024-10-04 18:01:37 +08:00
|
|
|
}
|
|
|
|
|
|
2024-11-24 23:56:20 -03:00
|
|
|
_CROSS_ENCODER_MODELS = {
|
|
|
|
|
"BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
|
2025-09-18 23:22:01 +08:00
|
|
|
"BertForTokenClassification": ("bert", "BertForTokenClassification"),
|
2025-10-05 17:18:11 +01:00
|
|
|
"GteNewForSequenceClassification": (
|
|
|
|
|
"bert_with_rope",
|
|
|
|
|
"GteNewForSequenceClassification",
|
|
|
|
|
),
|
2025-12-23 12:19:16 +01:00
|
|
|
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
|
|
|
|
|
"LlamaBidirectionalForSequenceClassification": (
|
|
|
|
|
"llama",
|
|
|
|
|
"LlamaBidirectionalForSequenceClassification",
|
|
|
|
|
),
|
2025-10-05 17:18:11 +01:00
|
|
|
"ModernBertForSequenceClassification": (
|
|
|
|
|
"modernbert",
|
|
|
|
|
"ModernBertForSequenceClassification",
|
|
|
|
|
),
|
2025-10-07 16:29:19 +02:00
|
|
|
"ModernBertForTokenClassification": (
|
|
|
|
|
"modernbert",
|
|
|
|
|
"ModernBertForTokenClassification",
|
|
|
|
|
),
|
2025-10-05 17:18:11 +01:00
|
|
|
"RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
|
|
|
|
|
"XLMRobertaForSequenceClassification": (
|
|
|
|
|
"roberta",
|
|
|
|
|
"RobertaForSequenceClassification",
|
|
|
|
|
),
|
2024-11-24 23:56:20 -03:00
|
|
|
}
|
|
|
|
|
|
2024-10-04 18:01:37 +08:00
|
|
|
_MULTIMODAL_MODELS = {
|
2024-10-07 19:55:12 +08:00
|
|
|
# [Decoder-only]
|
2024-11-26 02:10:55 +08:00
|
|
|
"AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
|
2025-12-14 05:14:55 -05:00
|
|
|
"AudioFlamingo3ForConditionalGeneration": (
|
|
|
|
|
"audioflamingo3",
|
|
|
|
|
"AudioFlamingo3ForConditionalGeneration",
|
|
|
|
|
),
|
2025-10-05 17:18:11 +01:00
|
|
|
"AyaVisionForConditionalGeneration": (
|
|
|
|
|
"aya_vision",
|
|
|
|
|
"AyaVisionForConditionalGeneration",
|
2025-10-06 00:29:18 -07:00
|
|
|
),
|
2025-12-15 14:58:23 +08:00
|
|
|
"BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"),
|
2025-10-20 10:31:26 +08:00
|
|
|
"BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
|
2024-10-07 19:55:12 +08:00
|
|
|
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
|
2025-10-05 17:18:11 +01:00
|
|
|
"ChameleonForConditionalGeneration": (
|
|
|
|
|
"chameleon",
|
|
|
|
|
"ChameleonForConditionalGeneration",
|
2025-10-06 00:29:18 -07:00
|
|
|
),
|
2025-10-05 17:18:11 +01:00
|
|
|
"Cohere2VisionForConditionalGeneration": (
|
|
|
|
|
"cohere2_vision",
|
|
|
|
|
"Cohere2VisionForConditionalGeneration",
|
2025-10-06 00:29:18 -07:00
|
|
|
),
|
2025-01-12 16:17:24 +08:00
|
|
|
"DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
|
2025-10-22 22:59:15 +08:00
|
|
|
"DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"),
|
2025-09-21 19:24:40 -07:00
|
|
|
"DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"),
|
2025-10-05 17:18:11 +01:00
|
|
|
"Ernie4_5_VLMoeForConditionalGeneration": (
|
|
|
|
|
"ernie45_vl",
|
|
|
|
|
"Ernie4_5_VLMoeForConditionalGeneration",
|
2025-10-06 00:29:18 -07:00
|
|
|
),
|
2024-10-04 18:01:37 +08:00
|
|
|
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
|
2025-10-22 14:05:34 -03:00
|
|
|
"Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501
|
2025-10-05 17:18:11 +01:00
|
|
|
"Gemma3nForConditionalGeneration": (
|
|
|
|
|
"gemma3n_mm",
|
|
|
|
|
"Gemma3nForConditionalGeneration",
|
2025-10-06 00:29:18 -07:00
|
|
|
),
|
2025-12-31 10:12:24 -05:00
|
|
|
"GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
|
2025-02-13 22:19:15 +08:00
|
|
|
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
|
2025-07-01 20:48:26 +08:00
|
|
|
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501
|
2025-08-13 08:13:17 +08:00
|
|
|
"Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"), # noqa: E501
|
2025-10-05 17:18:11 +01:00
|
|
|
"GraniteSpeechForConditionalGeneration": (
|
|
|
|
|
"granite_speech",
|
|
|
|
|
"GraniteSpeechForConditionalGeneration",
|
2025-10-06 00:29:18 -07:00
|
|
|
),
|
2024-11-03 18:15:36 -06:00
|
|
|
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
|
2025-11-25 11:28:51 +08:00
|
|
|
"HunYuanVLForConditionalGeneration": (
|
|
|
|
|
"hunyuan_vision",
|
|
|
|
|
"HunYuanVLForConditionalGeneration",
|
|
|
|
|
),
|
2024-10-04 18:01:37 +08:00
|
|
|
"InternVLChatModel": ("internvl", "InternVLChatModel"),
|
2025-09-26 02:10:29 +03:00
|
|
|
"NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
|
2025-11-24 11:27:55 +09:00
|
|
|
"OpenCUAForConditionalGeneration": (
|
|
|
|
|
"opencua",
|
|
|
|
|
"OpenCUAForConditionalGeneration",
|
|
|
|
|
),
|
2025-10-05 17:18:11 +01:00
|
|
|
"InternS1ForConditionalGeneration": (
|
|
|
|
|
"interns1",
|
|
|
|
|
"InternS1ForConditionalGeneration",
|
2025-10-06 00:29:18 -07:00
|
|
|
),
|
2025-10-05 17:18:11 +01:00
|
|
|
"InternVLForConditionalGeneration": (
|
|
|
|
|
"interns1",
|
|
|
|
|
"InternS1ForConditionalGeneration",
|
2025-10-06 00:29:18 -07:00
|
|
|
),
|
2025-10-05 17:18:11 +01:00
|
|
|
"Idefics3ForConditionalGeneration": (
|
|
|
|
|
"idefics3",
|
|
|
|
|
"Idefics3ForConditionalGeneration",
|
|
|
|
|
),
|
2025-12-25 21:49:11 -05:00
|
|
|
"IsaacForConditionalGeneration": ("isaac", "IsaacForConditionalGeneration"),
|
2025-10-05 17:18:11 +01:00
|
|
|
"SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"), # noqa: E501
|
2025-07-02 14:35:04 +08:00
|
|
|
"KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
|
2025-10-05 17:18:11 +01:00
|
|
|
"KeyeVL1_5ForConditionalGeneration": (
|
|
|
|
|
"keye_vl1_5",
|
|
|
|
|
"KeyeVL1_5ForConditionalGeneration",
|
2025-10-06 00:29:18 -07:00
|
|
|
),
|
2025-08-21 12:08:52 +08:00
|
|
|
"RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
|
2025-04-15 05:41:48 +08:00
|
|
|
"KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501
|
2025-10-17 07:05:24 +02:00
|
|
|
"LightOnOCRForConditionalGeneration": (
|
|
|
|
|
"lightonocr",
|
|
|
|
|
"LightOnOCRForConditionalGeneration",
|
|
|
|
|
),
|
2026-01-08 05:00:27 -08:00
|
|
|
"Lfm2VlForConditionalGeneration": ("lfm2_vl", "Lfm2VLForConditionalGeneration"),
|
2025-07-17 05:07:55 -05:00
|
|
|
"Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
|
2025-09-15 21:17:14 -07:00
|
|
|
"Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501
|
2024-10-07 19:55:12 +08:00
|
|
|
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
|
2025-10-05 17:18:11 +01:00
|
|
|
"LlavaNextForConditionalGeneration": (
|
|
|
|
|
"llava_next",
|
|
|
|
|
"LlavaNextForConditionalGeneration",
|
2025-10-06 00:29:18 -07:00
|
|
|
),
|
2025-10-05 17:18:11 +01:00
|
|
|
"LlavaNextVideoForConditionalGeneration": (
|
|
|
|
|
"llava_next_video",
|
|
|
|
|
"LlavaNextVideoForConditionalGeneration",
|
2025-10-06 00:29:18 -07:00
|
|
|
),
|
2025-10-05 17:18:11 +01:00
|
|
|
"LlavaOnevisionForConditionalGeneration": (
|
|
|
|
|
"llava_onevision",
|
|
|
|
|
"LlavaOnevisionForConditionalGeneration",
|
2025-10-06 00:29:18 -07:00
|
|
|
),
|
2024-12-08 01:10:05 +08:00
|
|
|
"MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"), # noqa: E501
|
2025-09-04 15:08:09 +08:00
|
|
|
"MiDashengLMModel": ("midashenglm", "MiDashengLMModel"),
|
2025-10-05 17:18:11 +01:00
|
|
|
"MiniMaxVL01ForConditionalGeneration": (
|
|
|
|
|
"minimax_vl_01",
|
|
|
|
|
"MiniMaxVL01ForConditionalGeneration",
|
2025-10-06 00:29:18 -07:00
|
|
|
),
|
2025-01-29 17:24:59 +08:00
|
|
|
"MiniCPMO": ("minicpmo", "MiniCPMO"),
|
2024-10-04 18:01:37 +08:00
|
|
|
"MiniCPMV": ("minicpmv", "MiniCPMV"),
|
2025-10-05 17:18:11 +01:00
|
|
|
"Mistral3ForConditionalGeneration": (
|
|
|
|
|
"mistral3",
|
|
|
|
|
"Mistral3ForConditionalGeneration",
|
2025-10-06 00:29:18 -07:00
|
|
|
),
|
2024-10-14 07:56:24 -07:00
|
|
|
"MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
|
2024-10-07 19:55:12 +08:00
|
|
|
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),
|
2025-05-12 08:56:30 +08:00
|
|
|
"Ovis": ("ovis", "Ovis"),
|
2025-08-19 21:12:59 +08:00
|
|
|
"Ovis2_5": ("ovis2_5", "Ovis2_5"),
|
2025-11-03 19:04:22 +08:00
|
|
|
"PaddleOCRVLForConditionalGeneration": (
|
|
|
|
|
"paddleocr_vl",
|
|
|
|
|
"PaddleOCRVLForConditionalGeneration",
|
|
|
|
|
),
|
2025-10-22 14:05:34 -03:00
|
|
|
"PaliGemmaForConditionalGeneration": (
|
|
|
|
|
"paligemma",
|
|
|
|
|
"PaliGemmaForConditionalGeneration",
|
|
|
|
|
),
|
2024-10-04 18:01:37 +08:00
|
|
|
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
2025-07-27 11:07:57 +08:00
|
|
|
"Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
|
2024-10-07 19:55:12 +08:00
|
|
|
"PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501
|
2025-02-13 22:19:15 +08:00
|
|
|
"QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"), # noqa: E501
|
2024-10-07 19:55:12 +08:00
|
|
|
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
2025-10-05 17:18:11 +01:00
|
|
|
"Qwen2_5_VLForConditionalGeneration": (
|
|
|
|
|
"qwen2_5_vl",
|
|
|
|
|
"Qwen2_5_VLForConditionalGeneration",
|
2025-10-06 00:29:18 -07:00
|
|
|
),
|
2025-10-05 17:18:11 +01:00
|
|
|
"Qwen2AudioForConditionalGeneration": (
|
|
|
|
|
"qwen2_audio",
|
|
|
|
|
"Qwen2AudioForConditionalGeneration",
|
2025-10-06 00:29:18 -07:00
|
|
|
),
|
2025-10-05 17:18:11 +01:00
|
|
|
"Qwen2_5OmniModel": (
|
|
|
|
|
"qwen2_5_omni_thinker",
|
|
|
|
|
"Qwen2_5OmniThinkerForConditionalGeneration",
|
2025-10-06 00:29:18 -07:00
|
|
|
),
|
2025-10-05 17:18:11 +01:00
|
|
|
"Qwen2_5OmniForConditionalGeneration": (
|
|
|
|
|
"qwen2_5_omni_thinker",
|
|
|
|
|
"Qwen2_5OmniThinkerForConditionalGeneration",
|
2025-10-06 00:29:18 -07:00
|
|
|
),
|
2025-10-11 01:00:56 +08:00
|
|
|
"Qwen3OmniMoeForConditionalGeneration": (
|
|
|
|
|
"qwen3_omni_moe_thinker",
|
|
|
|
|
"Qwen3OmniMoeThinkerForConditionalGeneration",
|
|
|
|
|
),
|
2025-09-16 22:01:04 -07:00
|
|
|
"Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"), # noqa: E501
|
2025-10-05 17:18:11 +01:00
|
|
|
"Qwen3VLMoeForConditionalGeneration": (
|
|
|
|
|
"qwen3_vl_moe",
|
|
|
|
|
"Qwen3VLMoeForConditionalGeneration",
|
2025-10-06 00:29:18 -07:00
|
|
|
),
|
2025-09-15 21:17:14 -07:00
|
|
|
"SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
|
2025-07-31 23:19:06 +08:00
|
|
|
"Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"), # noqa: E501
|
2025-06-03 13:13:13 +08:00
|
|
|
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501
|
2025-10-05 17:18:11 +01:00
|
|
|
"Tarsier2ForConditionalGeneration": (
|
|
|
|
|
"qwen2_vl",
|
|
|
|
|
"Tarsier2ForConditionalGeneration",
|
2025-10-06 00:29:18 -07:00
|
|
|
),
|
2025-09-16 22:01:04 -07:00
|
|
|
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
2025-07-15 16:35:30 +02:00
|
|
|
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
|
2025-12-23 14:31:55 +01:00
|
|
|
"VoxtralStreamingGeneration": ("voxtral_streaming", "VoxtralStreamingGeneration"), # noqa: E501
|
2024-10-07 19:55:12 +08:00
|
|
|
# [Encoder-decoder]
|
2026-01-05 23:00:14 +02:00
|
|
|
"NemotronParseForConditionalGeneration": (
|
|
|
|
|
"nemotron_parse",
|
|
|
|
|
"NemotronParseForConditionalGeneration",
|
|
|
|
|
),
|
2025-01-03 03:39:19 -05:00
|
|
|
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
|
2024-10-04 18:01:37 +08:00
|
|
|
}
|
2024-10-07 14:10:35 +08:00
|
|
|
|
|
|
|
|
_SPECULATIVE_DECODING_MODELS = {
|
2025-05-13 07:25:33 +08:00
|
|
|
"MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
|
2025-04-10 11:21:48 -07:00
|
|
|
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
|
2025-07-15 21:14:15 -07:00
|
|
|
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
|
2025-05-30 21:45:56 +08:00
|
|
|
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
|
2025-04-25 18:43:07 -04:00
|
|
|
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
2025-09-09 21:24:23 -07:00
|
|
|
"LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
2025-09-27 11:35:47 +08:00
|
|
|
"Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
2025-11-28 14:05:45 +08:00
|
|
|
"Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
2025-12-02 11:29:00 +01:00
|
|
|
"EagleMistralLarge3ForCausalLM": (
|
|
|
|
|
"mistral_large_3_eagle",
|
|
|
|
|
"EagleMistralLarge3ForCausalLM",
|
|
|
|
|
),
|
2025-08-20 04:01:31 -07:00
|
|
|
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
|
2025-02-19 01:06:23 -08:00
|
|
|
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
|
2025-08-20 20:41:55 +08:00
|
|
|
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
|
2025-09-25 12:53:40 +08:00
|
|
|
"LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
|
2025-07-20 06:40:31 +08:00
|
|
|
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
|
2024-10-07 14:10:35 +08:00
|
|
|
"MedusaModel": ("medusa", "Medusa"),
|
2025-11-05 00:17:20 +08:00
|
|
|
"OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
|
2025-09-11 15:32:09 +08:00
|
|
|
"Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
|
2025-07-18 21:47:50 -07:00
|
|
|
# Temporarily disabled.
|
|
|
|
|
# # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
|
|
|
|
|
# "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
2024-10-04 18:01:37 +08:00
|
|
|
}
|
2025-02-03 14:30:38 +01:00
|
|
|
|
2025-07-24 11:22:12 +01:00
|
|
|
_TRANSFORMERS_SUPPORTED_MODELS = {
|
2025-08-12 13:38:48 +01:00
|
|
|
# Text generation models
|
|
|
|
|
"SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"),
|
|
|
|
|
# Multimodal models
|
2025-10-16 22:50:39 +01:00
|
|
|
"Emu3ForConditionalGeneration": (
|
|
|
|
|
"transformers",
|
|
|
|
|
"TransformersMultiModalForCausalLM",
|
|
|
|
|
),
|
2025-07-24 11:22:12 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_TRANSFORMERS_BACKEND_MODELS = {
|
2025-10-16 22:50:39 +01:00
|
|
|
# Text generation models
|
2025-03-26 10:13:38 +00:00
|
|
|
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
|
2025-10-16 22:50:39 +01:00
|
|
|
"TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"),
|
|
|
|
|
# Multimodal models
|
|
|
|
|
"TransformersMultiModalForCausalLM": (
|
|
|
|
|
"transformers",
|
|
|
|
|
"TransformersMultiModalForCausalLM",
|
|
|
|
|
),
|
|
|
|
|
"TransformersMultiModalMoEForCausalLM": (
|
|
|
|
|
"transformers",
|
|
|
|
|
"TransformersMultiModalMoEForCausalLM",
|
|
|
|
|
),
|
|
|
|
|
# Embedding models
|
|
|
|
|
"TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"),
|
|
|
|
|
"TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"),
|
|
|
|
|
"TransformersMultiModalEmbeddingModel": (
|
|
|
|
|
"transformers",
|
|
|
|
|
"TransformersMultiModalEmbeddingModel",
|
|
|
|
|
),
|
|
|
|
|
# Sequence classification models
|
2025-10-05 17:18:11 +01:00
|
|
|
"TransformersForSequenceClassification": (
|
2025-10-16 22:50:39 +01:00
|
|
|
"transformers",
|
2025-10-05 17:18:11 +01:00
|
|
|
"TransformersForSequenceClassification",
|
2025-10-06 00:29:18 -07:00
|
|
|
),
|
2025-10-05 17:18:11 +01:00
|
|
|
"TransformersMoEForSequenceClassification": (
|
2025-10-16 22:50:39 +01:00
|
|
|
"transformers",
|
2025-10-05 17:18:11 +01:00
|
|
|
"TransformersMoEForSequenceClassification",
|
2025-10-06 00:29:18 -07:00
|
|
|
),
|
2025-10-16 22:50:39 +01:00
|
|
|
"TransformersMultiModalForSequenceClassification": (
|
|
|
|
|
"transformers",
|
|
|
|
|
"TransformersMultiModalForSequenceClassification",
|
2025-10-06 00:29:18 -07:00
|
|
|
),
|
2025-02-03 14:30:38 +01:00
|
|
|
}
|
2024-10-04 18:01:37 +08:00
|
|
|
|
2024-10-11 19:08:11 +08:00
|
|
|
_VLLM_MODELS = {
|
2024-10-07 14:10:35 +08:00
|
|
|
**_TEXT_GENERATION_MODELS,
|
2024-10-04 18:01:37 +08:00
|
|
|
**_EMBEDDING_MODELS,
|
2024-11-24 23:56:20 -03:00
|
|
|
**_CROSS_ENCODER_MODELS,
|
2024-10-04 18:01:37 +08:00
|
|
|
**_MULTIMODAL_MODELS,
|
2024-10-07 14:10:35 +08:00
|
|
|
**_SPECULATIVE_DECODING_MODELS,
|
2025-07-24 11:22:12 +01:00
|
|
|
**_TRANSFORMERS_SUPPORTED_MODELS,
|
|
|
|
|
**_TRANSFORMERS_BACKEND_MODELS,
|
2024-10-04 18:01:37 +08:00
|
|
|
}
|
|
|
|
|
|
2025-02-12 02:36:10 -08:00
|
|
|
# This variable is used as the args for subprocess.run(). We
|
|
|
|
|
# can modify this variable to alter the args if needed. e.g.
|
|
|
|
|
# when we use par format to pack things together, sys.executable
|
|
|
|
|
# might not be the target we want to run.
|
|
|
|
|
_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"]
|
|
|
|
|
|
2025-09-17 01:46:46 +08:00
|
|
|
_PREVIOUSLY_SUPPORTED_MODELS = {
|
2025-09-30 00:36:30 +08:00
|
|
|
"MotifForCausalLM": "0.10.2",
|
2025-09-17 01:46:46 +08:00
|
|
|
"Phi3SmallForCausalLM": "0.9.2",
|
2025-09-29 07:08:17 +02:00
|
|
|
"Phi4FlashForCausalLM": "0.10.2",
|
2025-12-04 13:44:50 +00:00
|
|
|
"Phi4MultimodalForCausalLM": "0.12.0",
|
2025-09-17 01:46:46 +08:00
|
|
|
# encoder-decoder models except whisper
|
|
|
|
|
# have been removed for V0 deprecation.
|
|
|
|
|
"BartModel": "0.10.2",
|
|
|
|
|
"BartForConditionalGeneration": "0.10.2",
|
|
|
|
|
"DonutForConditionalGeneration": "0.10.2",
|
|
|
|
|
"Florence2ForConditionalGeneration": "0.10.2",
|
|
|
|
|
"MBartForConditionalGeneration": "0.10.2",
|
|
|
|
|
"MllamaForConditionalGeneration": "0.10.2",
|
|
|
|
|
}
|
2025-07-23 15:03:16 +08:00
|
|
|
|
2024-10-04 18:01:37 +08:00
|
|
|
|
2024-10-11 19:08:11 +08:00
|
|
|
@dataclass(frozen=True)
|
|
|
|
|
class _ModelInfo:
|
2024-12-01 08:02:54 +08:00
|
|
|
architecture: str
|
2024-10-11 19:08:11 +08:00
|
|
|
is_text_generation_model: bool
|
2024-12-01 14:36:51 +08:00
|
|
|
is_pooling_model: bool
|
2025-11-28 14:05:48 +08:00
|
|
|
attn_type: AttnTypeStr
|
|
|
|
|
default_pooling_type: PoolingTypeStr
|
2024-11-24 23:56:20 -03:00
|
|
|
supports_cross_encoding: bool
|
2024-10-11 19:08:11 +08:00
|
|
|
supports_multimodal: bool
|
2025-08-28 01:01:50 +08:00
|
|
|
supports_multimodal_raw_input_only: bool
|
2025-12-23 14:31:55 +01:00
|
|
|
requires_raw_input_tokens: bool
|
2025-08-27 21:41:22 +08:00
|
|
|
supports_multimodal_encoder_tp_data: bool
|
2024-10-11 19:08:11 +08:00
|
|
|
supports_pp: bool
|
2024-10-11 11:40:06 -04:00
|
|
|
has_inner_state: bool
|
|
|
|
|
is_attention_free: bool
|
2024-12-11 04:53:37 +02:00
|
|
|
is_hybrid: bool
|
2025-03-31 15:35:14 +03:00
|
|
|
has_noops: bool
|
2025-10-27 15:05:20 +02:00
|
|
|
supports_mamba_prefix_caching: bool
|
2025-02-13 16:23:45 +01:00
|
|
|
supports_transcription: bool
|
2025-07-13 04:40:11 +02:00
|
|
|
supports_transcription_only: bool
|
2024-10-04 18:01:37 +08:00
|
|
|
|
|
|
|
|
@staticmethod
|
2025-05-15 06:06:50 +01:00
|
|
|
def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
|
2024-10-11 19:08:11 +08:00
|
|
|
return _ModelInfo(
|
2024-12-01 08:02:54 +08:00
|
|
|
architecture=model.__name__,
|
2024-10-11 19:08:11 +08:00
|
|
|
is_text_generation_model=is_text_generation_model(model),
|
2025-07-28 10:42:40 +08:00
|
|
|
is_pooling_model=is_pooling_model(model),
|
2025-08-12 00:41:37 +08:00
|
|
|
default_pooling_type=get_default_pooling_type(model),
|
2025-11-28 14:05:48 +08:00
|
|
|
attn_type=get_attn_type(model),
|
2024-11-24 23:56:20 -03:00
|
|
|
supports_cross_encoding=supports_cross_encoding(model),
|
2024-10-11 19:08:11 +08:00
|
|
|
supports_multimodal=supports_multimodal(model),
|
2025-08-28 01:01:50 +08:00
|
|
|
supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
|
2025-08-27 21:41:22 +08:00
|
|
|
model
|
|
|
|
|
),
|
2025-12-23 14:31:55 +01:00
|
|
|
requires_raw_input_tokens=requires_raw_input_tokens(model),
|
2025-08-27 21:41:22 +08:00
|
|
|
supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
|
|
|
|
|
model
|
2025-10-05 15:06:22 +01:00
|
|
|
),
|
2024-10-11 19:08:11 +08:00
|
|
|
supports_pp=supports_pp(model),
|
2024-10-11 11:40:06 -04:00
|
|
|
has_inner_state=has_inner_state(model),
|
|
|
|
|
is_attention_free=is_attention_free(model),
|
2024-12-11 04:53:37 +02:00
|
|
|
is_hybrid=is_hybrid(model),
|
2025-10-27 15:05:20 +02:00
|
|
|
supports_mamba_prefix_caching=supports_mamba_prefix_caching(model),
|
2025-02-27 17:02:15 -08:00
|
|
|
supports_transcription=supports_transcription(model),
|
2025-07-13 04:40:11 +02:00
|
|
|
supports_transcription_only=(
|
|
|
|
|
supports_transcription(model) and model.supports_transcription_only
|
|
|
|
|
),
|
2025-03-31 15:35:14 +03:00
|
|
|
has_noops=has_noops(model),
|
2025-02-27 17:02:15 -08:00
|
|
|
)
|
2024-10-04 10:38:25 -07:00
|
|
|
|
|
|
|
|
|
2024-10-11 19:08:11 +08:00
|
|
|
class _BaseRegisteredModel(ABC):
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def inspect_model_cls(self) -> _ModelInfo:
|
|
|
|
|
raise NotImplementedError
|
2024-10-04 18:01:37 +08:00
|
|
|
|
2024-10-11 19:08:11 +08:00
|
|
|
@abstractmethod
|
2025-05-15 06:06:50 +01:00
|
|
|
def load_model_cls(self) -> type[nn.Module]:
|
2024-10-11 19:08:11 +08:00
|
|
|
raise NotImplementedError
|
2024-10-04 18:01:37 +08:00
|
|
|
|
|
|
|
|
|
2024-10-11 19:08:11 +08:00
|
|
|
@dataclass(frozen=True)
|
|
|
|
|
class _RegisteredModel(_BaseRegisteredModel):
|
|
|
|
|
"""
|
|
|
|
|
Represents a model that has already been imported in the main process.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
interfaces: _ModelInfo
|
2025-05-15 06:06:50 +01:00
|
|
|
model_cls: type[nn.Module]
|
2024-10-04 18:01:37 +08:00
|
|
|
|
|
|
|
|
@staticmethod
|
2025-05-15 06:06:50 +01:00
|
|
|
def from_model_cls(model_cls: type[nn.Module]):
|
2024-10-11 19:08:11 +08:00
|
|
|
return _RegisteredModel(
|
|
|
|
|
interfaces=_ModelInfo.from_model_cls(model_cls),
|
|
|
|
|
model_cls=model_cls,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def inspect_model_cls(self) -> _ModelInfo:
|
|
|
|
|
return self.interfaces
|
|
|
|
|
|
2025-05-15 06:06:50 +01:00
|
|
|
def load_model_cls(self) -> type[nn.Module]:
|
2024-10-11 19:08:11 +08:00
|
|
|
return self.model_cls
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
|
class _LazyRegisteredModel(_BaseRegisteredModel):
|
|
|
|
|
"""
|
|
|
|
|
Represents a model that has not been imported in the main process.
|
|
|
|
|
"""
|
2025-10-05 15:06:22 +01:00
|
|
|
|
2024-10-11 19:08:11 +08:00
|
|
|
module_name: str
|
|
|
|
|
class_name: str
|
|
|
|
|
|
2025-09-20 07:51:13 -04:00
|
|
|
@staticmethod
|
|
|
|
|
def _get_cache_dir() -> Path:
|
|
|
|
|
return Path(envs.VLLM_CACHE_ROOT) / "modelinfos"
|
|
|
|
|
|
|
|
|
|
def _get_cache_filename(self) -> str:
|
|
|
|
|
cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-")
|
|
|
|
|
return f"{cls_name}.json"
|
|
|
|
|
|
|
|
|
|
def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None:
|
|
|
|
|
try:
|
|
|
|
|
try:
|
|
|
|
|
modelinfo_path = self._get_cache_dir() / self._get_cache_filename()
|
|
|
|
|
with open(modelinfo_path, encoding="utf-8") as file:
|
|
|
|
|
mi_dict = json.load(file)
|
|
|
|
|
except FileNotFoundError:
|
|
|
|
|
logger.debug(
|
2025-11-18 12:33:46 +08:00
|
|
|
"Cached model info file for class %s.%s not found",
|
2025-09-20 07:51:13 -04:00
|
|
|
self.module_name,
|
|
|
|
|
self.class_name,
|
|
|
|
|
)
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
if mi_dict["hash"] != module_hash:
|
|
|
|
|
logger.debug(
|
2025-11-18 12:33:46 +08:00
|
|
|
"Cached model info file for class %s.%s is stale",
|
2025-09-20 07:51:13 -04:00
|
|
|
self.module_name,
|
|
|
|
|
self.class_name,
|
|
|
|
|
)
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# file not changed, use cached _ModelInfo properties
|
|
|
|
|
return _ModelInfo(**mi_dict["modelinfo"])
|
|
|
|
|
except Exception:
|
2025-10-15 17:01:09 +01:00
|
|
|
logger.debug(
|
2025-11-18 12:33:46 +08:00
|
|
|
"Cached model info for class %s.%s error. ",
|
2025-09-20 07:51:13 -04:00
|
|
|
self.module_name,
|
|
|
|
|
self.class_name,
|
|
|
|
|
)
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:
|
|
|
|
|
"""save dictionary json file to cache"""
|
|
|
|
|
from vllm.model_executor.model_loader.weight_utils import atomic_writer
|
2025-10-05 15:06:22 +01:00
|
|
|
|
2025-09-20 07:51:13 -04:00
|
|
|
try:
|
|
|
|
|
modelinfo_dict = {
|
|
|
|
|
"hash": module_hash,
|
|
|
|
|
"modelinfo": asdict(mi),
|
|
|
|
|
}
|
|
|
|
|
cache_dir = self._get_cache_dir()
|
|
|
|
|
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
modelinfo_path = cache_dir / self._get_cache_filename()
|
|
|
|
|
with atomic_writer(modelinfo_path, encoding="utf-8") as f:
|
|
|
|
|
json.dump(modelinfo_dict, f, indent=2)
|
|
|
|
|
except Exception:
|
|
|
|
|
logger.exception("Error saving model info cache.")
|
|
|
|
|
|
|
|
|
|
@logtime(logger=logger, msg="Registry inspect model class")
|
2024-10-11 19:08:11 +08:00
|
|
|
def inspect_model_cls(self) -> _ModelInfo:
|
2025-09-20 07:51:13 -04:00
|
|
|
model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
|
2025-09-24 16:03:13 +08:00
|
|
|
module_hash = None
|
2025-09-20 07:51:13 -04:00
|
|
|
|
2025-09-24 16:03:13 +08:00
|
|
|
if model_path.exists():
|
|
|
|
|
with open(model_path, "rb") as f:
|
2025-11-25 18:50:22 -06:00
|
|
|
module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
|
2025-09-24 16:03:13 +08:00
|
|
|
|
|
|
|
|
mi = self._load_modelinfo_from_cache(module_hash)
|
|
|
|
|
if mi is not None:
|
|
|
|
|
logger.debug(
|
2025-11-18 12:33:46 +08:00
|
|
|
"Loaded model info for class %s.%s from cache",
|
2025-09-24 16:03:13 +08:00
|
|
|
self.module_name,
|
|
|
|
|
self.class_name,
|
|
|
|
|
)
|
|
|
|
|
return mi
|
|
|
|
|
else:
|
|
|
|
|
logger.debug(
|
2025-11-18 12:33:46 +08:00
|
|
|
"Cache model info for class %s.%s miss. Loading model instead.",
|
2025-09-24 16:03:13 +08:00
|
|
|
self.module_name,
|
|
|
|
|
self.class_name,
|
|
|
|
|
)
|
2025-09-20 07:51:13 -04:00
|
|
|
|
|
|
|
|
# Performed in another process to avoid initializing CUDA
|
|
|
|
|
mi = _run_in_subprocess(
|
2024-10-11 19:08:11 +08:00
|
|
|
lambda: _ModelInfo.from_model_cls(self.load_model_cls())
|
2025-10-05 15:06:22 +01:00
|
|
|
)
|
2025-09-20 07:51:13 -04:00
|
|
|
logger.debug(
|
|
|
|
|
"Loaded model info for class %s.%s", self.module_name, self.class_name
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# save cache file
|
2025-09-24 16:03:13 +08:00
|
|
|
if module_hash is not None:
|
|
|
|
|
self._save_modelinfo_to_cache(mi, module_hash)
|
2025-09-20 07:51:13 -04:00
|
|
|
|
|
|
|
|
return mi
|
2024-10-11 19:08:11 +08:00
|
|
|
|
2025-05-15 06:06:50 +01:00
|
|
|
def load_model_cls(self) -> type[nn.Module]:
|
2024-10-11 19:08:11 +08:00
|
|
|
mod = importlib.import_module(self.module_name)
|
|
|
|
|
return getattr(mod, self.class_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache(maxsize=128)
|
|
|
|
|
def _try_load_model_cls(
|
|
|
|
|
model_arch: str,
|
|
|
|
|
model: _BaseRegisteredModel,
|
2025-05-15 06:06:50 +01:00
|
|
|
) -> type[nn.Module] | None:
|
2024-12-30 20:24:45 +08:00
|
|
|
from vllm.platforms import current_platform
|
2025-10-05 15:06:22 +01:00
|
|
|
|
2024-12-27 16:45:25 +08:00
|
|
|
current_platform.verify_model_arch(model_arch)
|
2024-10-11 19:08:11 +08:00
|
|
|
try:
|
|
|
|
|
return model.load_model_cls()
|
|
|
|
|
except Exception:
|
|
|
|
|
logger.exception("Error in loading model architecture '%s'", model_arch)
|
|
|
|
|
return None
|
2024-10-04 18:01:37 +08:00
|
|
|
|
|
|
|
|
|
2024-10-11 19:08:11 +08:00
|
|
|
@lru_cache(maxsize=128)
|
|
|
|
|
def _try_inspect_model_cls(
|
|
|
|
|
model_arch: str,
|
|
|
|
|
model: _BaseRegisteredModel,
|
|
|
|
|
) -> _ModelInfo | None:
|
|
|
|
|
try:
|
|
|
|
|
return model.inspect_model_cls()
|
|
|
|
|
except Exception:
|
|
|
|
|
logger.exception("Error in inspecting model architecture '%s'", model_arch)
|
|
|
|
|
return None
|
2024-10-04 18:01:37 +08:00
|
|
|
|
|
|
|
|
|
2024-10-11 19:08:11 +08:00
|
|
|
@dataclass
|
|
|
|
|
class _ModelRegistry:
|
|
|
|
|
# Keyed by model_arch
|
2025-05-15 06:06:50 +01:00
|
|
|
models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
|
2024-10-04 18:01:37 +08:00
|
|
|
|
2025-05-15 06:06:50 +01:00
|
|
|
def get_supported_archs(self) -> Set[str]:
|
2024-11-15 12:23:09 +08:00
|
|
|
return self.models.keys()
|
2024-10-04 18:01:37 +08:00
|
|
|
|
2024-10-11 19:08:11 +08:00
|
|
|
def register_model(
|
|
|
|
|
self,
|
|
|
|
|
model_arch: str,
|
2025-05-15 06:06:50 +01:00
|
|
|
model_cls: type[nn.Module] | str,
|
2024-10-11 19:08:11 +08:00
|
|
|
) -> None:
|
2024-10-04 10:38:25 -07:00
|
|
|
"""
|
|
|
|
|
Register an external model to be used in vLLM.
|
|
|
|
|
|
2025-05-04 03:42:43 +01:00
|
|
|
`model_cls` can be either:
|
2024-10-04 10:38:25 -07:00
|
|
|
|
2025-05-27 17:44:20 +09:00
|
|
|
- A [`torch.nn.Module`][] class directly referencing the model.
|
2025-05-04 03:42:43 +01:00
|
|
|
- A string in the format `<module>:<class>` which can be used to
|
2024-10-04 10:38:25 -07:00
|
|
|
lazily import the model. This is useful to avoid initializing CUDA
|
|
|
|
|
when importing the model and thus the related error
|
2025-05-04 03:42:43 +01:00
|
|
|
`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
|
2024-10-04 10:38:25 -07:00
|
|
|
"""
|
2025-02-25 16:18:19 +08:00
|
|
|
if not isinstance(model_arch, str):
|
|
|
|
|
msg = f"`model_arch` should be a string, not a {type(model_arch)}"
|
|
|
|
|
raise TypeError(msg)
|
|
|
|
|
|
2024-10-11 19:08:11 +08:00
|
|
|
if model_arch in self.models:
|
2024-10-04 18:01:37 +08:00
|
|
|
logger.warning(
|
|
|
|
|
"Model architecture %s is already registered, and will be "
|
|
|
|
|
"overwritten by the new model class %s.",
|
|
|
|
|
model_arch,
|
2024-10-04 10:38:25 -07:00
|
|
|
model_cls,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if isinstance(model_cls, str):
|
|
|
|
|
split_str = model_cls.split(":")
|
|
|
|
|
if len(split_str) != 2:
|
|
|
|
|
msg = "Expected a string in the format `<module>:<class>`"
|
|
|
|
|
raise ValueError(msg)
|
2024-10-04 18:01:37 +08:00
|
|
|
|
2024-10-11 19:08:11 +08:00
|
|
|
model = _LazyRegisteredModel(*split_str)
|
2025-05-04 03:42:43 +01:00
|
|
|
elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
|
2024-10-11 19:08:11 +08:00
|
|
|
model = _RegisteredModel.from_model_cls(model_cls)
|
2025-02-25 16:18:19 +08:00
|
|
|
else:
|
|
|
|
|
msg = (
|
|
|
|
|
"`model_cls` should be a string or PyTorch model class, "
|
|
|
|
|
f"not a {type(model_arch)}"
|
|
|
|
|
)
|
|
|
|
|
raise TypeError(msg)
|
2024-10-04 18:01:37 +08:00
|
|
|
|
2024-10-11 19:08:11 +08:00
|
|
|
self.models[model_arch] = model
|
2024-10-04 18:01:37 +08:00
|
|
|
|
2025-05-15 06:06:50 +01:00
|
|
|
def _raise_for_unsupported(self, architectures: list[str]):
|
2024-10-11 19:08:11 +08:00
|
|
|
all_supported_archs = self.get_supported_archs()
|
2024-10-04 18:01:37 +08:00
|
|
|
|
2024-11-05 10:07:31 +08:00
|
|
|
if any(arch in all_supported_archs for arch in architectures):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Model architectures {architectures} failed "
|
|
|
|
|
"to be inspected. Please check the logs for more details."
|
|
|
|
|
)
|
|
|
|
|
|
2025-07-28 10:42:40 +08:00
|
|
|
for arch in architectures:
|
|
|
|
|
if arch in _PREVIOUSLY_SUPPORTED_MODELS:
|
|
|
|
|
previous_version = _PREVIOUSLY_SUPPORTED_MODELS[arch]
|
|
|
|
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Model architecture {arch} was supported in vLLM until "
|
|
|
|
|
f"v{previous_version}, and is not supported anymore. "
|
|
|
|
|
"Please use an older version of vLLM if you want to "
|
|
|
|
|
"use this model architecture."
|
|
|
|
|
)
|
|
|
|
|
|
2024-10-11 19:08:11 +08:00
|
|
|
raise ValueError(
|
|
|
|
|
f"Model architectures {architectures} are not supported for now. "
|
|
|
|
|
f"Supported architectures: {all_supported_archs}"
|
|
|
|
|
)
|
2024-10-04 18:01:37 +08:00
|
|
|
|
2025-05-15 06:06:50 +01:00
|
|
|
def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
|
2024-10-11 19:08:11 +08:00
|
|
|
if model_arch not in self.models:
|
|
|
|
|
return None
|
2024-10-04 18:01:37 +08:00
|
|
|
|
2024-10-11 19:08:11 +08:00
|
|
|
return _try_load_model_cls(model_arch, self.models[model_arch])
|
2024-10-07 14:10:35 +08:00
|
|
|
|
2024-10-11 19:08:11 +08:00
|
|
|
def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
|
2025-07-28 10:42:40 +08:00
|
|
|
if model_arch not in self.models:
|
|
|
|
|
return None
|
2025-07-18 15:15:07 +08:00
|
|
|
|
2025-07-28 10:42:40 +08:00
|
|
|
return _try_inspect_model_cls(model_arch, self.models[model_arch])
|
|
|
|
|
|
|
|
|
|
def _try_resolve_transformers(
|
|
|
|
|
self,
|
|
|
|
|
architecture: str,
|
|
|
|
|
model_config: ModelConfig,
|
|
|
|
|
) -> str | None:
|
|
|
|
|
if architecture in _TRANSFORMERS_BACKEND_MODELS:
|
|
|
|
|
return architecture
|
|
|
|
|
|
|
|
|
|
auto_map: dict[str, str] = (
|
|
|
|
|
getattr(model_config.hf_config, "auto_map", None) or dict()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Make sure that config class is always initialized before model class,
|
|
|
|
|
# otherwise the model class won't be able to access the config class,
|
|
|
|
|
# the expected auto_map should have correct order like:
|
|
|
|
|
# "auto_map": {
|
|
|
|
|
# "AutoConfig": "<your-repo-name>--<config-name>",
|
|
|
|
|
# "AutoModel": "<your-repo-name>--<config-name>",
|
|
|
|
|
# "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
|
|
|
|
|
# },
|
|
|
|
|
for prefix in ("AutoConfig", "AutoModel"):
|
|
|
|
|
for name, module in auto_map.items():
|
|
|
|
|
if name.startswith(prefix):
|
|
|
|
|
try_get_class_from_dynamic_module(
|
|
|
|
|
module,
|
|
|
|
|
model_config.model,
|
|
|
|
|
revision=model_config.revision,
|
|
|
|
|
warn_on_fail=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
model_module = getattr(transformers, architecture, None)
|
|
|
|
|
|
|
|
|
|
if model_module is None:
|
|
|
|
|
for name, module in auto_map.items():
|
|
|
|
|
if name.startswith("AutoModel"):
|
|
|
|
|
model_module = try_get_class_from_dynamic_module(
|
|
|
|
|
module,
|
|
|
|
|
model_config.model,
|
|
|
|
|
revision=model_config.revision,
|
|
|
|
|
warn_on_fail=True,
|
|
|
|
|
)
|
|
|
|
|
if model_module is not None:
|
|
|
|
|
break
|
|
|
|
|
else:
|
2025-09-19 17:22:33 +01:00
|
|
|
if model_config.model_impl != "transformers":
|
2025-07-28 10:42:40 +08:00
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Cannot find model module. {architecture!r} is not a "
|
|
|
|
|
"registered model in the Transformers library (only "
|
|
|
|
|
"relevant if the model is meant to be in Transformers) "
|
|
|
|
|
"and 'AutoModel' is not present in the model config's "
|
|
|
|
|
"'auto_map' (relevant if the model is custom)."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not model_module.is_backend_compatible():
|
2025-09-19 17:22:33 +01:00
|
|
|
if model_config.model_impl != "transformers":
|
2025-07-18 15:15:07 +08:00
|
|
|
return None
|
2024-10-07 14:10:35 +08:00
|
|
|
|
2025-07-28 10:42:40 +08:00
|
|
|
raise ValueError(
|
|
|
|
|
f"The Transformers implementation of {architecture!r} "
|
|
|
|
|
"is not compatible with vLLM."
|
|
|
|
|
)
|
2025-07-18 15:15:07 +08:00
|
|
|
|
2025-07-28 10:42:40 +08:00
|
|
|
return model_config._get_transformers_backend_cls()
|
2025-07-18 15:15:07 +08:00
|
|
|
|
2025-07-28 10:42:40 +08:00
|
|
|
def _normalize_arch(
|
|
|
|
|
self,
|
|
|
|
|
architecture: str,
|
|
|
|
|
model_config: ModelConfig,
|
|
|
|
|
) -> str:
|
|
|
|
|
if architecture in self.models:
|
|
|
|
|
return architecture
|
|
|
|
|
|
|
|
|
|
# This may be called in order to resolve runner_type and convert_type
|
|
|
|
|
# in the first place, in which case we consider the default match
|
|
|
|
|
match = try_match_architecture_defaults(
|
|
|
|
|
architecture,
|
|
|
|
|
runner_type=getattr(model_config, "runner_type", None),
|
|
|
|
|
convert_type=getattr(model_config, "convert_type", None),
|
|
|
|
|
)
|
|
|
|
|
if match:
|
|
|
|
|
suffix, _ = match
|
|
|
|
|
|
|
|
|
|
# Get the name of the base model to convert
|
|
|
|
|
for repl_suffix, _ in iter_architecture_defaults():
|
|
|
|
|
base_arch = architecture.replace(suffix, repl_suffix)
|
|
|
|
|
if base_arch in self.models:
|
|
|
|
|
return base_arch
|
|
|
|
|
|
|
|
|
|
return architecture
|
2024-10-07 14:10:35 +08:00
|
|
|
|
2024-10-11 19:08:11 +08:00
|
|
|
def inspect_model_cls(
|
|
|
|
|
self,
|
2025-05-15 06:06:50 +01:00
|
|
|
architectures: str | list[str],
|
2025-07-28 10:42:40 +08:00
|
|
|
model_config: ModelConfig,
|
2025-05-15 06:06:50 +01:00
|
|
|
) -> tuple[_ModelInfo, str]:
|
2025-07-28 10:42:40 +08:00
|
|
|
if isinstance(architectures, str):
|
|
|
|
|
architectures = [architectures]
|
2025-07-28 17:15:31 +08:00
|
|
|
if not architectures:
|
|
|
|
|
raise ValueError("No model architectures are specified")
|
2025-07-28 10:42:40 +08:00
|
|
|
|
|
|
|
|
# Require transformers impl
|
2025-09-19 17:22:33 +01:00
|
|
|
if model_config.model_impl == "transformers":
|
2025-07-28 10:42:40 +08:00
|
|
|
arch = self._try_resolve_transformers(architectures[0], model_config)
|
|
|
|
|
if arch is not None:
|
|
|
|
|
model_info = self._try_inspect_model_cls(arch)
|
|
|
|
|
if model_info is not None:
|
|
|
|
|
return (model_info, arch)
|
2025-09-19 17:22:33 +01:00
|
|
|
elif model_config.model_impl == "terratorch":
|
2025-09-04 08:22:41 +01:00
|
|
|
model_info = self._try_inspect_model_cls("Terratorch")
|
|
|
|
|
return (model_info, "Terratorch")
|
2025-07-28 10:42:40 +08:00
|
|
|
|
2025-07-28 17:15:31 +08:00
|
|
|
# Fallback to transformers impl (after resolving convert_type)
|
|
|
|
|
if (
|
|
|
|
|
all(arch not in self.models for arch in architectures)
|
2025-09-19 17:22:33 +01:00
|
|
|
and model_config.model_impl == "auto"
|
2025-07-28 17:15:31 +08:00
|
|
|
and getattr(model_config, "convert_type", "none") == "none"
|
|
|
|
|
):
|
|
|
|
|
arch = self._try_resolve_transformers(architectures[0], model_config)
|
|
|
|
|
if arch is not None:
|
|
|
|
|
model_info = self._try_inspect_model_cls(arch)
|
|
|
|
|
if model_info is not None:
|
|
|
|
|
return (model_info, arch)
|
|
|
|
|
|
|
|
|
|
for arch in architectures:
|
|
|
|
|
normalized_arch = self._normalize_arch(arch, model_config)
|
2025-07-28 10:42:40 +08:00
|
|
|
model_info = self._try_inspect_model_cls(normalized_arch)
|
2024-10-11 19:08:11 +08:00
|
|
|
if model_info is not None:
|
2024-12-01 08:02:54 +08:00
|
|
|
return (model_info, arch)
|
2024-10-04 18:01:37 +08:00
|
|
|
|
2025-07-28 17:15:31 +08:00
|
|
|
# Fallback to transformers impl (before resolving runner_type)
|
|
|
|
|
if (
|
|
|
|
|
all(arch not in self.models for arch in architectures)
|
2025-09-19 17:22:33 +01:00
|
|
|
and model_config.model_impl == "auto"
|
|
|
|
|
):
|
2025-07-28 10:42:40 +08:00
|
|
|
arch = self._try_resolve_transformers(architectures[0], model_config)
|
|
|
|
|
if arch is not None:
|
|
|
|
|
model_info = self._try_inspect_model_cls(arch)
|
|
|
|
|
if model_info is not None:
|
|
|
|
|
return (model_info, arch)
|
|
|
|
|
|
2024-10-11 19:08:11 +08:00
|
|
|
return self._raise_for_unsupported(architectures)
|
2024-10-04 18:01:37 +08:00
|
|
|
|
2024-10-11 19:08:11 +08:00
|
|
|
def resolve_model_cls(
|
|
|
|
|
self,
|
2025-05-15 06:06:50 +01:00
|
|
|
architectures: str | list[str],
|
2025-07-28 10:42:40 +08:00
|
|
|
model_config: ModelConfig,
|
2025-05-15 06:06:50 +01:00
|
|
|
) -> tuple[type[nn.Module], str]:
|
2025-07-28 10:42:40 +08:00
|
|
|
if isinstance(architectures, str):
|
|
|
|
|
architectures = [architectures]
|
2025-07-28 17:15:31 +08:00
|
|
|
if not architectures:
|
|
|
|
|
raise ValueError("No model architectures are specified")
|
2025-07-28 10:42:40 +08:00
|
|
|
|
|
|
|
|
# Require transformers impl
|
2025-09-19 17:22:33 +01:00
|
|
|
if model_config.model_impl == "transformers":
|
2025-07-28 10:42:40 +08:00
|
|
|
arch = self._try_resolve_transformers(architectures[0], model_config)
|
|
|
|
|
if arch is not None:
|
|
|
|
|
model_cls = self._try_load_model_cls(arch)
|
|
|
|
|
if model_cls is not None:
|
|
|
|
|
return (model_cls, arch)
|
2025-09-19 17:22:33 +01:00
|
|
|
elif model_config.model_impl == "terratorch":
|
2025-09-04 08:22:41 +01:00
|
|
|
arch = "Terratorch"
|
|
|
|
|
model_cls = self._try_load_model_cls(arch)
|
|
|
|
|
if model_cls is not None:
|
|
|
|
|
return (model_cls, arch)
|
2025-07-28 10:42:40 +08:00
|
|
|
|
2025-07-28 17:15:31 +08:00
|
|
|
# Fallback to transformers impl (after resolving convert_type)
|
|
|
|
|
if (
|
|
|
|
|
all(arch not in self.models for arch in architectures)
|
2025-09-19 17:22:33 +01:00
|
|
|
and model_config.model_impl == "auto"
|
2025-07-28 17:15:31 +08:00
|
|
|
and getattr(model_config, "convert_type", "none") == "none"
|
|
|
|
|
):
|
|
|
|
|
arch = self._try_resolve_transformers(architectures[0], model_config)
|
|
|
|
|
if arch is not None:
|
|
|
|
|
model_cls = self._try_load_model_cls(arch)
|
|
|
|
|
if model_cls is not None:
|
|
|
|
|
return (model_cls, arch)
|
|
|
|
|
|
|
|
|
|
for arch in architectures:
|
|
|
|
|
normalized_arch = self._normalize_arch(arch, model_config)
|
2025-07-28 10:42:40 +08:00
|
|
|
model_cls = self._try_load_model_cls(normalized_arch)
|
2024-10-11 19:08:11 +08:00
|
|
|
if model_cls is not None:
|
|
|
|
|
return (model_cls, arch)
|
2024-10-04 18:01:37 +08:00
|
|
|
|
2025-07-28 17:15:31 +08:00
|
|
|
# Fallback to transformers impl (before resolving runner_type)
|
|
|
|
|
if (
|
|
|
|
|
all(arch not in self.models for arch in architectures)
|
2025-09-19 17:22:33 +01:00
|
|
|
and model_config.model_impl == "auto"
|
|
|
|
|
):
|
2025-07-28 10:42:40 +08:00
|
|
|
arch = self._try_resolve_transformers(architectures[0], model_config)
|
|
|
|
|
if arch is not None:
|
|
|
|
|
model_cls = self._try_load_model_cls(arch)
|
|
|
|
|
if model_cls is not None:
|
|
|
|
|
return (model_cls, arch)
|
|
|
|
|
|
2024-10-11 19:08:11 +08:00
|
|
|
return self._raise_for_unsupported(architectures)
|
2024-10-04 18:01:37 +08:00
|
|
|
|
2024-10-11 19:08:11 +08:00
|
|
|
def is_text_generation_model(
|
|
|
|
|
self,
|
2025-05-15 06:06:50 +01:00
|
|
|
architectures: str | list[str],
|
2025-07-28 10:42:40 +08:00
|
|
|
model_config: ModelConfig,
|
2024-10-11 19:08:11 +08:00
|
|
|
) -> bool:
|
2025-07-28 10:42:40 +08:00
|
|
|
model_cls, _ = self.inspect_model_cls(architectures, model_config)
|
2024-12-01 08:02:54 +08:00
|
|
|
return model_cls.is_text_generation_model
|
2024-10-09 21:58:27 -07:00
|
|
|
|
2024-12-01 14:36:51 +08:00
|
|
|
def is_pooling_model(
|
2024-10-11 19:08:11 +08:00
|
|
|
self,
|
2025-05-15 06:06:50 +01:00
|
|
|
architectures: str | list[str],
|
2025-07-28 10:42:40 +08:00
|
|
|
model_config: ModelConfig,
|
2024-10-11 19:08:11 +08:00
|
|
|
) -> bool:
|
2025-07-28 10:42:40 +08:00
|
|
|
model_cls, _ = self.inspect_model_cls(architectures, model_config)
|
2024-12-01 14:36:51 +08:00
|
|
|
return model_cls.is_pooling_model
|
2024-10-11 19:08:11 +08:00
|
|
|
|
2024-11-24 23:56:20 -03:00
|
|
|
def is_cross_encoder_model(
|
|
|
|
|
self,
|
2025-05-15 06:06:50 +01:00
|
|
|
architectures: str | list[str],
|
2025-07-28 10:42:40 +08:00
|
|
|
model_config: ModelConfig,
|
2024-11-24 23:56:20 -03:00
|
|
|
) -> bool:
|
2025-07-28 10:42:40 +08:00
|
|
|
model_cls, _ = self.inspect_model_cls(architectures, model_config)
|
2024-12-01 08:02:54 +08:00
|
|
|
return model_cls.supports_cross_encoding
|
2024-11-24 23:56:20 -03:00
|
|
|
|
2024-10-11 19:08:11 +08:00
|
|
|
def is_multimodal_model(
|
|
|
|
|
self,
|
2025-05-15 06:06:50 +01:00
|
|
|
architectures: str | list[str],
|
2025-07-28 10:42:40 +08:00
|
|
|
model_config: ModelConfig,
|
2024-10-11 19:08:11 +08:00
|
|
|
) -> bool:
|
2025-07-28 10:42:40 +08:00
|
|
|
model_cls, _ = self.inspect_model_cls(architectures, model_config)
|
2024-12-01 08:02:54 +08:00
|
|
|
return model_cls.supports_multimodal
|
2024-10-11 19:08:11 +08:00
|
|
|
|
2025-08-28 01:01:50 +08:00
|
|
|
def is_multimodal_raw_input_only_model(
|
2025-07-23 19:00:23 +01:00
|
|
|
self,
|
|
|
|
|
architectures: str | list[str],
|
2025-07-28 10:42:40 +08:00
|
|
|
model_config: ModelConfig,
|
2025-07-23 19:00:23 +01:00
|
|
|
) -> bool:
|
2025-07-28 10:42:40 +08:00
|
|
|
model_cls, _ = self.inspect_model_cls(architectures, model_config)
|
2025-08-28 01:01:50 +08:00
|
|
|
return model_cls.supports_multimodal_raw_input_only
|
2025-07-23 19:00:23 +01:00
|
|
|
|
2024-10-11 19:08:11 +08:00
|
|
|
def is_pp_supported_model(
|
|
|
|
|
self,
|
2025-05-15 06:06:50 +01:00
|
|
|
architectures: str | list[str],
|
2025-07-28 10:42:40 +08:00
|
|
|
model_config: ModelConfig,
|
2024-10-11 19:08:11 +08:00
|
|
|
) -> bool:
|
2025-07-28 10:42:40 +08:00
|
|
|
model_cls, _ = self.inspect_model_cls(architectures, model_config)
|
2024-12-01 08:02:54 +08:00
|
|
|
return model_cls.supports_pp
|
2024-10-11 19:08:11 +08:00
|
|
|
|
2024-12-01 08:02:54 +08:00
|
|
|
def model_has_inner_state(
|
|
|
|
|
self,
|
2025-05-15 06:06:50 +01:00
|
|
|
architectures: str | list[str],
|
2025-07-28 10:42:40 +08:00
|
|
|
model_config: ModelConfig,
|
2024-12-01 08:02:54 +08:00
|
|
|
) -> bool:
|
2025-07-28 10:42:40 +08:00
|
|
|
model_cls, _ = self.inspect_model_cls(architectures, model_config)
|
2024-12-01 08:02:54 +08:00
|
|
|
return model_cls.has_inner_state
|
2024-10-11 11:40:06 -04:00
|
|
|
|
2024-12-01 08:02:54 +08:00
|
|
|
def is_attention_free_model(
|
|
|
|
|
self,
|
2025-05-15 06:06:50 +01:00
|
|
|
architectures: str | list[str],
|
2025-07-28 10:42:40 +08:00
|
|
|
model_config: ModelConfig,
|
2024-12-01 08:02:54 +08:00
|
|
|
) -> bool:
|
2025-07-28 10:42:40 +08:00
|
|
|
model_cls, _ = self.inspect_model_cls(architectures, model_config)
|
2024-12-01 08:02:54 +08:00
|
|
|
return model_cls.is_attention_free
|
2024-10-11 11:40:06 -04:00
|
|
|
|
2024-12-11 04:53:37 +02:00
|
|
|
def is_hybrid_model(
|
|
|
|
|
self,
|
2025-05-15 06:06:50 +01:00
|
|
|
architectures: str | list[str],
|
2025-07-28 10:42:40 +08:00
|
|
|
model_config: ModelConfig,
|
2024-12-11 04:53:37 +02:00
|
|
|
) -> bool:
|
2025-07-28 10:42:40 +08:00
|
|
|
model_cls, _ = self.inspect_model_cls(architectures, model_config)
|
2024-12-11 04:53:37 +02:00
|
|
|
return model_cls.is_hybrid
|
|
|
|
|
|
2025-03-31 15:35:14 +03:00
|
|
|
def is_noops_model(
|
|
|
|
|
self,
|
2025-05-15 06:06:50 +01:00
|
|
|
architectures: str | list[str],
|
2025-07-28 10:42:40 +08:00
|
|
|
model_config: ModelConfig,
|
2025-03-31 15:35:14 +03:00
|
|
|
) -> bool:
|
2025-07-28 10:42:40 +08:00
|
|
|
model_cls, _ = self.inspect_model_cls(architectures, model_config)
|
2025-03-31 15:35:14 +03:00
|
|
|
return model_cls.has_noops
|
|
|
|
|
|
2025-02-13 16:23:45 +01:00
|
|
|
def is_transcription_model(
|
|
|
|
|
self,
|
2025-05-15 06:06:50 +01:00
|
|
|
architectures: str | list[str],
|
2025-07-28 10:42:40 +08:00
|
|
|
model_config: ModelConfig,
|
2025-02-13 16:23:45 +01:00
|
|
|
) -> bool:
|
2025-07-28 10:42:40 +08:00
|
|
|
model_cls, _ = self.inspect_model_cls(architectures, model_config)
|
2025-02-13 16:23:45 +01:00
|
|
|
return model_cls.supports_transcription
|
|
|
|
|
|
2025-07-13 04:40:11 +02:00
|
|
|
def is_transcription_only_model(
|
|
|
|
|
self,
|
|
|
|
|
architectures: str | list[str],
|
2025-07-28 10:42:40 +08:00
|
|
|
model_config: ModelConfig,
|
2025-07-13 04:40:11 +02:00
|
|
|
) -> bool:
|
2025-07-28 10:42:40 +08:00
|
|
|
model_cls, _ = self.inspect_model_cls(architectures, model_config)
|
2025-07-13 04:40:11 +02:00
|
|
|
return model_cls.supports_transcription_only
|
|
|
|
|
|
2024-10-11 19:08:11 +08:00
|
|
|
|
|
|
|
|
ModelRegistry = _ModelRegistry(
|
|
|
|
|
{
|
2025-01-28 00:23:08 +00:00
|
|
|
model_arch: _LazyRegisteredModel(
|
2024-10-11 19:08:11 +08:00
|
|
|
module_name=f"vllm.model_executor.models.{mod_relname}",
|
|
|
|
|
class_name=cls_name,
|
|
|
|
|
)
|
|
|
|
|
for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
_T = TypeVar("_T")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
|
2024-10-29 13:08:20 +08:00
|
|
|
# NOTE: We use a temporary directory instead of a temporary file to avoid
|
|
|
|
|
# issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
|
|
|
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
|
|
|
output_filepath = os.path.join(tempdir, "registry_output.tmp")
|
|
|
|
|
|
2024-10-11 19:08:11 +08:00
|
|
|
# `cloudpickle` allows pickling lambda functions directly
|
2025-07-10 16:02:40 +01:00
|
|
|
import cloudpickle
|
2025-10-05 15:06:22 +01:00
|
|
|
|
2024-10-29 13:08:20 +08:00
|
|
|
input_bytes = cloudpickle.dumps((fn, output_filepath))
|
2024-10-11 19:08:11 +08:00
|
|
|
|
|
|
|
|
# cannot use `sys.executable __file__` here because the script
|
|
|
|
|
# contains relative imports
|
2025-02-12 02:36:10 -08:00
|
|
|
returned = subprocess.run(
|
|
|
|
|
_SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
|
|
|
|
|
)
|
2024-10-11 19:08:11 +08:00
|
|
|
|
|
|
|
|
# check if the subprocess is successful
|
|
|
|
|
try:
|
|
|
|
|
returned.check_returncode()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
# wrap raised exception to provide more information
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f"Error raised in subprocess:\n{returned.stderr.decode()}"
|
|
|
|
|
) from e
|
|
|
|
|
|
2024-10-29 13:08:20 +08:00
|
|
|
with open(output_filepath, "rb") as f:
|
2024-10-11 19:08:11 +08:00
|
|
|
return pickle.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _run() -> None:
|
|
|
|
|
# Setup plugins
|
|
|
|
|
from vllm.plugins import load_general_plugins
|
2025-10-05 15:06:22 +01:00
|
|
|
|
2024-10-11 19:08:11 +08:00
|
|
|
load_general_plugins()
|
|
|
|
|
|
|
|
|
|
fn, output_file = pickle.loads(sys.stdin.buffer.read())
|
|
|
|
|
|
|
|
|
|
result = fn()
|
2024-10-09 21:58:27 -07:00
|
|
|
|
|
|
|
|
with open(output_file, "wb") as f:
|
|
|
|
|
f.write(pickle.dumps(result))
|
2024-10-11 19:08:11 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2024-11-24 23:56:20 -03:00
|
|
|
_run()
|