[Refactor] Move task outside of PoolingParams.verify (#33796)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
Cyrus Leung
2026-02-05 17:33:11 +08:00
committed by GitHub
parent d2f4a71cd5
commit 038914b7c8
24 changed files with 186 additions and 216 deletions

View File

@@ -70,6 +70,7 @@ steps:
- vllm/
- tests/test_inputs.py
- tests/test_outputs.py
- tests/test_pooling_params.py
- tests/multimodal
- tests/renderers
- tests/standalone_tests/lazy_imports.py
@@ -82,6 +83,7 @@ steps:
- python3 standalone_tests/lazy_imports.py
- pytest -v -s test_inputs.py
- pytest -v -s test_outputs.py
- pytest -v -s test_pooling_params.py
- pytest -v -s -m 'cpu_test' multimodal
- pytest -v -s renderers
- pytest -v -s tokenizers_

View File

@@ -63,6 +63,7 @@ steps:
- vllm/
- tests/test_inputs.py
- tests/test_outputs.py
- tests/test_pooling_params.py
- tests/multimodal
- tests/renderers
- tests/standalone_tests/lazy_imports.py
@@ -75,6 +76,7 @@ steps:
- python3 standalone_tests/lazy_imports.py
- pytest -v -s test_inputs.py
- pytest -v -s test_outputs.py
- pytest -v -s test_pooling_params.py
- pytest -v -s -m 'cpu_test' multimodal
- pytest -v -s renderers
- pytest -v -s tokenizers_

View File

@@ -122,6 +122,7 @@ steps:
- vllm/
- tests/test_inputs.py
- tests/test_outputs.py
- tests/test_pooling_params.py
- tests/multimodal
- tests/renderers
- tests/standalone_tests/lazy_imports.py
@@ -134,6 +135,7 @@ steps:
- python3 standalone_tests/lazy_imports.py
- pytest -v -s test_inputs.py
- pytest -v -s test_outputs.py
- pytest -v -s test_pooling_params.py
- pytest -v -s -m 'cpu_test' multimodal
- pytest -v -s renderers
- pytest -v -s tokenizers_

View File

@@ -469,6 +469,4 @@ async def test_pooling_not_supported(
},
)
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(
f"Task {task} is not supported"
)
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")

View File

@@ -757,6 +757,4 @@ async def test_pooling_not_supported(
},
)
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(
f"Task {task} is not supported"
)
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")

View File

@@ -138,17 +138,17 @@ def test_colbert_token_embed(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_colbert_embed_not_supported(server: RemoteOpenAIServer, model_name: str):
"""Test that ColBERT model does not support 'embed' task."""
task = "embed"
text = "What is the capital of France?"
pooling_response = requests.post(
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": text,
"task": "embed",
"task": task,
},
)
# Should return error
assert pooling_response.status_code == 400
assert "Task embed is not supported" in pooling_response.text
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")

View File

@@ -232,6 +232,4 @@ async def test_pooling_not_supported(
},
)
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(
f"Task {task} is not supported"
)
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")

View File

@@ -27,35 +27,24 @@ class MockModelConfig:
pooler_config: PoolerConfig
def test_task():
pooling_params = PoolingParams()
pooling_params.verify(task="score")
pooling_params = PoolingParams(task="score")
pooling_params.verify(task="score")
with pytest.raises(ValueError):
pooling_params.verify(task="classify")
def test_embed():
task = "embed"
model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
pooling_params = PoolingParams(use_activation=None)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, use_activation=None)
pooling_params.verify(model_config)
pooling_params = PoolingParams(use_activation=True)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, use_activation=True)
pooling_params.verify(model_config)
pooling_params = PoolingParams(use_activation=False)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, use_activation=False)
pooling_params.verify(model_config)
invalid_parameters = classify_parameters + step_pooling_parameters
for p in set(invalid_parameters) - set(embed_parameters):
with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, **{p: True})
pooling_params.verify(model_config)
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
@@ -63,7 +52,6 @@ def test_embed_dimensions(model_info: EmbedModelInfo):
task = "embed"
model_config = ModelConfig(
model_info.name,
task="auto",
tokenizer=model_info.name,
tokenizer_mode="auto",
trust_remote_code=False,
@@ -71,37 +59,39 @@ def test_embed_dimensions(model_info: EmbedModelInfo):
dtype="float16",
)
pooling_params = PoolingParams(dimensions=None)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, dimensions=None)
pooling_params.verify(model_config)
with pytest.raises(ValueError):
pooling_params = PoolingParams(dimensions=1)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, dimensions=1)
pooling_params.verify(model_config)
if model_info.is_matryoshka:
assert model_info.matryoshka_dimensions is not None
pooling_params = PoolingParams(dimensions=model_info.matryoshka_dimensions[0])
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(
task=task, dimensions=model_info.matryoshka_dimensions[0]
)
pooling_params.verify(model_config)
@pytest.mark.parametrize("task", ["score", "classify"])
def test_classify(task):
model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
pooling_params = PoolingParams(use_activation=None)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, use_activation=None)
pooling_params.verify(model_config)
pooling_params = PoolingParams(use_activation=True)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, use_activation=True)
pooling_params.verify(model_config)
pooling_params = PoolingParams(use_activation=False)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, use_activation=False)
pooling_params.verify(model_config)
invalid_parameters = embed_parameters + step_pooling_parameters
for p in set(invalid_parameters) - set(classify_parameters):
with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, **{p: True})
pooling_params.verify(model_config)
@pytest.mark.parametrize("pooling_type", ["ALL", "STEP"])
@@ -111,14 +101,14 @@ def test_token_embed(pooling_type: str):
pooler_config=PoolerConfig(tok_pooling_type=pooling_type)
)
pooling_params = PoolingParams(use_activation=None)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, use_activation=None)
pooling_params.verify(model_config)
pooling_params = PoolingParams(use_activation=True)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, use_activation=True)
pooling_params.verify(model_config)
pooling_params = PoolingParams(use_activation=False)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, use_activation=False)
pooling_params.verify(model_config)
invalid_parameters = classify_parameters
if pooling_type != "STEP":
@@ -126,8 +116,8 @@ def test_token_embed(pooling_type: str):
for p in set(invalid_parameters) - set(embed_parameters):
with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, **{p: True})
pooling_params.verify(model_config)
@pytest.mark.parametrize("pooling_type", ["ALL", "STEP"])
@@ -137,14 +127,14 @@ def test_token_classify(pooling_type: str):
pooler_config=PoolerConfig(tok_pooling_type=pooling_type)
)
pooling_params = PoolingParams(use_activation=None)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, use_activation=None)
pooling_params.verify(model_config)
pooling_params = PoolingParams(use_activation=True)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, use_activation=True)
pooling_params.verify(model_config)
pooling_params = PoolingParams(use_activation=False)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, use_activation=False)
pooling_params.verify(model_config)
invalid_parameters = embed_parameters
if pooling_type != "STEP":
@@ -152,5 +142,5 @@ def test_token_classify(pooling_type: str):
for p in set(invalid_parameters) - set(classify_parameters):
with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(task=task, **{p: True})
pooling_params.verify(model_config)

View File

@@ -1135,11 +1135,12 @@ class LLM:
# Use default pooling params.
pooling_params = PoolingParams()
if pooling_task not in self.supported_tasks:
raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")
for param in as_iter(pooling_params):
param.verify(pooling_task, model_config)
if param.task is None:
param.task = pooling_task
elif param.task != pooling_task:
msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
raise ValueError(msg)
self._validate_and_add_requests(
prompts=prompts,
@@ -1472,8 +1473,9 @@ class LLM:
if pooling_params is None:
pooling_params = PoolingParams(task="score")
elif pooling_params.task is None:
pooling_params.task = "score"
pooling_params.verify("score", model_config)
pooling_params_list = list[PoolingParams]()
prompts = list[PromptType]()
@@ -1836,6 +1838,7 @@ class LLM:
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
supported_tasks=self.supported_tasks,
)
self.llm_engine.add_request(

View File

@@ -68,7 +68,6 @@ def init_pooling_state(
OpenAIServingPooling(
engine_client,
state.openai_serving_models,
supported_tasks=supported_tasks,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
@@ -76,7 +75,7 @@ def init_pooling_state(
log_error_stack=args.log_error_stack,
)
)
if any(task in POOLING_TASKS for task in supported_tasks)
if any(t in supported_tasks for t in POOLING_TASKS)
else None
)
state.openai_serving_embedding = (

View File

@@ -6,19 +6,15 @@ from typing import Annotated, Any
from pydantic import Field, model_validator
from vllm import PoolingParams
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
from vllm.logger import init_logger
from vllm.renderers import ChatParams, merge_kwargs
from vllm.utils import random_uuid
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
logger = init_logger(__name__)
class PoolingBasicRequestMixin(OpenAIBaseModel):
# --8<-- [start:pooling-common-params]
@@ -185,20 +181,6 @@ class EmbedRequestMixin(EncodingRequestMixin):
)
# --8<-- [end:embed-extra-params]
def to_pooling_params(self):
if self.normalize is not None:
logger.warning_once(
"`normalize` is deprecated and will be removed in v0.17. "
"Please pass `use_activation` instead."
)
self.use_activation = self.normalize
return PoolingParams(
dimensions=self.dimensions,
use_activation=self.use_activation,
truncate_prompt_tokens=getattr(self, "truncate_prompt_tokens", None),
)
class ClassifyRequestMixin(OpenAIBaseModel):
# --8<-- [start:classify-extra-params]
@@ -208,9 +190,3 @@ class ClassifyRequestMixin(OpenAIBaseModel):
"`None` uses the pooler's default, which is `True` in most cases.",
)
# --8<-- [end:classify-extra-params]
def to_pooling_params(self):
return PoolingParams(
use_activation=self.use_activation,
truncate_prompt_tokens=getattr(self, "truncate_prompt_tokens", None),
)

View File

@@ -6,6 +6,7 @@ from typing import Any, TypeAlias
from pydantic import Field
from vllm import PoolingParams
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
@@ -14,9 +15,12 @@ from vllm.entrypoints.pooling.base.protocol import (
CompletionRequestMixin,
PoolingBasicRequestMixin,
)
from vllm.logger import init_logger
from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid
logger = init_logger(__name__)
class ClassificationCompletionRequest(
PoolingBasicRequestMixin, CompletionRequestMixin, ClassifyRequestMixin
@@ -33,6 +37,13 @@ class ClassificationCompletionRequest(
max_total_tokens_param="max_model_len",
)
def to_pooling_params(self):
return PoolingParams(
task="classify",
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)
class ClassificationChatRequest(
PoolingBasicRequestMixin, ChatRequestMixin, ClassifyRequestMixin
@@ -55,6 +66,13 @@ class ClassificationChatRequest(
max_total_tokens_param="max_model_len",
)
def to_pooling_params(self):
return PoolingParams(
task="classify",
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)
ClassificationRequest: TypeAlias = (
ClassificationCompletionRequest | ClassificationChatRequest

View File

@@ -22,7 +22,6 @@ from vllm.entrypoints.pooling.classify.protocol import (
)
from vllm.logger import init_logger
from vllm.outputs import ClassificationOutput
from vllm.pooling_params import PoolingParams
logger = init_logger(__name__)
@@ -159,18 +158,3 @@ class ServingClassification(OpenAIServing):
)
return await self.handle(ctx) # type: ignore[return-value]
def _create_pooling_params(
self,
ctx: ClassificationServeContext,
) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
return pooling_params
try:
pooling_params.verify("classify", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
return pooling_params

View File

@@ -5,6 +5,7 @@ from typing import Any, TypeAlias
from pydantic import Field
from vllm import PoolingParams
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
@@ -13,9 +14,12 @@ from vllm.entrypoints.pooling.base.protocol import (
EmbedRequestMixin,
PoolingBasicRequestMixin,
)
from vllm.logger import init_logger
from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid
logger = init_logger(__name__)
def _get_max_total_output_tokens(
model_config: ModelConfig,
@@ -55,6 +59,21 @@ class EmbeddingCompletionRequest(
max_output_tokens_param="max_model_len - max_embed_len",
)
def to_pooling_params(self):
if self.normalize is not None:
logger.warning_once(
"`normalize` is deprecated and will be removed in v0.17. "
"Please pass `use_activation` instead."
)
self.use_activation = self.normalize
return PoolingParams(
task="embed",
dimensions=self.dimensions,
use_activation=self.use_activation,
truncate_prompt_tokens=self.truncate_prompt_tokens,
)
class EmbeddingChatRequest(
PoolingBasicRequestMixin, ChatRequestMixin, EmbedRequestMixin
@@ -82,6 +101,21 @@ class EmbeddingChatRequest(
max_output_tokens_param="max_model_len - max_embed_len",
)
def to_pooling_params(self):
if self.normalize is not None:
logger.warning_once(
"`normalize` is deprecated and will be removed in v0.17. "
"Please pass `use_activation` instead."
)
self.use_activation = self.normalize
return PoolingParams(
task="embed",
dimensions=self.dimensions,
use_activation=self.use_activation,
truncate_prompt_tokens=self.truncate_prompt_tokens,
)
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest

View File

@@ -424,12 +424,6 @@ class OpenAIServingEmbedding(OpenAIServing):
if isinstance(pooling_params, ErrorResponse):
return pooling_params
# Verify and set the task for pooling params
try:
pooling_params.verify("embed", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
if ctx.engine_prompts is None:
return self.create_error_response("Engine prompts not available")
@@ -463,8 +457,7 @@ class OpenAIServingEmbedding(OpenAIServing):
return None
except Exception as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return self.create_error_response(e)
async def _collect_batch(
self,
@@ -634,7 +627,7 @@ class OpenAIServingEmbedding(OpenAIServing):
return None
except Exception as e:
return self.create_error_response(str(e))
return self.create_error_response(e)
async def create_embedding(
self,
@@ -661,18 +654,3 @@ class OpenAIServingEmbedding(OpenAIServing):
)
return await self.handle(ctx) # type: ignore[return-value]
def _create_pooling_params(
self,
ctx: EmbeddingServeContext,
) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
return pooling_params
try:
pooling_params.verify("embed", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
return pooling_params

View File

@@ -53,6 +53,7 @@ class PoolingCompletionRequest(
self.use_activation = self.normalize
return PoolingParams(
task=self.task,
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
dimensions=self.dimensions,
@@ -90,6 +91,7 @@ class PoolingChatRequest(
self.use_activation = self.normalize
return PoolingParams(
task=self.task,
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
dimensions=self.dimensions,
@@ -104,7 +106,7 @@ class IOProcessorRequest(PoolingBasicRequestMixin, EncodingRequestMixin, Generic
task: PoolingTask = "plugin"
def to_pooling_params(self):
return PoolingParams()
return PoolingParams(task=self.task)
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):

View File

@@ -35,7 +35,6 @@ from vllm.entrypoints.pooling.utils import (
)
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.tasks import PoolingTask, SupportedTask
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
@@ -48,7 +47,6 @@ class OpenAIServingPooling(OpenAIServing):
engine_client: EngineClient,
models: OpenAIServingModels,
*,
supported_tasks: tuple[SupportedTask, ...],
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
@@ -62,7 +60,6 @@ class OpenAIServingPooling(OpenAIServing):
log_error_stack=log_error_stack,
)
self.supported_tasks = supported_tasks
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
@@ -152,32 +149,6 @@ class OpenAIServingPooling(OpenAIServing):
else:
pooling_params = request.to_pooling_params()
pooling_task: PoolingTask
if request.task is None:
if "token_embed" in self.supported_tasks:
pooling_task = "token_embed"
elif "token_classify" in self.supported_tasks:
pooling_task = "token_classify"
elif "plugin" in self.supported_tasks:
pooling_task = "plugin"
else:
return self.create_error_response(
f"pooling_task must be one of {self.supported_tasks}."
)
else:
pooling_task = request.task
if pooling_task not in self.supported_tasks:
return self.create_error_response(
f"Task {pooling_task} is not supported, it"
f" must be one of {self.supported_tasks}."
)
try:
pooling_params.verify(pooling_task, self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
@@ -212,8 +183,7 @@ class OpenAIServingPooling(OpenAIServing):
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return self.create_error_response(e)
result_generator = merge_async_iterators(*generators)
@@ -251,8 +221,7 @@ class OpenAIServingPooling(OpenAIServing):
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return self.create_error_response(e)
return response

View File

@@ -18,6 +18,7 @@ from vllm.entrypoints.pooling.score.utils import (
ScoreInputs,
)
from vllm.renderers import TokenizeParams
from vllm.tasks import PoolingTask
from vllm.utils import random_uuid
@@ -40,8 +41,9 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
max_total_tokens_param="max_model_len",
)
def to_pooling_params(self):
def to_pooling_params(self, task: PoolingTask = "score"):
return PoolingParams(
task=task,
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)
@@ -122,6 +124,13 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
max_total_tokens_param="max_model_len",
)
def to_pooling_params(self, task: PoolingTask = "score"):
return PoolingParams(
task=task,
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)
class RerankDocument(BaseModel):
text: str | None = None

View File

@@ -118,12 +118,7 @@ class ServingScores(OpenAIServing):
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
pooling_params = request.to_pooling_params()
try:
pooling_params.verify("embed", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
pooling_params = request.to_pooling_params("embed")
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
@@ -223,19 +218,7 @@ class ServingScores(OpenAIServing):
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
# Use token_embed task for late interaction models
from vllm import PoolingParams
pooling_params = PoolingParams(
task="token_embed",
truncate_prompt_tokens=request.truncate_prompt_tokens,
use_activation=request.use_activation,
)
try:
pooling_params.verify("token_embed", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
pooling_params = request.to_pooling_params("token_embed")
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
@@ -358,12 +341,7 @@ class ServingScores(OpenAIServing):
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
default_pooling_params = request.to_pooling_params()
try:
default_pooling_params.verify("score", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
default_pooling_params = request.to_pooling_params("score")
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
@@ -497,8 +475,7 @@ class ServingScores(OpenAIServing):
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return self.create_error_response(e)
async def do_rerank(
self, request: RerankRequest, raw_request: Request | None = None
@@ -542,8 +519,7 @@ class ServingScores(OpenAIServing):
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return self.create_error_response(e)
def request_output_to_score_response(
self,

View File

@@ -72,15 +72,7 @@ class PoolingParams(
"""Returns a deep copy of the PoolingParams instance."""
return deepcopy(self)
def verify(
self, task: PoolingTask, model_config: "ModelConfig | None" = None
) -> None:
if self.task is None:
self.task = task
elif self.task != task:
msg = f"You cannot overwrite {self.task=!r} with {task=!r}!"
raise ValueError(msg)
def verify(self, model_config: "ModelConfig") -> None:
# plugin task uses io_processor.parse_request to verify inputs,
# skipping PoolingParams verify
if self.task == "plugin":
@@ -167,7 +159,7 @@ class PoolingParams(
if mds is not None:
if self.dimensions not in mds:
raise ValueError(
f'Model "{model_config.served_model_name}" '
f"Model {model_config.served_model_name!r} "
f"only supports {str(mds)} matryoshka dimensions, "
f"use other output dimensions will "
f"lead to poor results."
@@ -179,7 +171,7 @@ class PoolingParams(
if self.use_activation is None:
self.use_activation = True
else:
raise ValueError(f"Unknown pooling task: {self.task}")
raise ValueError(f"Unknown pooling task: {self.task!r}")
def _verify_valid_parameters(self):
assert self.task is not None, "task must be set"
@@ -194,7 +186,7 @@ class PoolingParams(
if invalid_parameters:
raise ValueError(
f"Task {self.task} only supports {valid_parameters} "
f"Task {self.task!r} only supports {valid_parameters} "
f"parameters, does not support "
f"{invalid_parameters} parameters"
)

View File

@@ -269,7 +269,11 @@ class AsyncLLM(EngineClient):
cancel_task_threadsafe(handler)
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return await self.engine_core.get_supported_tasks_async()
if not hasattr(self, "_supported_tasks"):
# Cache the result
self._supported_tasks = await self.engine_core.get_supported_tasks_async()
return self._supported_tasks
async def add_request(
self,
@@ -355,6 +359,7 @@ class AsyncLLM(EngineClient):
trace_headers=trace_headers,
priority=priority,
data_parallel_rank=data_parallel_rank,
supported_tasks=await self.get_supported_tasks(),
)
prompt_text = get_prompt_text(prompt)

View File

@@ -31,6 +31,7 @@ from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer
from vllm.sampling_params import _SAMPLING_EPS, SamplingParams
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
@@ -196,13 +197,41 @@ class InputProcessor:
def _validate_params(
self,
params: SamplingParams | PoolingParams,
# TODO: Validate generation tasks as well once `supported_tasks`
# is passed to all `process_inputs` calls
supported_tasks: tuple[SupportedTask, ...] | None,
):
"""
Validate supported SamplingParam.
Should raise ValueError if unsupported for API Server.
"""
if isinstance(params, PoolingParams):
if supported_tasks is None:
raise RuntimeError("`supported_tasks` must be passed for pooling")
supported_pooling_tasks = [
task for task in supported_tasks if task in POOLING_TASKS
]
if params.task is None:
if not supported_pooling_tasks:
raise ValueError("Pooling tasks are not supported")
if "token_embed" in supported_pooling_tasks:
params.task = "token_embed"
elif "token_classify" in supported_pooling_tasks:
params.task = "token_classify"
elif "plugin" in supported_pooling_tasks:
params.task = "plugin"
if params.task not in supported_pooling_tasks:
raise ValueError(
f"Unsupported task: {params.task!r} "
f"Supported tasks: {supported_pooling_tasks}"
)
params.verify(self.model_config)
return
self._validate_logprobs(params)
@@ -498,10 +527,11 @@ class InputProcessor:
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
data_parallel_rank: int | None = None,
supported_tasks: tuple[SupportedTask, ...] | None = None,
resumable: bool = False,
) -> EngineCoreRequest:
self._validate_lora(lora_request)
self._validate_params(params)
self._validate_params(params, supported_tasks)
parallel_config = self.vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size

View File

@@ -201,7 +201,11 @@ class LLMEngine:
return outputs
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.engine_core.get_supported_tasks()
if not hasattr(self, "_supported_tasks"):
# Cache the result
self._supported_tasks = self.engine_core.get_supported_tasks()
return self._supported_tasks
def abort_request(self, request_ids: list[str], internal: bool = False) -> None:
"""Remove request_ids from EngineCore and Detokenizer."""
@@ -245,6 +249,7 @@ class LLMEngine:
tokenization_kwargs,
trace_headers,
priority,
supported_tasks=self.get_supported_tasks(),
)
prompt_text = get_prompt_text(prompt)

View File

@@ -5037,7 +5037,7 @@ class GPUModelRunner(
model = cast(VllmModelForPooling, self.get_model())
dummy_pooling_params = PoolingParams(task=task)
dummy_pooling_params.verify(task=task, model_config=self.model_config)
dummy_pooling_params.verify(self.model_config)
to_update = model.pooler.get_pooling_updates(task)
to_update.apply(dummy_pooling_params)