[Refactor] Separate sequence and token pooling types (#32026)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -10,9 +10,7 @@ from pathlib import Path
|
||||
from typing import Any, Literal, TypeAlias
|
||||
|
||||
import huggingface_hub
|
||||
from huggingface_hub import (
|
||||
get_safetensors_metadata,
|
||||
)
|
||||
from huggingface_hub import get_safetensors_metadata
|
||||
from packaging.version import Version
|
||||
from transformers import GenerationConfig, PretrainedConfig
|
||||
from transformers.models.auto.image_processing_auto import get_image_processor_config
|
||||
@@ -742,7 +740,10 @@ def get_config(
|
||||
|
||||
|
||||
@cache
|
||||
def get_pooling_config(model: str, revision: str | None = "main") -> dict | None:
|
||||
def get_pooling_config(
|
||||
model: str,
|
||||
revision: str | None = "main",
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
This function gets the pooling and normalize
|
||||
config from the model - only applies to
|
||||
@@ -793,38 +794,40 @@ def get_pooling_config(model: str, revision: str | None = "main") -> dict | None
|
||||
)
|
||||
|
||||
if pooling:
|
||||
pooling_file_name = "{}/config.json".format(pooling["path"])
|
||||
pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision)
|
||||
pooling_type_name = next(
|
||||
(item for item, val in pooling_dict.items() if val is True), None
|
||||
)
|
||||
from vllm.config.pooler import SEQ_POOLING_TYPES, TOK_POOLING_TYPES
|
||||
|
||||
if pooling_type_name is not None:
|
||||
pooling_type_name = get_pooling_config_name(pooling_type_name)
|
||||
pooling_file_name = "{}/config.json".format(pooling["path"])
|
||||
pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision) or {}
|
||||
|
||||
logger.info("Found pooling configuration.")
|
||||
return {"pooling_type": pooling_type_name, "normalize": normalize}
|
||||
|
||||
config: dict[str, Any] = {"normalize": normalize}
|
||||
for key, val in pooling_dict.items():
|
||||
if val is True:
|
||||
pooling_type = parse_pooling_type(key)
|
||||
if pooling_type in SEQ_POOLING_TYPES:
|
||||
config["seq_pooling_type"] = pooling_type
|
||||
elif pooling_type in TOK_POOLING_TYPES:
|
||||
config["tok_pooling_type"] = pooling_type
|
||||
else:
|
||||
logger.debug("Skipping unrelated field: %r=%r", key, val)
|
||||
|
||||
return config
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_pooling_config_name(pooling_name: str) -> str | None:
|
||||
def parse_pooling_type(pooling_name: str):
|
||||
if "pooling_mode_" in pooling_name:
|
||||
pooling_name = pooling_name.replace("pooling_mode_", "")
|
||||
|
||||
if "_" in pooling_name:
|
||||
pooling_name = pooling_name.split("_")[0]
|
||||
pooling_name = pooling_name.split("_", 1)[0]
|
||||
|
||||
if "lasttoken" in pooling_name:
|
||||
pooling_name = "last"
|
||||
|
||||
supported_pooling_types = ["LAST", "ALL", "CLS", "STEP", "MEAN"]
|
||||
pooling_type_name = pooling_name.upper()
|
||||
|
||||
if pooling_type_name in supported_pooling_types:
|
||||
return pooling_type_name
|
||||
|
||||
raise NotImplementedError(f"Pooling type {pooling_type_name} not supported")
|
||||
return pooling_name.upper()
|
||||
|
||||
|
||||
@cache
|
||||
|
||||
Reference in New Issue
Block a user