Order config.py in Lexicographical order (#35866)

Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Co-authored-by: Andrii Skliar <askliar@nvidia.com>
This commit is contained in:
Andrii Skliar
2026-03-05 05:56:47 +01:00
committed by GitHub
parent dd6dbd93f8
commit 0a12cea25f

View File

@@ -28,305 +28,26 @@ class VerifyAndUpdateConfig:
return
class Gemma3TextModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
hf_config = model_config.hf_config
hf_config.is_causal = not hf_config.use_bidirectional_attention
class GteNewModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
config = model_config.hf_config
assert config.__class__.__name__ == "NewConfig"
assert config.hidden_act == "gelu"
config.hidden_act = "geglu"
head_dim = config.hidden_size // config.num_attention_heads
rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
config.rotary_kwargs = {
"head_size": head_dim,
"max_position": config.max_position_embeddings,
"rope_parameters": config.rope_parameters,
}
class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
pooler_config = model_config.pooler_config
if pooler_config.use_activation is None:
pooler_config.use_activation = False
class JinaRobertaModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
config = model_config.hf_config
if config.position_embedding_type == "rotary":
assert config.__class__.__name__ == "XLMRobertaFlashConfig"
head_dim = config.hidden_size // config.num_attention_heads
max_position = config.max_position_embeddings
# Jina-embeddings-v3 has max_position_embeddings=8194, which will cause
# out-of-bound index issue at RoPE for long prompts with torch.compile,
# because it can't be divided by triton num_warps(default=4 or 8).
# To deal with this, we increase max_position to multiple of n_warps,
# so that triton kernel won't hit out-of-bound index in RoPE cache.
if not model_config.enforce_eager:
max_position = round_up(max_position, 8)
rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
config.rotary_kwargs = {
"head_size": head_dim,
"max_position": max_position,
"rope_parameters": config.rope_parameters,
}
class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
from vllm.config.pooler import SequencePoolingType
hf_config = model_config.hf_config
hf_config.is_causal = False
pooling_type_map: dict[str, SequencePoolingType] = {
"avg": "MEAN",
"cls": "CLS",
"last": "LAST",
}
pooling_type = pooling_type_map.get(hf_config.pooling, None)
if pooling_type is None:
raise ValueError(f"pool_type {hf_config.pooling!r} not supported")
model_config.pooler_config.seq_pooling_type = pooling_type
class LlamaNemotronVLConfig(VerifyAndUpdateConfig):
"""Config handler for LlamaNemotronVL embedding models."""
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
from vllm.config.pooler import SequencePoolingType
hf_config = model_config.hf_config
# Set bidirectional attention on the language model config
hf_config.is_causal = False
if hasattr(hf_config, "llm_config"):
hf_config.llm_config.is_causal = False
if hasattr(hf_config, "vision_config"):
hf_config.patch_size = hf_config.vision_config.patch_size
# Set up pooling type
pooling_type_map: dict[str, SequencePoolingType] = {
"avg": "MEAN",
"cls": "CLS",
"last": "LAST",
}
# Get pooling type from config (check both top-level and llm_config)
pooling = getattr(hf_config, "pooling", None)
if pooling is None and hasattr(hf_config, "llm_config"):
pooling = getattr(hf_config.llm_config, "pooling", "avg")
pooling_type = pooling_type_map.get(pooling)
if pooling_type is None:
raise ValueError(f"pool_type {pooling!r} not supported")
model_config.pooler_config.seq_pooling_type = pooling_type
class NomicBertModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
config = model_config.hf_config
assert config.__class__.__name__ == "NomicBertConfig"
assert config.activation_function in ["swiglu", "gelu"]
config.position_embedding_type = getattr(
config, "position_embedding_type", "rope"
)
if config.activation_function == "swiglu":
config.hidden_act = "silu"
else:
config.hidden_act = config.activation_function
assert config.mlp_fc1_bias == config.mlp_fc2_bias == config.qkv_proj_bias
config.bias = config.qkv_proj_bias
assert config.rotary_emb_scale_base is None
assert not config.rotary_emb_interleaved
config.layer_norm_eps = config.layer_norm_epsilon
config.intermediate_size = config.n_inner
config.hidden_size = config.n_embd
config.num_hidden_layers = config.n_layer
model_config.model_arch_config.hidden_size = config.hidden_size
model_config.model_arch_config.total_num_hidden_layers = (
config.num_hidden_layers
)
head_dim = config.hidden_size // config.num_attention_heads
max_trained_positions = getattr(config, "max_trained_positions", 2048)
config.rotary_kwargs = {
"head_size": head_dim,
"max_position": max_trained_positions,
"rope_parameters": config.rope_parameters,
}
# we ignore config.rotary_scaling_factor so that for datasets shorter
# than max_trained_positions 2048, the results are consistent
# with SentenceTransformer.
# The context extension uses vllm style rope_theta and rope_parameters.
# See #17785 #18755
if (
not model_config.hf_overrides
and model_config.original_max_model_len is None
):
# Default
# Reset max_model_len to max_trained_positions.
# nomic-embed-text-v2-moe the length is set to 512
# by sentence_bert_config.json.
max_model_len_before = model_config.max_model_len
max_model_len = min(model_config.max_model_len, max_trained_positions)
model_config.max_model_len = model_config.get_and_verify_max_len(
max_model_len
)
if model_config.max_model_len != max_model_len_before:
logger.warning(
"Nomic context extension is disabled. "
"Changing max_model_len from %s to %s. "
"To enable context extension, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.py",
max_model_len_before,
model_config.max_model_len,
)
else:
# We need to re-verify max_model_len to avoid lengths
# greater than position_embedding.
hf_text_config = model_config.hf_text_config
if isinstance(model_config.hf_overrides, dict):
# hf_overrides_kw
max_model_len = model_config.hf_overrides.get(
"max_model_len", model_config.max_model_len
)
else:
# hf_overrides_fn
# This might be overridden by sentence_bert_config.json.
max_model_len = model_config.max_model_len
# reset hf_text_config for recalculate_max_model_len.
if hasattr(hf_text_config, "max_model_len"):
delattr(hf_text_config, "max_model_len")
hf_text_config.max_position_embeddings = max_trained_positions
hf_text_config.rope_parameters = config.rotary_kwargs["rope_parameters"]
# Update the cached derived_max_model_len to enforce the limit
model_config.model_arch_config.derived_max_model_len_and_key = (
float(max_trained_positions),
"max_position_embeddings",
)
# The priority of sentence_bert_config.json is higher
# than max_position_embeddings
encoder_config = deepcopy(model_config.encoder_config)
encoder_config.pop("max_seq_length", None)
model_config.encoder_config = encoder_config
model_config.max_model_len = model_config.get_and_verify_max_len(
max_model_len
)
class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
pooler_config = model_config.pooler_config
if pooler_config.step_tag_id is None:
pooler_config.step_tag_id = 151651
class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
pooler_config = model_config.pooler_config
if pooler_config.use_activation is None:
pooler_config.use_activation = False
class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
config = model_config.hf_config
is_original_qwen3_reranker = getattr(
config, "is_original_qwen3_reranker", False
)
if not is_original_qwen3_reranker:
return
tokens = getattr(config, "classifier_from_token", None)
assert tokens is not None and len(tokens) == 2, (
"Try loading the original Qwen3 Reranker?, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/pooling/score/qwen3_reranker_offline.py"
)
text_config = config.get_text_config()
text_config.method = "from_2_way_softmax"
text_config.classifier_from_token = tokens
class Qwen3VLForSequenceClassificationConfig(Qwen3ForSequenceClassificationConfig):
pass
class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
config = model_config.hf_config
config.num_labels = 1
pooler_config = model_config.pooler_config
if pooler_config.logit_bias is None:
pooler_config.logit_bias = 2.65
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
config = model_config.hf_config
assert config.__class__.__name__ == "GteConfig"
assert config.hidden_act == "gelu"
config.hidden_act = "geglu"
head_dim = config.hidden_size // config.num_attention_heads
rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
config.rotary_kwargs = {
"head_size": head_dim,
"max_position": config.max_position_embeddings,
"rope_parameters": config.rope_parameters,
}
class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
@classmethod
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"""
Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32
"""
hf_config = vllm_config.model_config.hf_config
# Mirror the check in vllm/model_executor/models/deepseek_v2.py
is_v32 = hasattr(hf_config, "index_topk")
assert is_v32
# For DeepSeekV3.2, a custom fp8 format is used when fp8 kv-cache is enabled.
cache_config = vllm_config.cache_config
if cache_config.cache_dtype.startswith("fp8"):
cache_config.cache_dtype = "fp8_ds_mla"
logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
if cache_config.cache_dtype == "bfloat16":
cache_config.cache_dtype = "auto"
logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
class Ernie4_5_VLMoeForConditionalGenerationConfig(VerifyAndUpdateConfig):
@@ -337,6 +58,13 @@ class Ernie4_5_VLMoeForConditionalGenerationConfig(VerifyAndUpdateConfig):
vllm_config.compilation_config.fast_moe_cold_start = False
class Gemma3TextModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
hf_config = model_config.hf_config
hf_config.is_causal = not hf_config.use_bidirectional_attention
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
@@ -360,64 +88,24 @@ class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
)
class MambaModelConfig(VerifyAndUpdateConfig):
@classmethod
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"""
Enable FULL_AND_PIECEWISE cuda graph mode by default (required
to get good performance for mamba layers in V1).
class GteNewModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
config = model_config.hf_config
Args:
vllm_config: vLLM Config
"""
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
assert config.__class__.__name__ == "NewConfig"
assert config.hidden_act == "gelu"
if cache_config.enable_prefix_caching:
if cache_config.mamba_cache_mode == "none":
cache_config.mamba_cache_mode = (
"all" if model_config.supports_mamba_prefix_caching else "align"
)
logger.warning(
"Mamba cache mode is set to '%s' for %s by default "
"when prefix caching is enabled",
cache_config.mamba_cache_mode,
model_config.architecture,
)
if (
cache_config.mamba_cache_mode == "all"
and not model_config.supports_mamba_prefix_caching
):
cache_config.mamba_cache_mode = "align"
logger.warning(
"Hybrid or mamba-based model detected without support "
"for prefix caching with Mamba cache 'all' mode: "
"falling back to 'align' mode."
)
if cache_config.mamba_cache_mode == "align":
assert vllm_config.scheduler_config.enable_chunked_prefill, (
"Chunked prefill is required for mamba cache mode 'align'."
)
logger.info(
"Warning: Prefix caching in Mamba cache '%s' "
"mode is currently enabled. "
"Its support for Mamba layers is experimental. "
"Please report any issues you may observe.",
cache_config.mamba_cache_mode,
)
# By default, mamba block size will be set to max_model_len (see
# below). When enabling prefix caching, we align mamba block size
# to the block size as the basic granularity for prefix caching.
if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = cache_config.block_size
else:
if cache_config.mamba_cache_mode != "none":
cache_config.mamba_cache_mode = "none"
logger.warning(
"Mamba cache mode is set to 'none' when prefix caching is disabled"
)
if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = model_config.max_model_len
config.hidden_act = "geglu"
head_dim = config.hidden_size // config.num_attention_heads
rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
config.rotary_kwargs = {
"head_size": head_dim,
"max_position": config.max_position_embeddings,
"rope_parameters": config.rope_parameters,
}
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
@@ -580,26 +268,167 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
)
class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
pooler_config = model_config.pooler_config
if pooler_config.use_activation is None:
pooler_config.use_activation = False
class JinaRobertaModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
config = model_config.hf_config
if config.position_embedding_type == "rotary":
assert config.__class__.__name__ == "XLMRobertaFlashConfig"
head_dim = config.hidden_size // config.num_attention_heads
max_position = config.max_position_embeddings
# Jina-embeddings-v3 has max_position_embeddings=8194, which will cause
# out-of-bound index issue at RoPE for long prompts with torch.compile,
# because it can't be divided by triton num_warps(default=4 or 8).
# To deal with this, we increase max_position to multiple of n_warps,
# so that triton kernel won't hit out-of-bound index in RoPE cache.
if not model_config.enforce_eager:
max_position = round_up(max_position, 8)
rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
config.rotary_kwargs = {
"head_size": head_dim,
"max_position": max_position,
"rope_parameters": config.rope_parameters,
}
class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
config = model_config.hf_config
config.num_labels = 1
pooler_config = model_config.pooler_config
if pooler_config.logit_bias is None:
pooler_config.logit_bias = 2.65
class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
from vllm.config.pooler import SequencePoolingType
hf_config = model_config.hf_config
hf_config.is_causal = False
pooling_type_map: dict[str, SequencePoolingType] = {
"avg": "MEAN",
"cls": "CLS",
"last": "LAST",
}
pooling_type = pooling_type_map.get(hf_config.pooling, None)
if pooling_type is None:
raise ValueError(f"pool_type {hf_config.pooling!r} not supported")
model_config.pooler_config.seq_pooling_type = pooling_type
class LlamaNemotronVLConfig(VerifyAndUpdateConfig):
"""Config handler for LlamaNemotronVL embedding models."""
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
from vllm.config.pooler import SequencePoolingType
hf_config = model_config.hf_config
# Set bidirectional attention on the language model config
hf_config.is_causal = False
if hasattr(hf_config, "llm_config"):
hf_config.llm_config.is_causal = False
if hasattr(hf_config, "vision_config"):
hf_config.patch_size = hf_config.vision_config.patch_size
# Set up pooling type
pooling_type_map: dict[str, SequencePoolingType] = {
"avg": "MEAN",
"cls": "CLS",
"last": "LAST",
}
# Get pooling type from config (check both top-level and llm_config)
pooling = getattr(hf_config, "pooling", None)
if pooling is None and hasattr(hf_config, "llm_config"):
pooling = getattr(hf_config.llm_config, "pooling", "avg")
pooling_type = pooling_type_map.get(pooling)
if pooling_type is None:
raise ValueError(f"pool_type {pooling!r} not supported")
model_config.pooler_config.seq_pooling_type = pooling_type
class MambaModelConfig(VerifyAndUpdateConfig):
@classmethod
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"""
Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32
Enable FULL_AND_PIECEWISE cuda graph mode by default (required
to get good performance for mamba layers in V1).
Args:
vllm_config: vLLM Config
"""
hf_config = vllm_config.model_config.hf_config
# Mirror the check in vllm/model_executor/models/deepseek_v2.py
is_v32 = hasattr(hf_config, "index_topk")
assert is_v32
# For DeepSeekV3.2, a custom fp8 format is used when fp8 kv-cache is enabled.
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
if cache_config.cache_dtype.startswith("fp8"):
cache_config.cache_dtype = "fp8_ds_mla"
logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
if cache_config.cache_dtype == "bfloat16":
cache_config.cache_dtype = "auto"
logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
if cache_config.enable_prefix_caching:
if cache_config.mamba_cache_mode == "none":
cache_config.mamba_cache_mode = (
"all" if model_config.supports_mamba_prefix_caching else "align"
)
logger.warning(
"Mamba cache mode is set to '%s' for %s by default "
"when prefix caching is enabled",
cache_config.mamba_cache_mode,
model_config.architecture,
)
if (
cache_config.mamba_cache_mode == "all"
and not model_config.supports_mamba_prefix_caching
):
cache_config.mamba_cache_mode = "align"
logger.warning(
"Hybrid or mamba-based model detected without support "
"for prefix caching with Mamba cache 'all' mode: "
"falling back to 'align' mode."
)
if cache_config.mamba_cache_mode == "align":
assert vllm_config.scheduler_config.enable_chunked_prefill, (
"Chunked prefill is required for mamba cache mode 'align'."
)
logger.info(
"Warning: Prefix caching in Mamba cache '%s' "
"mode is currently enabled. "
"Its support for Mamba layers is experimental. "
"Please report any issues you may observe.",
cache_config.mamba_cache_mode,
)
# By default, mamba block size will be set to max_model_len (see
# below). When enabling prefix caching, we align mamba block size
# to the block size as the basic granularity for prefix caching.
if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = cache_config.block_size
else:
if cache_config.mamba_cache_mode != "none":
cache_config.mamba_cache_mode = "none"
logger.warning(
"Mamba cache mode is set to 'none' when prefix caching is disabled"
)
if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = model_config.max_model_len
class NemotronHForCausalLMConfig(VerifyAndUpdateConfig):
@@ -631,6 +460,157 @@ class NemotronHNanoVLV2Config(VerifyAndUpdateConfig):
video_kwargs.setdefault("video_backend", "nemotron_vl")
class NomicBertModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
config = model_config.hf_config
assert config.__class__.__name__ == "NomicBertConfig"
assert config.activation_function in ["swiglu", "gelu"]
config.position_embedding_type = getattr(
config, "position_embedding_type", "rope"
)
if config.activation_function == "swiglu":
config.hidden_act = "silu"
else:
config.hidden_act = config.activation_function
assert config.mlp_fc1_bias == config.mlp_fc2_bias == config.qkv_proj_bias
config.bias = config.qkv_proj_bias
assert config.rotary_emb_scale_base is None
assert not config.rotary_emb_interleaved
config.layer_norm_eps = config.layer_norm_epsilon
config.intermediate_size = config.n_inner
config.hidden_size = config.n_embd
config.num_hidden_layers = config.n_layer
model_config.model_arch_config.hidden_size = config.hidden_size
model_config.model_arch_config.total_num_hidden_layers = (
config.num_hidden_layers
)
head_dim = config.hidden_size // config.num_attention_heads
max_trained_positions = getattr(config, "max_trained_positions", 2048)
config.rotary_kwargs = {
"head_size": head_dim,
"max_position": max_trained_positions,
"rope_parameters": config.rope_parameters,
}
# we ignore config.rotary_scaling_factor so that for datasets shorter
# than max_trained_positions 2048, the results are consistent
# with SentenceTransformer.
# The context extension uses vllm style rope_theta and rope_parameters.
# See #17785 #18755
if (
not model_config.hf_overrides
and model_config.original_max_model_len is None
):
# Default
# Reset max_model_len to max_trained_positions.
# nomic-embed-text-v2-moe the length is set to 512
# by sentence_bert_config.json.
max_model_len_before = model_config.max_model_len
max_model_len = min(model_config.max_model_len, max_trained_positions)
model_config.max_model_len = model_config.get_and_verify_max_len(
max_model_len
)
if model_config.max_model_len != max_model_len_before:
logger.warning(
"Nomic context extension is disabled. "
"Changing max_model_len from %s to %s. "
"To enable context extension, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.py",
max_model_len_before,
model_config.max_model_len,
)
else:
# We need to re-verify max_model_len to avoid lengths
# greater than position_embedding.
hf_text_config = model_config.hf_text_config
if isinstance(model_config.hf_overrides, dict):
# hf_overrides_kw
max_model_len = model_config.hf_overrides.get(
"max_model_len", model_config.max_model_len
)
else:
# hf_overrides_fn
# This might be overridden by sentence_bert_config.json.
max_model_len = model_config.max_model_len
# reset hf_text_config for recalculate_max_model_len.
if hasattr(hf_text_config, "max_model_len"):
delattr(hf_text_config, "max_model_len")
hf_text_config.max_position_embeddings = max_trained_positions
hf_text_config.rope_parameters = config.rotary_kwargs["rope_parameters"]
# Update the cached derived_max_model_len to enforce the limit
model_config.model_arch_config.derived_max_model_len_and_key = (
float(max_trained_positions),
"max_position_embeddings",
)
# The priority of sentence_bert_config.json is higher
# than max_position_embeddings
encoder_config = deepcopy(model_config.encoder_config)
encoder_config.pop("max_seq_length", None)
model_config.encoder_config = encoder_config
model_config.max_model_len = model_config.get_and_verify_max_len(
max_model_len
)
class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
pooler_config = model_config.pooler_config
if pooler_config.step_tag_id is None:
pooler_config.step_tag_id = 151651
class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
pooler_config = model_config.pooler_config
if pooler_config.use_activation is None:
pooler_config.use_activation = False
class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
config = model_config.hf_config
is_original_qwen3_reranker = getattr(
config, "is_original_qwen3_reranker", False
)
if not is_original_qwen3_reranker:
return
tokens = getattr(config, "classifier_from_token", None)
assert tokens is not None and len(tokens) == 2, (
"Try loading the original Qwen3 Reranker?, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/pooling/score/qwen3_reranker_offline.py"
)
text_config = config.get_text_config()
text_config.method = "from_2_way_softmax"
text_config.classifier_from_token = tokens
class Qwen3VLForSequenceClassificationConfig(Qwen3ForSequenceClassificationConfig):
pass
class Qwen3_5ForConditionalGenerationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
@@ -658,6 +638,26 @@ class Qwen3_5ForConditionalGenerationConfig(VerifyAndUpdateConfig):
)
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
config = model_config.hf_config
assert config.__class__.__name__ == "GteConfig"
assert config.hidden_act == "gelu"
config.hidden_act = "geglu"
head_dim = config.hidden_size // config.num_attention_heads
rotary_dim = getattr(config, "rotary_emb_dim", head_dim)
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
config.rotary_kwargs = {
"head_size": head_dim,
"max_position": config.max_position_embeddings,
"rope_parameters": config.rope_parameters,
}
class VoyageQwen3BidirectionalEmbedModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
@@ -666,33 +666,33 @@ class VoyageQwen3BidirectionalEmbedModelConfig(VerifyAndUpdateConfig):
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"GteModel": SnowflakeGteNewModelConfig,
"GteNewModel": GteNewModelConfig,
"GteNewForSequenceClassification": GteNewModelConfig,
"ColBERTJinaRobertaModel": JinaRobertaModelConfig,
"DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
"Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLMoeForConditionalGenerationConfig, # noqa: E501
"FalconMambaForCausalLM": MambaModelConfig,
"Gemma3TextModel": Gemma3TextModelConfig,
"NemotronH_Nano_VL_V2": NemotronHNanoVLV2Config,
"GptOssForCausalLM": GptOssForCausalLMConfig,
"GteModel": SnowflakeGteNewModelConfig,
"GteNewForSequenceClassification": GteNewModelConfig,
"GteNewModel": GteNewModelConfig,
"JambaForSequenceClassification": JambaForSequenceClassificationConfig,
"JinaVLForRanking": JinaVLForSequenceClassificationConfig,
"LlamaBidirectionalForSequenceClassification": LlamaBidirectionalConfig,
"LlamaBidirectionalModel": LlamaBidirectionalConfig,
"LlamaNemotronVLModel": LlamaNemotronVLConfig,
"LlamaNemotronVLForSequenceClassification": LlamaNemotronVLConfig,
"LlamaNemotronVLModel": LlamaNemotronVLConfig,
"Mamba2ForCausalLM": MambaModelConfig,
"MambaForCausalLM": MambaModelConfig,
"NemotronHForCausalLM": NemotronHForCausalLMConfig,
"NemotronHPuzzleForCausalLM": NemotronHForCausalLMConfig,
"NemotronH_Nano_VL_V2": NemotronHNanoVLV2Config,
"NomicBertModel": NomicBertModelConfig,
"Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
"Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
"Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
"Qwen3VLForSequenceClassification": Qwen3VLForSequenceClassificationConfig,
"Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLMoeForConditionalGenerationConfig, # noqa: E501
"XLMRobertaModel": JinaRobertaModelConfig,
"ColBERTJinaRobertaModel": JinaRobertaModelConfig,
"JinaVLForRanking": JinaVLForSequenceClassificationConfig,
"JambaForSequenceClassification": JambaForSequenceClassificationConfig,
"GptOssForCausalLM": GptOssForCausalLMConfig,
"MambaForCausalLM": MambaModelConfig,
"Mamba2ForCausalLM": MambaModelConfig,
"FalconMambaForCausalLM": MambaModelConfig,
"DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
"NemotronHForCausalLM": NemotronHForCausalLMConfig,
"NemotronHPuzzleForCausalLM": NemotronHForCausalLMConfig,
"Qwen3_5ForConditionalGeneration": Qwen3_5ForConditionalGenerationConfig,
"Qwen3_5MoeForConditionalGeneration": Qwen3_5ForConditionalGenerationConfig,
"VoyageQwen3BidirectionalEmbedModel": VoyageQwen3BidirectionalEmbedModelConfig,
"XLMRobertaModel": JinaRobertaModelConfig,
}