[Model] Introduce verify_and_update_model_config for VerifyAndUpdateConfig. (#31131)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
@@ -13,7 +13,7 @@ from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -21,20 +21,24 @@ logger = init_logger(__name__)
|
||||
class VerifyAndUpdateConfig:
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
raise NotImplementedError
|
||||
return
|
||||
|
||||
|
||||
class Gemma3TextModelConfig:
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
return
|
||||
|
||||
|
||||
class Gemma3TextModelConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
hf_config = model_config.hf_config
|
||||
hf_config.is_causal = not hf_config.use_bidirectional_attention
|
||||
|
||||
|
||||
class GteNewModelConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
config = model_config.hf_config
|
||||
|
||||
assert config.__class__.__name__ == "NewConfig"
|
||||
assert config.hidden_act == "gelu"
|
||||
@@ -53,16 +57,15 @@ class GteNewModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
pooler_config = model_config.pooler_config
|
||||
if pooler_config.use_activation is None:
|
||||
pooler_config.use_activation = False
|
||||
|
||||
|
||||
class JinaRobertaModelConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
model_config = vllm_config.model_config
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
config = model_config.hf_config
|
||||
|
||||
if config.position_embedding_type == "rotary":
|
||||
@@ -90,10 +93,10 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
from vllm.config.pooler import PoolingTypeStr
|
||||
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
hf_config = model_config.hf_config
|
||||
hf_config.is_causal = False
|
||||
|
||||
pooling_type_map: dict[str, PoolingTypeStr] = {
|
||||
@@ -105,7 +108,7 @@ class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
|
||||
pooling_type = pooling_type_map.get(hf_config.pooling, None)
|
||||
if pooling_type is None:
|
||||
raise ValueError(f"pool_type {hf_config.pooling} not supported")
|
||||
vllm_config.model_config.pooler_config.pooling_type = pooling_type
|
||||
model_config.pooler_config.pooling_type = pooling_type
|
||||
|
||||
|
||||
class NomicBertModelConfig(VerifyAndUpdateConfig):
|
||||
@@ -204,8 +207,8 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
pooler_config = model_config.pooler_config
|
||||
|
||||
if pooler_config.step_tag_id is None:
|
||||
pooler_config.step_tag_id = 151651
|
||||
@@ -213,8 +216,8 @@ class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
pooler_config = model_config.pooler_config
|
||||
|
||||
if pooler_config.softmax is None:
|
||||
pooler_config.softmax = False
|
||||
@@ -222,8 +225,8 @@ class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
config = model_config.hf_config
|
||||
|
||||
is_original_qwen3_reranker = getattr(
|
||||
config, "is_original_qwen3_reranker", False
|
||||
@@ -237,23 +240,23 @@ class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||
"Try loading the original Qwen3 Reranker?, see: "
|
||||
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/offline_reranker.py"
|
||||
)
|
||||
vllm_config.model_config.hf_config.method = "from_2_way_softmax"
|
||||
model_config.hf_config.method = "from_2_way_softmax"
|
||||
|
||||
|
||||
class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
config = model_config.hf_config
|
||||
config.num_labels = 1
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
pooler_config = model_config.pooler_config
|
||||
if pooler_config.logit_bias is None:
|
||||
pooler_config.logit_bias = 2.65
|
||||
|
||||
|
||||
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
config = model_config.hf_config
|
||||
|
||||
assert config.__class__.__name__ == "GteConfig"
|
||||
assert config.hidden_act == "gelu"
|
||||
|
||||
Reference in New Issue
Block a user