[Misc] Split up pooling tasks (#10820)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-11 17:28:00 +08:00
committed by GitHub
parent 40766ca1b8
commit 8f10d5e393
27 changed files with 527 additions and 168 deletions

View File

@@ -45,13 +45,27 @@ else:
logger = init_logger(__name__)
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
TaskOption = Literal["auto", "generate", "embedding"]
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
"score", "reward"]
# "draft" is only used internally for speculative decoding
_Task = Literal["generate", "embedding", "draft"]
_ResolvedTask = Literal["generate", "embed", "classify", "score", "reward",
"draft"]
RunnerType = Literal["generate", "pooling", "draft"]
_RUNNER_TASKS: Dict[RunnerType, List[_ResolvedTask]] = {
"generate": ["generate"],
"pooling": ["embed", "classify", "score", "reward"],
"draft": ["draft"],
}
_TASK_RUNNER: Dict[_ResolvedTask, RunnerType] = {
task: runner
for runner, tasks in _RUNNER_TASKS.items() for task in tasks
}
HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig],
PretrainedConfig]]
@@ -144,7 +158,7 @@ class ModelConfig:
def __init__(
self,
model: str,
task: Union[TaskOption, _Task],
task: Union[TaskOption, Literal["draft"]],
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
@@ -295,6 +309,7 @@ class ModelConfig:
supported_tasks, task = self._resolve_task(task, self.hf_config)
self.supported_tasks = supported_tasks
self.task: Final = task
self.pooler_config = self._init_pooler_config(override_pooler_config)
self._verify_quantization()
@@ -323,7 +338,7 @@ class ModelConfig:
override_pooler_config: Optional["PoolerConfig"],
) -> Optional["PoolerConfig"]:
if self.task == "embedding":
if self.runner_type == "pooling":
user_config = override_pooler_config or PoolerConfig()
base_config = get_pooling_config(self.model, self.revision)
@@ -357,60 +372,90 @@ class ModelConfig:
"either 'auto', 'slow' or 'mistral'.")
self.tokenizer_mode = tokenizer_mode
def _get_preferred_task(
self,
architectures: List[str],
supported_tasks: Set[_ResolvedTask],
) -> Optional[_ResolvedTask]:
model_id = self.model
if get_pooling_config(model_id, self.revision):
return "embed"
if ModelRegistry.is_cross_encoder_model(architectures):
return "score"
suffix_to_preferred_task: List[Tuple[str, _ResolvedTask]] = [
# Other models follow this pattern
("ForCausalLM", "generate"),
("ForConditionalGeneration", "generate"),
("ForSequenceClassification", "classify"),
("ChatModel", "generate"),
("LMHeadModel", "generate"),
("EmbeddingModel", "embed"),
("RewardModel", "reward"),
]
_, arch = ModelRegistry.inspect_model_cls(architectures)
for suffix, pref_task in suffix_to_preferred_task:
if arch.endswith(suffix) and pref_task in supported_tasks:
return pref_task
return None
def _resolve_task(
self,
task_option: Union[TaskOption, _Task],
task_option: Union[TaskOption, Literal["draft"]],
hf_config: PretrainedConfig,
) -> Tuple[Set[_Task], _Task]:
) -> Tuple[Set[_ResolvedTask], _ResolvedTask]:
if task_option == "draft":
return {"draft"}, "draft"
architectures = getattr(hf_config, "architectures", [])
task_support: Dict[_Task, bool] = {
runner_support: Dict[RunnerType, bool] = {
# NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them
"generate": ModelRegistry.is_text_generation_model(architectures),
"embedding": ModelRegistry.is_pooling_model(architectures),
"pooling": ModelRegistry.is_pooling_model(architectures),
}
supported_tasks_lst: List[_Task] = [
task for task, is_supported in task_support.items() if is_supported
supported_runner_types_lst: List[RunnerType] = [
runner_type
for runner_type, is_supported in runner_support.items()
if is_supported
]
supported_tasks_lst: List[_ResolvedTask] = [
task for runner_type in supported_runner_types_lst
for task in _RUNNER_TASKS[runner_type]
]
supported_tasks = set(supported_tasks_lst)
if task_option == "auto":
selected_task = next(iter(supported_tasks_lst))
if len(supported_tasks) > 1:
suffix_to_preferred_task: List[Tuple[str, _Task]] = [
# Hardcode the models that are exceptions
("AquilaModel", "generate"),
("ChatGLMModel", "generate"),
# Other models follow this pattern
("ForCausalLM", "generate"),
("ForConditionalGeneration", "generate"),
("ChatModel", "generate"),
("LMHeadModel", "generate"),
("EmbeddingModel", "embedding"),
("RewardModel", "embedding"),
("ForSequenceClassification", "embedding"),
]
info, arch = ModelRegistry.inspect_model_cls(architectures)
for suffix, pref_task in suffix_to_preferred_task:
if arch.endswith(suffix) and pref_task in supported_tasks:
selected_task = pref_task
break
else:
if (arch.endswith("Model")
and info.architecture.endswith("ForCausalLM")
and "embedding" in supported_tasks):
selected_task = "embedding"
if len(supported_tasks_lst) > 1:
preferred_task = self._get_preferred_task(
architectures, supported_tasks)
if preferred_task is not None:
selected_task = preferred_task
logger.info(
"This model supports multiple tasks: %s. "
"Defaulting to '%s'.", supported_tasks, selected_task)
else:
# Aliases
if task_option == "embedding":
preferred_task = self._get_preferred_task(
architectures, supported_tasks)
if preferred_task != "embed":
msg = ("The 'embedding' task will be restricted to "
"embedding models in a future release. Please "
"pass `--task classify`, `--task score`, or "
"`--task reward` explicitly for other pooling "
"models.")
warnings.warn(msg, DeprecationWarning, stacklevel=2)
task_option = preferred_task or "embed"
if task_option not in supported_tasks:
msg = (
f"This model does not support the '{task_option}' task. "
@@ -533,7 +578,7 @@ class ModelConfig:
# Async postprocessor is not necessary with embedding mode
# since there is no token generation
if self.task == "embedding":
if self.runner_type == "pooling":
self.use_async_output_proc = False
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
@@ -750,6 +795,14 @@ class ModelConfig:
architectures = getattr(self.hf_config, "architectures", [])
return ModelRegistry.is_cross_encoder_model(architectures)
@property
def supported_runner_types(self) -> Set[RunnerType]:
return {_TASK_RUNNER[task] for task in self.supported_tasks}
@property
def runner_type(self) -> RunnerType:
return _TASK_RUNNER[self.task]
class CacheConfig:
"""Configuration for the KV cache.
@@ -1096,7 +1149,7 @@ class ParallelConfig:
class SchedulerConfig:
"""Scheduler configuration."""
task: str = "generate" # The task to use the model for.
runner_type: str = "generate" # The runner type to launch for the model.
# Maximum number of tokens to be processed in a single iteration.
max_num_batched_tokens: int = field(default=None) # type: ignore
@@ -1164,11 +1217,11 @@ class SchedulerConfig:
# for higher throughput.
self.max_num_batched_tokens = max(self.max_model_len, 2048)
if self.task == "embedding":
# For embedding, choose specific value for higher throughput
if self.runner_type == "pooling":
# Choose specific value for higher throughput
self.max_num_batched_tokens = max(
self.max_num_batched_tokens,
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS,
_POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
)
if self.is_multimodal_model:
# The value needs to be at least the number of multimodal tokens