[Refactor] Move task outside of PoolingParams.verify (#33796)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
Cyrus Leung
2026-02-05 17:33:11 +08:00
committed by GitHub
parent d2f4a71cd5
commit 038914b7c8
24 changed files with 186 additions and 216 deletions

View File

@@ -70,6 +70,7 @@ steps:
- vllm/ - vllm/
- tests/test_inputs.py - tests/test_inputs.py
- tests/test_outputs.py - tests/test_outputs.py
- tests/test_pooling_params.py
- tests/multimodal - tests/multimodal
- tests/renderers - tests/renderers
- tests/standalone_tests/lazy_imports.py - tests/standalone_tests/lazy_imports.py
@@ -82,6 +83,7 @@ steps:
- python3 standalone_tests/lazy_imports.py - python3 standalone_tests/lazy_imports.py
- pytest -v -s test_inputs.py - pytest -v -s test_inputs.py
- pytest -v -s test_outputs.py - pytest -v -s test_outputs.py
- pytest -v -s test_pooling_params.py
- pytest -v -s -m 'cpu_test' multimodal - pytest -v -s -m 'cpu_test' multimodal
- pytest -v -s renderers - pytest -v -s renderers
- pytest -v -s tokenizers_ - pytest -v -s tokenizers_

View File

@@ -63,6 +63,7 @@ steps:
- vllm/ - vllm/
- tests/test_inputs.py - tests/test_inputs.py
- tests/test_outputs.py - tests/test_outputs.py
- tests/test_pooling_params.py
- tests/multimodal - tests/multimodal
- tests/renderers - tests/renderers
- tests/standalone_tests/lazy_imports.py - tests/standalone_tests/lazy_imports.py
@@ -75,6 +76,7 @@ steps:
- python3 standalone_tests/lazy_imports.py - python3 standalone_tests/lazy_imports.py
- pytest -v -s test_inputs.py - pytest -v -s test_inputs.py
- pytest -v -s test_outputs.py - pytest -v -s test_outputs.py
- pytest -v -s test_pooling_params.py
- pytest -v -s -m 'cpu_test' multimodal - pytest -v -s -m 'cpu_test' multimodal
- pytest -v -s renderers - pytest -v -s renderers
- pytest -v -s tokenizers_ - pytest -v -s tokenizers_

View File

@@ -122,6 +122,7 @@ steps:
- vllm/ - vllm/
- tests/test_inputs.py - tests/test_inputs.py
- tests/test_outputs.py - tests/test_outputs.py
- tests/test_pooling_params.py
- tests/multimodal - tests/multimodal
- tests/renderers - tests/renderers
- tests/standalone_tests/lazy_imports.py - tests/standalone_tests/lazy_imports.py
@@ -134,6 +135,7 @@ steps:
- python3 standalone_tests/lazy_imports.py - python3 standalone_tests/lazy_imports.py
- pytest -v -s test_inputs.py - pytest -v -s test_inputs.py
- pytest -v -s test_outputs.py - pytest -v -s test_outputs.py
- pytest -v -s test_pooling_params.py
- pytest -v -s -m 'cpu_test' multimodal - pytest -v -s -m 'cpu_test' multimodal
- pytest -v -s renderers - pytest -v -s renderers
- pytest -v -s tokenizers_ - pytest -v -s tokenizers_

View File

@@ -469,6 +469,4 @@ async def test_pooling_not_supported(
}, },
) )
assert response.json()["error"]["type"] == "BadRequestError" assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith( assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
f"Task {task} is not supported"
)

View File

@@ -757,6 +757,4 @@ async def test_pooling_not_supported(
}, },
) )
assert response.json()["error"]["type"] == "BadRequestError" assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith( assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
f"Task {task} is not supported"
)

View File

@@ -138,17 +138,17 @@ def test_colbert_token_embed(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_colbert_embed_not_supported(server: RemoteOpenAIServer, model_name: str): def test_colbert_embed_not_supported(server: RemoteOpenAIServer, model_name: str):
"""Test that ColBERT model does not support 'embed' task.""" """Test that ColBERT model does not support 'embed' task."""
task = "embed"
text = "What is the capital of France?" text = "What is the capital of France?"
pooling_response = requests.post( response = requests.post(
server.url_for("pooling"), server.url_for("pooling"),
json={ json={
"model": model_name, "model": model_name,
"input": text, "input": text,
"task": "embed", "task": task,
}, },
) )
# Should return error assert response.json()["error"]["type"] == "BadRequestError"
assert pooling_response.status_code == 400 assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
assert "Task embed is not supported" in pooling_response.text

View File

@@ -232,6 +232,4 @@ async def test_pooling_not_supported(
}, },
) )
assert response.json()["error"]["type"] == "BadRequestError" assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith( assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
f"Task {task} is not supported"
)

View File

@@ -27,35 +27,24 @@ class MockModelConfig:
pooler_config: PoolerConfig pooler_config: PoolerConfig
def test_task():
pooling_params = PoolingParams()
pooling_params.verify(task="score")
pooling_params = PoolingParams(task="score")
pooling_params.verify(task="score")
with pytest.raises(ValueError):
pooling_params.verify(task="classify")
def test_embed(): def test_embed():
task = "embed" task = "embed"
model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS")) model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
pooling_params = PoolingParams(use_activation=None) pooling_params = PoolingParams(task=task, use_activation=None)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(model_config)
pooling_params = PoolingParams(use_activation=True) pooling_params = PoolingParams(task=task, use_activation=True)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(model_config)
pooling_params = PoolingParams(use_activation=False) pooling_params = PoolingParams(task=task, use_activation=False)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(model_config)
invalid_parameters = classify_parameters + step_pooling_parameters invalid_parameters = classify_parameters + step_pooling_parameters
for p in set(invalid_parameters) - set(embed_parameters): for p in set(invalid_parameters) - set(embed_parameters):
with pytest.raises(ValueError): with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True}) pooling_params = PoolingParams(task=task, **{p: True})
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(model_config)
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS) @pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
@@ -63,7 +52,6 @@ def test_embed_dimensions(model_info: EmbedModelInfo):
task = "embed" task = "embed"
model_config = ModelConfig( model_config = ModelConfig(
model_info.name, model_info.name,
task="auto",
tokenizer=model_info.name, tokenizer=model_info.name,
tokenizer_mode="auto", tokenizer_mode="auto",
trust_remote_code=False, trust_remote_code=False,
@@ -71,37 +59,39 @@ def test_embed_dimensions(model_info: EmbedModelInfo):
dtype="float16", dtype="float16",
) )
pooling_params = PoolingParams(dimensions=None) pooling_params = PoolingParams(task=task, dimensions=None)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(model_config)
with pytest.raises(ValueError): with pytest.raises(ValueError):
pooling_params = PoolingParams(dimensions=1) pooling_params = PoolingParams(task=task, dimensions=1)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(model_config)
if model_info.is_matryoshka: if model_info.is_matryoshka:
assert model_info.matryoshka_dimensions is not None assert model_info.matryoshka_dimensions is not None
pooling_params = PoolingParams(dimensions=model_info.matryoshka_dimensions[0]) pooling_params = PoolingParams(
pooling_params.verify(task=task, model_config=model_config) task=task, dimensions=model_info.matryoshka_dimensions[0]
)
pooling_params.verify(model_config)
@pytest.mark.parametrize("task", ["score", "classify"]) @pytest.mark.parametrize("task", ["score", "classify"])
def test_classify(task): def test_classify(task):
model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS")) model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
pooling_params = PoolingParams(use_activation=None) pooling_params = PoolingParams(task=task, use_activation=None)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(model_config)
pooling_params = PoolingParams(use_activation=True) pooling_params = PoolingParams(task=task, use_activation=True)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(model_config)
pooling_params = PoolingParams(use_activation=False) pooling_params = PoolingParams(task=task, use_activation=False)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(model_config)
invalid_parameters = embed_parameters + step_pooling_parameters invalid_parameters = embed_parameters + step_pooling_parameters
for p in set(invalid_parameters) - set(classify_parameters): for p in set(invalid_parameters) - set(classify_parameters):
with pytest.raises(ValueError): with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True}) pooling_params = PoolingParams(task=task, **{p: True})
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(model_config)
@pytest.mark.parametrize("pooling_type", ["ALL", "STEP"]) @pytest.mark.parametrize("pooling_type", ["ALL", "STEP"])
@@ -111,14 +101,14 @@ def test_token_embed(pooling_type: str):
pooler_config=PoolerConfig(tok_pooling_type=pooling_type) pooler_config=PoolerConfig(tok_pooling_type=pooling_type)
) )
pooling_params = PoolingParams(use_activation=None) pooling_params = PoolingParams(task=task, use_activation=None)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(model_config)
pooling_params = PoolingParams(use_activation=True) pooling_params = PoolingParams(task=task, use_activation=True)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(model_config)
pooling_params = PoolingParams(use_activation=False) pooling_params = PoolingParams(task=task, use_activation=False)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(model_config)
invalid_parameters = classify_parameters invalid_parameters = classify_parameters
if pooling_type != "STEP": if pooling_type != "STEP":
@@ -126,8 +116,8 @@ def test_token_embed(pooling_type: str):
for p in set(invalid_parameters) - set(embed_parameters): for p in set(invalid_parameters) - set(embed_parameters):
with pytest.raises(ValueError): with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True}) pooling_params = PoolingParams(task=task, **{p: True})
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(model_config)
@pytest.mark.parametrize("pooling_type", ["ALL", "STEP"]) @pytest.mark.parametrize("pooling_type", ["ALL", "STEP"])
@@ -137,14 +127,14 @@ def test_token_classify(pooling_type: str):
pooler_config=PoolerConfig(tok_pooling_type=pooling_type) pooler_config=PoolerConfig(tok_pooling_type=pooling_type)
) )
pooling_params = PoolingParams(use_activation=None) pooling_params = PoolingParams(task=task, use_activation=None)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(model_config)
pooling_params = PoolingParams(use_activation=True) pooling_params = PoolingParams(task=task, use_activation=True)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(model_config)
pooling_params = PoolingParams(use_activation=False) pooling_params = PoolingParams(task=task, use_activation=False)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(model_config)
invalid_parameters = embed_parameters invalid_parameters = embed_parameters
if pooling_type != "STEP": if pooling_type != "STEP":
@@ -152,5 +142,5 @@ def test_token_classify(pooling_type: str):
for p in set(invalid_parameters) - set(classify_parameters): for p in set(invalid_parameters) - set(classify_parameters):
with pytest.raises(ValueError): with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True}) pooling_params = PoolingParams(task=task, **{p: True})
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(model_config)

View File

@@ -1135,11 +1135,12 @@ class LLM:
# Use default pooling params. # Use default pooling params.
pooling_params = PoolingParams() pooling_params = PoolingParams()
if pooling_task not in self.supported_tasks:
raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")
for param in as_iter(pooling_params): for param in as_iter(pooling_params):
param.verify(pooling_task, model_config) if param.task is None:
param.task = pooling_task
elif param.task != pooling_task:
msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
raise ValueError(msg)
self._validate_and_add_requests( self._validate_and_add_requests(
prompts=prompts, prompts=prompts,
@@ -1472,8 +1473,9 @@ class LLM:
if pooling_params is None: if pooling_params is None:
pooling_params = PoolingParams(task="score") pooling_params = PoolingParams(task="score")
elif pooling_params.task is None:
pooling_params.task = "score"
pooling_params.verify("score", model_config)
pooling_params_list = list[PoolingParams]() pooling_params_list = list[PoolingParams]()
prompts = list[PromptType]() prompts = list[PromptType]()
@@ -1836,6 +1838,7 @@ class LLM:
lora_request=lora_request, lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
priority=priority, priority=priority,
supported_tasks=self.supported_tasks,
) )
self.llm_engine.add_request( self.llm_engine.add_request(

View File

@@ -68,7 +68,6 @@ def init_pooling_state(
OpenAIServingPooling( OpenAIServingPooling(
engine_client, engine_client,
state.openai_serving_models, state.openai_serving_models,
supported_tasks=supported_tasks,
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
@@ -76,7 +75,7 @@ def init_pooling_state(
log_error_stack=args.log_error_stack, log_error_stack=args.log_error_stack,
) )
) )
if any(task in POOLING_TASKS for task in supported_tasks) if any(t in supported_tasks for t in POOLING_TASKS)
else None else None
) )
state.openai_serving_embedding = ( state.openai_serving_embedding = (

View File

@@ -6,19 +6,15 @@ from typing import Annotated, Any
from pydantic import Field, model_validator from pydantic import Field, model_validator
from vllm import PoolingParams
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam, ChatCompletionMessageParam,
ChatTemplateContentFormatOption, ChatTemplateContentFormatOption,
) )
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
from vllm.logger import init_logger
from vllm.renderers import ChatParams, merge_kwargs from vllm.renderers import ChatParams, merge_kwargs
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
logger = init_logger(__name__)
class PoolingBasicRequestMixin(OpenAIBaseModel): class PoolingBasicRequestMixin(OpenAIBaseModel):
# --8<-- [start:pooling-common-params] # --8<-- [start:pooling-common-params]
@@ -185,20 +181,6 @@ class EmbedRequestMixin(EncodingRequestMixin):
) )
# --8<-- [end:embed-extra-params] # --8<-- [end:embed-extra-params]
def to_pooling_params(self):
if self.normalize is not None:
logger.warning_once(
"`normalize` is deprecated and will be removed in v0.17. "
"Please pass `use_activation` instead."
)
self.use_activation = self.normalize
return PoolingParams(
dimensions=self.dimensions,
use_activation=self.use_activation,
truncate_prompt_tokens=getattr(self, "truncate_prompt_tokens", None),
)
class ClassifyRequestMixin(OpenAIBaseModel): class ClassifyRequestMixin(OpenAIBaseModel):
# --8<-- [start:classify-extra-params] # --8<-- [start:classify-extra-params]
@@ -208,9 +190,3 @@ class ClassifyRequestMixin(OpenAIBaseModel):
"`None` uses the pooler's default, which is `True` in most cases.", "`None` uses the pooler's default, which is `True` in most cases.",
) )
# --8<-- [end:classify-extra-params] # --8<-- [end:classify-extra-params]
def to_pooling_params(self):
return PoolingParams(
use_activation=self.use_activation,
truncate_prompt_tokens=getattr(self, "truncate_prompt_tokens", None),
)

View File

@@ -6,6 +6,7 @@ from typing import Any, TypeAlias
from pydantic import Field from pydantic import Field
from vllm import PoolingParams
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import ( from vllm.entrypoints.pooling.base.protocol import (
@@ -14,9 +15,12 @@ from vllm.entrypoints.pooling.base.protocol import (
CompletionRequestMixin, CompletionRequestMixin,
PoolingBasicRequestMixin, PoolingBasicRequestMixin,
) )
from vllm.logger import init_logger
from vllm.renderers import TokenizeParams from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid from vllm.utils import random_uuid
logger = init_logger(__name__)
class ClassificationCompletionRequest( class ClassificationCompletionRequest(
PoolingBasicRequestMixin, CompletionRequestMixin, ClassifyRequestMixin PoolingBasicRequestMixin, CompletionRequestMixin, ClassifyRequestMixin
@@ -33,6 +37,13 @@ class ClassificationCompletionRequest(
max_total_tokens_param="max_model_len", max_total_tokens_param="max_model_len",
) )
def to_pooling_params(self):
return PoolingParams(
task="classify",
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)
class ClassificationChatRequest( class ClassificationChatRequest(
PoolingBasicRequestMixin, ChatRequestMixin, ClassifyRequestMixin PoolingBasicRequestMixin, ChatRequestMixin, ClassifyRequestMixin
@@ -55,6 +66,13 @@ class ClassificationChatRequest(
max_total_tokens_param="max_model_len", max_total_tokens_param="max_model_len",
) )
def to_pooling_params(self):
return PoolingParams(
task="classify",
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)
ClassificationRequest: TypeAlias = ( ClassificationRequest: TypeAlias = (
ClassificationCompletionRequest | ClassificationChatRequest ClassificationCompletionRequest | ClassificationChatRequest

View File

@@ -22,7 +22,6 @@ from vllm.entrypoints.pooling.classify.protocol import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import ClassificationOutput from vllm.outputs import ClassificationOutput
from vllm.pooling_params import PoolingParams
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -159,18 +158,3 @@ class ServingClassification(OpenAIServing):
) )
return await self.handle(ctx) # type: ignore[return-value] return await self.handle(ctx) # type: ignore[return-value]
def _create_pooling_params(
self,
ctx: ClassificationServeContext,
) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
return pooling_params
try:
pooling_params.verify("classify", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
return pooling_params

View File

@@ -5,6 +5,7 @@ from typing import Any, TypeAlias
from pydantic import Field from pydantic import Field
from vllm import PoolingParams
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import ( from vllm.entrypoints.pooling.base.protocol import (
@@ -13,9 +14,12 @@ from vllm.entrypoints.pooling.base.protocol import (
EmbedRequestMixin, EmbedRequestMixin,
PoolingBasicRequestMixin, PoolingBasicRequestMixin,
) )
from vllm.logger import init_logger
from vllm.renderers import TokenizeParams from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid from vllm.utils import random_uuid
logger = init_logger(__name__)
def _get_max_total_output_tokens( def _get_max_total_output_tokens(
model_config: ModelConfig, model_config: ModelConfig,
@@ -55,6 +59,21 @@ class EmbeddingCompletionRequest(
max_output_tokens_param="max_model_len - max_embed_len", max_output_tokens_param="max_model_len - max_embed_len",
) )
def to_pooling_params(self):
if self.normalize is not None:
logger.warning_once(
"`normalize` is deprecated and will be removed in v0.17. "
"Please pass `use_activation` instead."
)
self.use_activation = self.normalize
return PoolingParams(
task="embed",
dimensions=self.dimensions,
use_activation=self.use_activation,
truncate_prompt_tokens=self.truncate_prompt_tokens,
)
class EmbeddingChatRequest( class EmbeddingChatRequest(
PoolingBasicRequestMixin, ChatRequestMixin, EmbedRequestMixin PoolingBasicRequestMixin, ChatRequestMixin, EmbedRequestMixin
@@ -82,6 +101,21 @@ class EmbeddingChatRequest(
max_output_tokens_param="max_model_len - max_embed_len", max_output_tokens_param="max_model_len - max_embed_len",
) )
def to_pooling_params(self):
if self.normalize is not None:
logger.warning_once(
"`normalize` is deprecated and will be removed in v0.17. "
"Please pass `use_activation` instead."
)
self.use_activation = self.normalize
return PoolingParams(
task="embed",
dimensions=self.dimensions,
use_activation=self.use_activation,
truncate_prompt_tokens=self.truncate_prompt_tokens,
)
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest

View File

@@ -424,12 +424,6 @@ class OpenAIServingEmbedding(OpenAIServing):
if isinstance(pooling_params, ErrorResponse): if isinstance(pooling_params, ErrorResponse):
return pooling_params return pooling_params
# Verify and set the task for pooling params
try:
pooling_params.verify("embed", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
if ctx.engine_prompts is None: if ctx.engine_prompts is None:
return self.create_error_response("Engine prompts not available") return self.create_error_response("Engine prompts not available")
@@ -463,8 +457,7 @@ class OpenAIServingEmbedding(OpenAIServing):
return None return None
except Exception as e: except Exception as e:
# TODO: Use a vllm-specific Validation Error return self.create_error_response(e)
return self.create_error_response(str(e))
async def _collect_batch( async def _collect_batch(
self, self,
@@ -634,7 +627,7 @@ class OpenAIServingEmbedding(OpenAIServing):
return None return None
except Exception as e: except Exception as e:
return self.create_error_response(str(e)) return self.create_error_response(e)
async def create_embedding( async def create_embedding(
self, self,
@@ -661,18 +654,3 @@ class OpenAIServingEmbedding(OpenAIServing):
) )
return await self.handle(ctx) # type: ignore[return-value] return await self.handle(ctx) # type: ignore[return-value]
def _create_pooling_params(
self,
ctx: EmbeddingServeContext,
) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
return pooling_params
try:
pooling_params.verify("embed", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
return pooling_params

View File

@@ -53,6 +53,7 @@ class PoolingCompletionRequest(
self.use_activation = self.normalize self.use_activation = self.normalize
return PoolingParams( return PoolingParams(
task=self.task,
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation, use_activation=self.use_activation,
dimensions=self.dimensions, dimensions=self.dimensions,
@@ -90,6 +91,7 @@ class PoolingChatRequest(
self.use_activation = self.normalize self.use_activation = self.normalize
return PoolingParams( return PoolingParams(
task=self.task,
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation, use_activation=self.use_activation,
dimensions=self.dimensions, dimensions=self.dimensions,
@@ -104,7 +106,7 @@ class IOProcessorRequest(PoolingBasicRequestMixin, EncodingRequestMixin, Generic
task: PoolingTask = "plugin" task: PoolingTask = "plugin"
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams() return PoolingParams(task=self.task)
class IOProcessorResponse(OpenAIBaseModel, Generic[T]): class IOProcessorResponse(OpenAIBaseModel, Generic[T]):

View File

@@ -35,7 +35,6 @@ from vllm.entrypoints.pooling.utils import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.tasks import PoolingTask, SupportedTask
from vllm.utils.async_utils import merge_async_iterators from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
@@ -48,7 +47,6 @@ class OpenAIServingPooling(OpenAIServing):
engine_client: EngineClient, engine_client: EngineClient,
models: OpenAIServingModels, models: OpenAIServingModels,
*, *,
supported_tasks: tuple[SupportedTask, ...],
request_logger: RequestLogger | None, request_logger: RequestLogger | None,
chat_template: str | None, chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption, chat_template_content_format: ChatTemplateContentFormatOption,
@@ -62,7 +60,6 @@ class OpenAIServingPooling(OpenAIServing):
log_error_stack=log_error_stack, log_error_stack=log_error_stack,
) )
self.supported_tasks = supported_tasks
self.chat_template = chat_template self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template self.trust_request_chat_template = trust_request_chat_template
@@ -152,32 +149,6 @@ class OpenAIServingPooling(OpenAIServing):
else: else:
pooling_params = request.to_pooling_params() pooling_params = request.to_pooling_params()
pooling_task: PoolingTask
if request.task is None:
if "token_embed" in self.supported_tasks:
pooling_task = "token_embed"
elif "token_classify" in self.supported_tasks:
pooling_task = "token_classify"
elif "plugin" in self.supported_tasks:
pooling_task = "plugin"
else:
return self.create_error_response(
f"pooling_task must be one of {self.supported_tasks}."
)
else:
pooling_task = request.task
if pooling_task not in self.supported_tasks:
return self.create_error_response(
f"Task {pooling_task} is not supported, it"
f" must be one of {self.supported_tasks}."
)
try:
pooling_params.verify(pooling_task, self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"
@@ -212,8 +183,7 @@ class OpenAIServingPooling(OpenAIServing):
generators.append(generator) generators.append(generator)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error return self.create_error_response(e)
return self.create_error_response(str(e))
result_generator = merge_async_iterators(*generators) result_generator = merge_async_iterators(*generators)
@@ -251,8 +221,7 @@ class OpenAIServingPooling(OpenAIServing):
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error return self.create_error_response(e)
return self.create_error_response(str(e))
return response return response

View File

@@ -18,6 +18,7 @@ from vllm.entrypoints.pooling.score.utils import (
ScoreInputs, ScoreInputs,
) )
from vllm.renderers import TokenizeParams from vllm.renderers import TokenizeParams
from vllm.tasks import PoolingTask
from vllm.utils import random_uuid from vllm.utils import random_uuid
@@ -40,8 +41,9 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
max_total_tokens_param="max_model_len", max_total_tokens_param="max_model_len",
) )
def to_pooling_params(self): def to_pooling_params(self, task: PoolingTask = "score"):
return PoolingParams( return PoolingParams(
task=task,
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation, use_activation=self.use_activation,
) )
@@ -122,6 +124,13 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
max_total_tokens_param="max_model_len", max_total_tokens_param="max_model_len",
) )
def to_pooling_params(self, task: PoolingTask = "score"):
return PoolingParams(
task=task,
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)
class RerankDocument(BaseModel): class RerankDocument(BaseModel):
text: str | None = None text: str | None = None

View File

@@ -118,12 +118,7 @@ class ServingScores(OpenAIServing):
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
pooling_params = request.to_pooling_params() pooling_params = request.to_pooling_params("embed")
try:
pooling_params.verify("embed", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"
@@ -223,19 +218,7 @@ class ServingScores(OpenAIServing):
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
# Use token_embed task for late interaction models pooling_params = request.to_pooling_params("token_embed")
from vllm import PoolingParams
pooling_params = PoolingParams(
task="token_embed",
truncate_prompt_tokens=request.truncate_prompt_tokens,
use_activation=request.use_activation,
)
try:
pooling_params.verify("token_embed", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"
@@ -358,12 +341,7 @@ class ServingScores(OpenAIServing):
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
default_pooling_params = request.to_pooling_params() default_pooling_params = request.to_pooling_params("score")
try:
default_pooling_params.verify("score", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"
@@ -497,8 +475,7 @@ class ServingScores(OpenAIServing):
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error return self.create_error_response(e)
return self.create_error_response(str(e))
async def do_rerank( async def do_rerank(
self, request: RerankRequest, raw_request: Request | None = None self, request: RerankRequest, raw_request: Request | None = None
@@ -542,8 +519,7 @@ class ServingScores(OpenAIServing):
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error return self.create_error_response(e)
return self.create_error_response(str(e))
def request_output_to_score_response( def request_output_to_score_response(
self, self,

View File

@@ -72,15 +72,7 @@ class PoolingParams(
"""Returns a deep copy of the PoolingParams instance.""" """Returns a deep copy of the PoolingParams instance."""
return deepcopy(self) return deepcopy(self)
def verify( def verify(self, model_config: "ModelConfig") -> None:
self, task: PoolingTask, model_config: "ModelConfig | None" = None
) -> None:
if self.task is None:
self.task = task
elif self.task != task:
msg = f"You cannot overwrite {self.task=!r} with {task=!r}!"
raise ValueError(msg)
# plugin task uses io_processor.parse_request to verify inputs, # plugin task uses io_processor.parse_request to verify inputs,
# skipping PoolingParams verify # skipping PoolingParams verify
if self.task == "plugin": if self.task == "plugin":
@@ -167,7 +159,7 @@ class PoolingParams(
if mds is not None: if mds is not None:
if self.dimensions not in mds: if self.dimensions not in mds:
raise ValueError( raise ValueError(
f'Model "{model_config.served_model_name}" ' f"Model {model_config.served_model_name!r} "
f"only supports {str(mds)} matryoshka dimensions, " f"only supports {str(mds)} matryoshka dimensions, "
f"use other output dimensions will " f"use other output dimensions will "
f"lead to poor results." f"lead to poor results."
@@ -179,7 +171,7 @@ class PoolingParams(
if self.use_activation is None: if self.use_activation is None:
self.use_activation = True self.use_activation = True
else: else:
raise ValueError(f"Unknown pooling task: {self.task}") raise ValueError(f"Unknown pooling task: {self.task!r}")
def _verify_valid_parameters(self): def _verify_valid_parameters(self):
assert self.task is not None, "task must be set" assert self.task is not None, "task must be set"
@@ -194,7 +186,7 @@ class PoolingParams(
if invalid_parameters: if invalid_parameters:
raise ValueError( raise ValueError(
f"Task {self.task} only supports {valid_parameters} " f"Task {self.task!r} only supports {valid_parameters} "
f"parameters, does not support " f"parameters, does not support "
f"{invalid_parameters} parameters" f"{invalid_parameters} parameters"
) )

View File

@@ -269,7 +269,11 @@ class AsyncLLM(EngineClient):
cancel_task_threadsafe(handler) cancel_task_threadsafe(handler)
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return await self.engine_core.get_supported_tasks_async() if not hasattr(self, "_supported_tasks"):
# Cache the result
self._supported_tasks = await self.engine_core.get_supported_tasks_async()
return self._supported_tasks
async def add_request( async def add_request(
self, self,
@@ -355,6 +359,7 @@ class AsyncLLM(EngineClient):
trace_headers=trace_headers, trace_headers=trace_headers,
priority=priority, priority=priority,
data_parallel_rank=data_parallel_rank, data_parallel_rank=data_parallel_rank,
supported_tasks=await self.get_supported_tasks(),
) )
prompt_text = get_prompt_text(prompt) prompt_text = get_prompt_text(prompt)

View File

@@ -31,6 +31,7 @@ from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer from vllm.renderers import BaseRenderer
from vllm.sampling_params import _SAMPLING_EPS, SamplingParams from vllm.sampling_params import _SAMPLING_EPS, SamplingParams
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
@@ -196,13 +197,41 @@ class InputProcessor:
def _validate_params( def _validate_params(
self, self,
params: SamplingParams | PoolingParams, params: SamplingParams | PoolingParams,
# TODO: Validate generation tasks as well once `supported_tasks`
# is passed to all `process_inputs` calls
supported_tasks: tuple[SupportedTask, ...] | None,
): ):
""" """
Validate supported SamplingParam. Validate supported SamplingParam.
Should raise ValueError if unsupported for API Server. Should raise ValueError if unsupported for API Server.
""" """
if isinstance(params, PoolingParams): if isinstance(params, PoolingParams):
if supported_tasks is None:
raise RuntimeError("`supported_tasks` must be passed for pooling")
supported_pooling_tasks = [
task for task in supported_tasks if task in POOLING_TASKS
]
if params.task is None:
if not supported_pooling_tasks:
raise ValueError("Pooling tasks are not supported")
if "token_embed" in supported_pooling_tasks:
params.task = "token_embed"
elif "token_classify" in supported_pooling_tasks:
params.task = "token_classify"
elif "plugin" in supported_pooling_tasks:
params.task = "plugin"
if params.task not in supported_pooling_tasks:
raise ValueError(
f"Unsupported task: {params.task!r} "
f"Supported tasks: {supported_pooling_tasks}"
)
params.verify(self.model_config)
return return
self._validate_logprobs(params) self._validate_logprobs(params)
@@ -498,10 +527,11 @@ class InputProcessor:
trace_headers: Mapping[str, str] | None = None, trace_headers: Mapping[str, str] | None = None,
priority: int = 0, priority: int = 0,
data_parallel_rank: int | None = None, data_parallel_rank: int | None = None,
supported_tasks: tuple[SupportedTask, ...] | None = None,
resumable: bool = False, resumable: bool = False,
) -> EngineCoreRequest: ) -> EngineCoreRequest:
self._validate_lora(lora_request) self._validate_lora(lora_request)
self._validate_params(params) self._validate_params(params, supported_tasks)
parallel_config = self.vllm_config.parallel_config parallel_config = self.vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size dp_size = parallel_config.data_parallel_size

View File

@@ -201,7 +201,11 @@ class LLMEngine:
return outputs return outputs
def get_supported_tasks(self) -> tuple[SupportedTask, ...]: def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.engine_core.get_supported_tasks() if not hasattr(self, "_supported_tasks"):
# Cache the result
self._supported_tasks = self.engine_core.get_supported_tasks()
return self._supported_tasks
def abort_request(self, request_ids: list[str], internal: bool = False) -> None: def abort_request(self, request_ids: list[str], internal: bool = False) -> None:
"""Remove request_ids from EngineCore and Detokenizer.""" """Remove request_ids from EngineCore and Detokenizer."""
@@ -245,6 +249,7 @@ class LLMEngine:
tokenization_kwargs, tokenization_kwargs,
trace_headers, trace_headers,
priority, priority,
supported_tasks=self.get_supported_tasks(),
) )
prompt_text = get_prompt_text(prompt) prompt_text = get_prompt_text(prompt)

View File

@@ -5037,7 +5037,7 @@ class GPUModelRunner(
model = cast(VllmModelForPooling, self.get_model()) model = cast(VllmModelForPooling, self.get_model())
dummy_pooling_params = PoolingParams(task=task) dummy_pooling_params = PoolingParams(task=task)
dummy_pooling_params.verify(task=task, model_config=self.model_config) dummy_pooling_params.verify(self.model_config)
to_update = model.pooler.get_pooling_updates(task) to_update = model.pooler.get_pooling_updates(task)
to_update.apply(dummy_pooling_params) to_update.apply(dummy_pooling_params)