[Refactor] Separate sequence and token pooling types (#32026)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-10 12:53:24 +08:00
committed by GitHub
parent 52d428295d
commit 583a90e005
42 changed files with 324 additions and 204 deletions

View File

@@ -10,9 +10,7 @@ from pathlib import Path
from typing import Any, Literal, TypeAlias
import huggingface_hub
from huggingface_hub import (
get_safetensors_metadata,
)
from huggingface_hub import get_safetensors_metadata
from packaging.version import Version
from transformers import GenerationConfig, PretrainedConfig
from transformers.models.auto.image_processing_auto import get_image_processor_config
@@ -742,7 +740,10 @@ def get_config(
@cache
def get_pooling_config(model: str, revision: str | None = "main") -> dict | None:
def get_pooling_config(
model: str,
revision: str | None = "main",
) -> dict[str, Any] | None:
"""
This function gets the pooling and normalize
config from the model - only applies to
@@ -793,38 +794,40 @@ def get_pooling_config(model: str, revision: str | None = "main") -> dict | None
)
if pooling:
pooling_file_name = "{}/config.json".format(pooling["path"])
pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision)
pooling_type_name = next(
(item for item, val in pooling_dict.items() if val is True), None
)
from vllm.config.pooler import SEQ_POOLING_TYPES, TOK_POOLING_TYPES
if pooling_type_name is not None:
pooling_type_name = get_pooling_config_name(pooling_type_name)
pooling_file_name = "{}/config.json".format(pooling["path"])
pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision) or {}
logger.info("Found pooling configuration.")
return {"pooling_type": pooling_type_name, "normalize": normalize}
config: dict[str, Any] = {"normalize": normalize}
for key, val in pooling_dict.items():
if val is True:
pooling_type = parse_pooling_type(key)
if pooling_type in SEQ_POOLING_TYPES:
config["seq_pooling_type"] = pooling_type
elif pooling_type in TOK_POOLING_TYPES:
config["tok_pooling_type"] = pooling_type
else:
logger.debug("Skipping unrelated field: %r=%r", key, val)
return config
return None
def get_pooling_config_name(pooling_name: str) -> str | None:
def parse_pooling_type(pooling_name: str):
if "pooling_mode_" in pooling_name:
pooling_name = pooling_name.replace("pooling_mode_", "")
if "_" in pooling_name:
pooling_name = pooling_name.split("_")[0]
pooling_name = pooling_name.split("_", 1)[0]
if "lasttoken" in pooling_name:
pooling_name = "last"
supported_pooling_types = ["LAST", "ALL", "CLS", "STEP", "MEAN"]
pooling_type_name = pooling_name.upper()
if pooling_type_name in supported_pooling_types:
return pooling_type_name
raise NotImplementedError(f"Pooling type {pooling_type_name} not supported")
return pooling_name.upper()
@cache