[Refactor] Separate sequence and token pooling types (#32026)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-10 12:53:24 +08:00
committed by GitHub
parent 52d428295d
commit 583a90e005
42 changed files with 324 additions and 204 deletions

View File

@@ -94,12 +94,12 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig):
class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
from vllm.config.pooler import PoolingTypeStr
from vllm.config.pooler import SequencePoolingType
hf_config = model_config.hf_config
hf_config.is_causal = False
pooling_type_map: dict[str, PoolingTypeStr] = {
pooling_type_map: dict[str, SequencePoolingType] = {
"avg": "MEAN",
"cls": "CLS",
"last": "LAST",
@@ -107,8 +107,9 @@ class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
pooling_type = pooling_type_map.get(hf_config.pooling, None)
if pooling_type is None:
raise ValueError(f"pool_type {hf_config.pooling} not supported")
model_config.pooler_config.pooling_type = pooling_type
raise ValueError(f"pool_type {hf_config.pooling!r} not supported")
model_config.pooler_config.seq_pooling_type = pooling_type
class NomicBertModelConfig(VerifyAndUpdateConfig):