[Misc] Split up pooling tasks (#10820)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
137
vllm/config.py
137
vllm/config.py
@@ -45,13 +45,27 @@ else:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
|
||||
_POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
|
||||
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
|
||||
|
||||
TaskOption = Literal["auto", "generate", "embedding"]
|
||||
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
|
||||
"score", "reward"]
|
||||
|
||||
# "draft" is only used internally for speculative decoding
|
||||
_Task = Literal["generate", "embedding", "draft"]
|
||||
_ResolvedTask = Literal["generate", "embed", "classify", "score", "reward",
|
||||
"draft"]
|
||||
|
||||
RunnerType = Literal["generate", "pooling", "draft"]
|
||||
|
||||
_RUNNER_TASKS: Dict[RunnerType, List[_ResolvedTask]] = {
|
||||
"generate": ["generate"],
|
||||
"pooling": ["embed", "classify", "score", "reward"],
|
||||
"draft": ["draft"],
|
||||
}
|
||||
|
||||
_TASK_RUNNER: Dict[_ResolvedTask, RunnerType] = {
|
||||
task: runner
|
||||
for runner, tasks in _RUNNER_TASKS.items() for task in tasks
|
||||
}
|
||||
|
||||
HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig],
|
||||
PretrainedConfig]]
|
||||
@@ -144,7 +158,7 @@ class ModelConfig:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
task: Union[TaskOption, _Task],
|
||||
task: Union[TaskOption, Literal["draft"]],
|
||||
tokenizer: str,
|
||||
tokenizer_mode: str,
|
||||
trust_remote_code: bool,
|
||||
@@ -295,6 +309,7 @@ class ModelConfig:
|
||||
supported_tasks, task = self._resolve_task(task, self.hf_config)
|
||||
self.supported_tasks = supported_tasks
|
||||
self.task: Final = task
|
||||
|
||||
self.pooler_config = self._init_pooler_config(override_pooler_config)
|
||||
|
||||
self._verify_quantization()
|
||||
@@ -323,7 +338,7 @@ class ModelConfig:
|
||||
override_pooler_config: Optional["PoolerConfig"],
|
||||
) -> Optional["PoolerConfig"]:
|
||||
|
||||
if self.task == "embedding":
|
||||
if self.runner_type == "pooling":
|
||||
user_config = override_pooler_config or PoolerConfig()
|
||||
|
||||
base_config = get_pooling_config(self.model, self.revision)
|
||||
@@ -357,60 +372,90 @@ class ModelConfig:
|
||||
"either 'auto', 'slow' or 'mistral'.")
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
|
||||
def _get_preferred_task(
|
||||
self,
|
||||
architectures: List[str],
|
||||
supported_tasks: Set[_ResolvedTask],
|
||||
) -> Optional[_ResolvedTask]:
|
||||
model_id = self.model
|
||||
if get_pooling_config(model_id, self.revision):
|
||||
return "embed"
|
||||
if ModelRegistry.is_cross_encoder_model(architectures):
|
||||
return "score"
|
||||
|
||||
suffix_to_preferred_task: List[Tuple[str, _ResolvedTask]] = [
|
||||
# Other models follow this pattern
|
||||
("ForCausalLM", "generate"),
|
||||
("ForConditionalGeneration", "generate"),
|
||||
("ForSequenceClassification", "classify"),
|
||||
("ChatModel", "generate"),
|
||||
("LMHeadModel", "generate"),
|
||||
("EmbeddingModel", "embed"),
|
||||
("RewardModel", "reward"),
|
||||
]
|
||||
_, arch = ModelRegistry.inspect_model_cls(architectures)
|
||||
|
||||
for suffix, pref_task in suffix_to_preferred_task:
|
||||
if arch.endswith(suffix) and pref_task in supported_tasks:
|
||||
return pref_task
|
||||
|
||||
return None
|
||||
|
||||
def _resolve_task(
|
||||
self,
|
||||
task_option: Union[TaskOption, _Task],
|
||||
task_option: Union[TaskOption, Literal["draft"]],
|
||||
hf_config: PretrainedConfig,
|
||||
) -> Tuple[Set[_Task], _Task]:
|
||||
) -> Tuple[Set[_ResolvedTask], _ResolvedTask]:
|
||||
if task_option == "draft":
|
||||
return {"draft"}, "draft"
|
||||
|
||||
architectures = getattr(hf_config, "architectures", [])
|
||||
|
||||
task_support: Dict[_Task, bool] = {
|
||||
runner_support: Dict[RunnerType, bool] = {
|
||||
# NOTE: Listed from highest to lowest priority,
|
||||
# in case the model supports multiple of them
|
||||
"generate": ModelRegistry.is_text_generation_model(architectures),
|
||||
"embedding": ModelRegistry.is_pooling_model(architectures),
|
||||
"pooling": ModelRegistry.is_pooling_model(architectures),
|
||||
}
|
||||
supported_tasks_lst: List[_Task] = [
|
||||
task for task, is_supported in task_support.items() if is_supported
|
||||
supported_runner_types_lst: List[RunnerType] = [
|
||||
runner_type
|
||||
for runner_type, is_supported in runner_support.items()
|
||||
if is_supported
|
||||
]
|
||||
|
||||
supported_tasks_lst: List[_ResolvedTask] = [
|
||||
task for runner_type in supported_runner_types_lst
|
||||
for task in _RUNNER_TASKS[runner_type]
|
||||
]
|
||||
supported_tasks = set(supported_tasks_lst)
|
||||
|
||||
if task_option == "auto":
|
||||
selected_task = next(iter(supported_tasks_lst))
|
||||
|
||||
if len(supported_tasks) > 1:
|
||||
suffix_to_preferred_task: List[Tuple[str, _Task]] = [
|
||||
# Hardcode the models that are exceptions
|
||||
("AquilaModel", "generate"),
|
||||
("ChatGLMModel", "generate"),
|
||||
# Other models follow this pattern
|
||||
("ForCausalLM", "generate"),
|
||||
("ForConditionalGeneration", "generate"),
|
||||
("ChatModel", "generate"),
|
||||
("LMHeadModel", "generate"),
|
||||
("EmbeddingModel", "embedding"),
|
||||
("RewardModel", "embedding"),
|
||||
("ForSequenceClassification", "embedding"),
|
||||
]
|
||||
info, arch = ModelRegistry.inspect_model_cls(architectures)
|
||||
|
||||
for suffix, pref_task in suffix_to_preferred_task:
|
||||
if arch.endswith(suffix) and pref_task in supported_tasks:
|
||||
selected_task = pref_task
|
||||
break
|
||||
else:
|
||||
if (arch.endswith("Model")
|
||||
and info.architecture.endswith("ForCausalLM")
|
||||
and "embedding" in supported_tasks):
|
||||
selected_task = "embedding"
|
||||
if len(supported_tasks_lst) > 1:
|
||||
preferred_task = self._get_preferred_task(
|
||||
architectures, supported_tasks)
|
||||
if preferred_task is not None:
|
||||
selected_task = preferred_task
|
||||
|
||||
logger.info(
|
||||
"This model supports multiple tasks: %s. "
|
||||
"Defaulting to '%s'.", supported_tasks, selected_task)
|
||||
else:
|
||||
# Aliases
|
||||
if task_option == "embedding":
|
||||
preferred_task = self._get_preferred_task(
|
||||
architectures, supported_tasks)
|
||||
if preferred_task != "embed":
|
||||
msg = ("The 'embedding' task will be restricted to "
|
||||
"embedding models in a future release. Please "
|
||||
"pass `--task classify`, `--task score`, or "
|
||||
"`--task reward` explicitly for other pooling "
|
||||
"models.")
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||
|
||||
task_option = preferred_task or "embed"
|
||||
|
||||
if task_option not in supported_tasks:
|
||||
msg = (
|
||||
f"This model does not support the '{task_option}' task. "
|
||||
@@ -533,7 +578,7 @@ class ModelConfig:
|
||||
|
||||
# Async postprocessor is not necessary with embedding mode
|
||||
# since there is no token generation
|
||||
if self.task == "embedding":
|
||||
if self.runner_type == "pooling":
|
||||
self.use_async_output_proc = False
|
||||
|
||||
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
|
||||
@@ -750,6 +795,14 @@ class ModelConfig:
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
return ModelRegistry.is_cross_encoder_model(architectures)
|
||||
|
||||
@property
|
||||
def supported_runner_types(self) -> Set[RunnerType]:
|
||||
return {_TASK_RUNNER[task] for task in self.supported_tasks}
|
||||
|
||||
@property
|
||||
def runner_type(self) -> RunnerType:
|
||||
return _TASK_RUNNER[self.task]
|
||||
|
||||
|
||||
class CacheConfig:
|
||||
"""Configuration for the KV cache.
|
||||
@@ -1096,7 +1149,7 @@ class ParallelConfig:
|
||||
class SchedulerConfig:
|
||||
"""Scheduler configuration."""
|
||||
|
||||
task: str = "generate" # The task to use the model for.
|
||||
runner_type: str = "generate" # The runner type to launch for the model.
|
||||
|
||||
# Maximum number of tokens to be processed in a single iteration.
|
||||
max_num_batched_tokens: int = field(default=None) # type: ignore
|
||||
@@ -1164,11 +1217,11 @@ class SchedulerConfig:
|
||||
# for higher throughput.
|
||||
self.max_num_batched_tokens = max(self.max_model_len, 2048)
|
||||
|
||||
if self.task == "embedding":
|
||||
# For embedding, choose specific value for higher throughput
|
||||
if self.runner_type == "pooling":
|
||||
# Choose specific value for higher throughput
|
||||
self.max_num_batched_tokens = max(
|
||||
self.max_num_batched_tokens,
|
||||
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS,
|
||||
_POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
if self.is_multimodal_model:
|
||||
# The value needs to be at least the number of multimodal tokens
|
||||
|
||||
@@ -337,7 +337,7 @@ class Scheduler:
|
||||
self.lora_config = lora_config
|
||||
|
||||
version = "selfattn"
|
||||
if (self.scheduler_config.task == "embedding"
|
||||
if (self.scheduler_config.runner_type == "pooling"
|
||||
or self.cache_config.is_attention_free):
|
||||
version = "placeholder"
|
||||
|
||||
|
||||
@@ -1066,7 +1066,7 @@ class EngineArgs:
|
||||
if (is_gpu and not use_sliding_window and not use_spec_decode
|
||||
and not self.enable_lora
|
||||
and not self.enable_prompt_adapter
|
||||
and model_config.task != "embedding"):
|
||||
and model_config.runner_type != "pooling"):
|
||||
self.enable_chunked_prefill = True
|
||||
logger.warning(
|
||||
"Chunked prefill is enabled by default for models with "
|
||||
@@ -1083,7 +1083,8 @@ class EngineArgs:
|
||||
"errors during the initial memory profiling phase, or result "
|
||||
"in low performance due to small KV cache space. Consider "
|
||||
"setting --max-model-len to a smaller value.", max_model_len)
|
||||
elif self.enable_chunked_prefill and model_config.task == "embedding":
|
||||
elif (self.enable_chunked_prefill
|
||||
and model_config.runner_type == "pooling"):
|
||||
msg = "Chunked prefill is not supported for embedding models"
|
||||
raise ValueError(msg)
|
||||
|
||||
@@ -1144,7 +1145,7 @@ class EngineArgs:
|
||||
" please file an issue with detailed information.")
|
||||
|
||||
scheduler_config = SchedulerConfig(
|
||||
task=model_config.task,
|
||||
runner_type=model_config.runner_type,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
max_num_seqs=self.max_num_seqs,
|
||||
max_model_len=model_config.max_model_len,
|
||||
|
||||
@@ -288,7 +288,7 @@ class LLMEngine:
|
||||
|
||||
self.model_executor = executor_class(vllm_config=vllm_config, )
|
||||
|
||||
if self.model_config.task != "embedding":
|
||||
if self.model_config.runner_type != "pooling":
|
||||
self._initialize_kv_caches()
|
||||
|
||||
# If usage stat is enabled, collect relevant info.
|
||||
@@ -1123,7 +1123,7 @@ class LLMEngine:
|
||||
seq_group.metrics.model_execute_time = (
|
||||
o.model_execute_time)
|
||||
|
||||
if self.model_config.task == "embedding":
|
||||
if self.model_config.runner_type == "pooling":
|
||||
self._process_sequence_group_outputs(seq_group, output)
|
||||
else:
|
||||
self.output_processor.process_prompt_logprob(seq_group, output)
|
||||
|
||||
@@ -381,19 +381,20 @@ class LLM:
|
||||
considered legacy and may be deprecated in the future. You should
|
||||
instead pass them via the ``inputs`` parameter.
|
||||
"""
|
||||
task = self.llm_engine.model_config.task
|
||||
if task != "generate":
|
||||
runner_type = self.llm_engine.model_config.runner_type
|
||||
if runner_type != "generate":
|
||||
messages = [
|
||||
"LLM.generate() is only supported for (conditional) generation "
|
||||
"models (XForCausalLM, XForConditionalGeneration).",
|
||||
]
|
||||
|
||||
supported_tasks = self.llm_engine.model_config.supported_tasks
|
||||
if "generate" in supported_tasks:
|
||||
supported_runner_types = self.llm_engine.model_config \
|
||||
.supported_runner_types
|
||||
if "generate" in supported_runner_types:
|
||||
messages.append(
|
||||
"Your model supports the 'generate' task, but is "
|
||||
f"currently initialized for the '{task}' task. Please "
|
||||
"initialize the model using `--task generate`.")
|
||||
"Your model supports the 'generate' runner, but is "
|
||||
f"currently initialized for the '{runner_type}' runner. "
|
||||
"Please initialize vLLM using `--task generate`.")
|
||||
|
||||
raise ValueError(" ".join(messages))
|
||||
|
||||
@@ -793,16 +794,18 @@ class LLM:
|
||||
considered legacy and may be deprecated in the future. You should
|
||||
instead pass them via the ``inputs`` parameter.
|
||||
"""
|
||||
task = self.llm_engine.model_config.task
|
||||
if task != "embedding":
|
||||
messages = ["LLM.encode() is only supported for embedding models."]
|
||||
runner_type = self.llm_engine.model_config.runner_type
|
||||
if runner_type != "pooling":
|
||||
messages = ["LLM.encode() is only supported for pooling models."]
|
||||
|
||||
supported_tasks = self.llm_engine.model_config.supported_tasks
|
||||
if "embedding" in supported_tasks:
|
||||
supported_runner_types = self.llm_engine.model_config \
|
||||
.supported_runner_types
|
||||
if "pooling" in supported_runner_types:
|
||||
messages.append(
|
||||
"Your model supports the 'embedding' task, but is "
|
||||
f"currently initialized for the '{task}' task. Please "
|
||||
"initialize the model using `--task embedding`.")
|
||||
"Your model supports the 'pooling' runner, but is "
|
||||
f"currently initialized for the '{runner_type}' runner. "
|
||||
"Please initialize vLLM using `--task embed`, "
|
||||
"`--task classify`, `--task score` etc.")
|
||||
|
||||
raise ValueError(" ".join(messages))
|
||||
|
||||
@@ -864,21 +867,23 @@ class LLM:
|
||||
A list of ``PoolingRequestOutput`` objects containing the
|
||||
generated scores in the same order as the input prompts.
|
||||
"""
|
||||
task = self.llm_engine.model_config.task
|
||||
if task != "embedding":
|
||||
messages = ["LLM.score() is only supported for embedding models."]
|
||||
runner_type = self.llm_engine.model_config.runner_type
|
||||
if runner_type != "pooling":
|
||||
messages = ["LLM.score() is only supported for pooling models."]
|
||||
|
||||
supported_tasks = self.llm_engine.model_config.supported_tasks
|
||||
if "embedding" in supported_tasks:
|
||||
supported_runner_types = self.llm_engine.model_config \
|
||||
.supported_runner_types
|
||||
if "pooling" in supported_runner_types:
|
||||
messages.append(
|
||||
"Your model supports the 'embedding' task, but is "
|
||||
f"currently initialized for the '{task}' task. Please "
|
||||
"initialize the model using `--task embedding`.")
|
||||
"Your model supports the 'pooling' runner, but is "
|
||||
f"currently initialized for the '{runner_type}' runner. "
|
||||
"Please initialize vLLM using `--task embed`, "
|
||||
"`--task classify`, `--task score` etc.")
|
||||
|
||||
raise ValueError(" ".join(messages))
|
||||
|
||||
if not self.llm_engine.model_config.is_cross_encoder:
|
||||
raise ValueError("Your model does not support the cross encoding")
|
||||
raise ValueError("Your model does not support cross encoding")
|
||||
|
||||
tokenizer = self.llm_engine.get_tokenizer()
|
||||
|
||||
|
||||
@@ -573,7 +573,7 @@ def init_app_state(
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
tool_parser=args.tool_call_parser,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
) if model_config.task == "generate" else None
|
||||
) if model_config.runner_type == "generate" else None
|
||||
state.openai_serving_completion = OpenAIServingCompletion(
|
||||
engine_client,
|
||||
model_config,
|
||||
@@ -582,7 +582,7 @@ def init_app_state(
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
) if model_config.task == "generate" else None
|
||||
) if model_config.runner_type == "generate" else None
|
||||
state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine_client,
|
||||
model_config,
|
||||
@@ -590,13 +590,13 @@ def init_app_state(
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
) if model_config.task == "embedding" else None
|
||||
) if model_config.runner_type == "pooling" else None
|
||||
state.openai_serving_scores = OpenAIServingScores(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
request_logger=request_logger
|
||||
) if (model_config.task == "embedding" \
|
||||
) if (model_config.runner_type == "pooling" \
|
||||
and model_config.is_cross_encoder) else None
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine_client,
|
||||
|
||||
@@ -224,7 +224,7 @@ async def main(args):
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
) if model_config.task == "generate" else None
|
||||
) if model_config.runner_type == "generate" else None
|
||||
openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine,
|
||||
model_config,
|
||||
@@ -232,7 +232,7 @@ async def main(args):
|
||||
request_logger=request_logger,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
) if model_config.task == "embedding" else None
|
||||
) if model_config.runner_type == "pooling" else None
|
||||
|
||||
tracker = BatchProgressTracker()
|
||||
logger.info("Reading batch from %s...", args.input_file)
|
||||
|
||||
@@ -35,7 +35,7 @@ def get_model_architecture(
|
||||
architectures = ["QuantMixtralForCausalLM"]
|
||||
|
||||
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
|
||||
if model_config.task == "embedding":
|
||||
if model_config.runner_type == "pooling":
|
||||
model_cls = as_embedding_model(model_cls)
|
||||
|
||||
return model_cls, arch
|
||||
|
||||
@@ -42,7 +42,7 @@ class EngineCore:
|
||||
executor_class: Type[Executor],
|
||||
usage_context: UsageContext,
|
||||
):
|
||||
assert vllm_config.model_config.task != "embedding"
|
||||
assert vllm_config.model_config.runner_type != "pooling"
|
||||
|
||||
logger.info("Initializing an LLM engine (v%s) with config: %s",
|
||||
VLLM_VERSION, vllm_config)
|
||||
|
||||
@@ -163,7 +163,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
not in ["medusa", "mlp_speculator", "eagle"]) \
|
||||
else {"return_hidden_states": True}
|
||||
ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner
|
||||
if self.model_config.task == "embedding":
|
||||
if self.model_config.runner_type == "pooling":
|
||||
ModelRunnerClass = CPUPoolingModelRunner
|
||||
elif self.model_config.is_encoder_decoder:
|
||||
ModelRunnerClass = CPUEncoderDecoderModelRunner
|
||||
|
||||
@@ -75,7 +75,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
else {"return_hidden_states": True}
|
||||
|
||||
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
||||
if model_config.task == "embedding":
|
||||
if model_config.runner_type == "pooling":
|
||||
ModelRunnerClass = PoolingModelRunner
|
||||
elif self.model_config.is_encoder_decoder:
|
||||
ModelRunnerClass = EncoderDecoderModelRunner
|
||||
|
||||
Reference in New Issue
Block a user