[Refactor] Move task outside of PoolingParams.verify (#33796)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
@@ -70,6 +70,7 @@ steps:
|
|||||||
- vllm/
|
- vllm/
|
||||||
- tests/test_inputs.py
|
- tests/test_inputs.py
|
||||||
- tests/test_outputs.py
|
- tests/test_outputs.py
|
||||||
|
- tests/test_pooling_params.py
|
||||||
- tests/multimodal
|
- tests/multimodal
|
||||||
- tests/renderers
|
- tests/renderers
|
||||||
- tests/standalone_tests/lazy_imports.py
|
- tests/standalone_tests/lazy_imports.py
|
||||||
@@ -82,6 +83,7 @@ steps:
|
|||||||
- python3 standalone_tests/lazy_imports.py
|
- python3 standalone_tests/lazy_imports.py
|
||||||
- pytest -v -s test_inputs.py
|
- pytest -v -s test_inputs.py
|
||||||
- pytest -v -s test_outputs.py
|
- pytest -v -s test_outputs.py
|
||||||
|
- pytest -v -s test_pooling_params.py
|
||||||
- pytest -v -s -m 'cpu_test' multimodal
|
- pytest -v -s -m 'cpu_test' multimodal
|
||||||
- pytest -v -s renderers
|
- pytest -v -s renderers
|
||||||
- pytest -v -s tokenizers_
|
- pytest -v -s tokenizers_
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ steps:
|
|||||||
- vllm/
|
- vllm/
|
||||||
- tests/test_inputs.py
|
- tests/test_inputs.py
|
||||||
- tests/test_outputs.py
|
- tests/test_outputs.py
|
||||||
|
- tests/test_pooling_params.py
|
||||||
- tests/multimodal
|
- tests/multimodal
|
||||||
- tests/renderers
|
- tests/renderers
|
||||||
- tests/standalone_tests/lazy_imports.py
|
- tests/standalone_tests/lazy_imports.py
|
||||||
@@ -75,6 +76,7 @@ steps:
|
|||||||
- python3 standalone_tests/lazy_imports.py
|
- python3 standalone_tests/lazy_imports.py
|
||||||
- pytest -v -s test_inputs.py
|
- pytest -v -s test_inputs.py
|
||||||
- pytest -v -s test_outputs.py
|
- pytest -v -s test_outputs.py
|
||||||
|
- pytest -v -s test_pooling_params.py
|
||||||
- pytest -v -s -m 'cpu_test' multimodal
|
- pytest -v -s -m 'cpu_test' multimodal
|
||||||
- pytest -v -s renderers
|
- pytest -v -s renderers
|
||||||
- pytest -v -s tokenizers_
|
- pytest -v -s tokenizers_
|
||||||
|
|||||||
@@ -122,6 +122,7 @@ steps:
|
|||||||
- vllm/
|
- vllm/
|
||||||
- tests/test_inputs.py
|
- tests/test_inputs.py
|
||||||
- tests/test_outputs.py
|
- tests/test_outputs.py
|
||||||
|
- tests/test_pooling_params.py
|
||||||
- tests/multimodal
|
- tests/multimodal
|
||||||
- tests/renderers
|
- tests/renderers
|
||||||
- tests/standalone_tests/lazy_imports.py
|
- tests/standalone_tests/lazy_imports.py
|
||||||
@@ -134,6 +135,7 @@ steps:
|
|||||||
- python3 standalone_tests/lazy_imports.py
|
- python3 standalone_tests/lazy_imports.py
|
||||||
- pytest -v -s test_inputs.py
|
- pytest -v -s test_inputs.py
|
||||||
- pytest -v -s test_outputs.py
|
- pytest -v -s test_outputs.py
|
||||||
|
- pytest -v -s test_pooling_params.py
|
||||||
- pytest -v -s -m 'cpu_test' multimodal
|
- pytest -v -s -m 'cpu_test' multimodal
|
||||||
- pytest -v -s renderers
|
- pytest -v -s renderers
|
||||||
- pytest -v -s tokenizers_
|
- pytest -v -s tokenizers_
|
||||||
|
|||||||
@@ -469,6 +469,4 @@ async def test_pooling_not_supported(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.json()["error"]["type"] == "BadRequestError"
|
assert response.json()["error"]["type"] == "BadRequestError"
|
||||||
assert response.json()["error"]["message"].startswith(
|
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
|
||||||
f"Task {task} is not supported"
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -757,6 +757,4 @@ async def test_pooling_not_supported(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.json()["error"]["type"] == "BadRequestError"
|
assert response.json()["error"]["type"] == "BadRequestError"
|
||||||
assert response.json()["error"]["message"].startswith(
|
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
|
||||||
f"Task {task} is not supported"
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -138,17 +138,17 @@ def test_colbert_token_embed(server: RemoteOpenAIServer, model_name: str):
|
|||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
def test_colbert_embed_not_supported(server: RemoteOpenAIServer, model_name: str):
|
def test_colbert_embed_not_supported(server: RemoteOpenAIServer, model_name: str):
|
||||||
"""Test that ColBERT model does not support 'embed' task."""
|
"""Test that ColBERT model does not support 'embed' task."""
|
||||||
|
task = "embed"
|
||||||
text = "What is the capital of France?"
|
text = "What is the capital of France?"
|
||||||
|
|
||||||
pooling_response = requests.post(
|
response = requests.post(
|
||||||
server.url_for("pooling"),
|
server.url_for("pooling"),
|
||||||
json={
|
json={
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
"input": text,
|
"input": text,
|
||||||
"task": "embed",
|
"task": task,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should return error
|
assert response.json()["error"]["type"] == "BadRequestError"
|
||||||
assert pooling_response.status_code == 400
|
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
|
||||||
assert "Task embed is not supported" in pooling_response.text
|
|
||||||
|
|||||||
@@ -232,6 +232,4 @@ async def test_pooling_not_supported(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.json()["error"]["type"] == "BadRequestError"
|
assert response.json()["error"]["type"] == "BadRequestError"
|
||||||
assert response.json()["error"]["message"].startswith(
|
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
|
||||||
f"Task {task} is not supported"
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -27,35 +27,24 @@ class MockModelConfig:
|
|||||||
pooler_config: PoolerConfig
|
pooler_config: PoolerConfig
|
||||||
|
|
||||||
|
|
||||||
def test_task():
|
|
||||||
pooling_params = PoolingParams()
|
|
||||||
pooling_params.verify(task="score")
|
|
||||||
|
|
||||||
pooling_params = PoolingParams(task="score")
|
|
||||||
pooling_params.verify(task="score")
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
pooling_params.verify(task="classify")
|
|
||||||
|
|
||||||
|
|
||||||
def test_embed():
|
def test_embed():
|
||||||
task = "embed"
|
task = "embed"
|
||||||
model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
|
model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
|
||||||
|
|
||||||
pooling_params = PoolingParams(use_activation=None)
|
pooling_params = PoolingParams(task=task, use_activation=None)
|
||||||
pooling_params.verify(task=task, model_config=model_config)
|
pooling_params.verify(model_config)
|
||||||
|
|
||||||
pooling_params = PoolingParams(use_activation=True)
|
pooling_params = PoolingParams(task=task, use_activation=True)
|
||||||
pooling_params.verify(task=task, model_config=model_config)
|
pooling_params.verify(model_config)
|
||||||
|
|
||||||
pooling_params = PoolingParams(use_activation=False)
|
pooling_params = PoolingParams(task=task, use_activation=False)
|
||||||
pooling_params.verify(task=task, model_config=model_config)
|
pooling_params.verify(model_config)
|
||||||
|
|
||||||
invalid_parameters = classify_parameters + step_pooling_parameters
|
invalid_parameters = classify_parameters + step_pooling_parameters
|
||||||
for p in set(invalid_parameters) - set(embed_parameters):
|
for p in set(invalid_parameters) - set(embed_parameters):
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
pooling_params = PoolingParams(**{p: True})
|
pooling_params = PoolingParams(task=task, **{p: True})
|
||||||
pooling_params.verify(task=task, model_config=model_config)
|
pooling_params.verify(model_config)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
|
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
|
||||||
@@ -63,7 +52,6 @@ def test_embed_dimensions(model_info: EmbedModelInfo):
|
|||||||
task = "embed"
|
task = "embed"
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig(
|
||||||
model_info.name,
|
model_info.name,
|
||||||
task="auto",
|
|
||||||
tokenizer=model_info.name,
|
tokenizer=model_info.name,
|
||||||
tokenizer_mode="auto",
|
tokenizer_mode="auto",
|
||||||
trust_remote_code=False,
|
trust_remote_code=False,
|
||||||
@@ -71,37 +59,39 @@ def test_embed_dimensions(model_info: EmbedModelInfo):
|
|||||||
dtype="float16",
|
dtype="float16",
|
||||||
)
|
)
|
||||||
|
|
||||||
pooling_params = PoolingParams(dimensions=None)
|
pooling_params = PoolingParams(task=task, dimensions=None)
|
||||||
pooling_params.verify(task=task, model_config=model_config)
|
pooling_params.verify(model_config)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
pooling_params = PoolingParams(dimensions=1)
|
pooling_params = PoolingParams(task=task, dimensions=1)
|
||||||
pooling_params.verify(task=task, model_config=model_config)
|
pooling_params.verify(model_config)
|
||||||
|
|
||||||
if model_info.is_matryoshka:
|
if model_info.is_matryoshka:
|
||||||
assert model_info.matryoshka_dimensions is not None
|
assert model_info.matryoshka_dimensions is not None
|
||||||
pooling_params = PoolingParams(dimensions=model_info.matryoshka_dimensions[0])
|
pooling_params = PoolingParams(
|
||||||
pooling_params.verify(task=task, model_config=model_config)
|
task=task, dimensions=model_info.matryoshka_dimensions[0]
|
||||||
|
)
|
||||||
|
pooling_params.verify(model_config)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("task", ["score", "classify"])
|
@pytest.mark.parametrize("task", ["score", "classify"])
|
||||||
def test_classify(task):
|
def test_classify(task):
|
||||||
model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
|
model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
|
||||||
|
|
||||||
pooling_params = PoolingParams(use_activation=None)
|
pooling_params = PoolingParams(task=task, use_activation=None)
|
||||||
pooling_params.verify(task=task, model_config=model_config)
|
pooling_params.verify(model_config)
|
||||||
|
|
||||||
pooling_params = PoolingParams(use_activation=True)
|
pooling_params = PoolingParams(task=task, use_activation=True)
|
||||||
pooling_params.verify(task=task, model_config=model_config)
|
pooling_params.verify(model_config)
|
||||||
|
|
||||||
pooling_params = PoolingParams(use_activation=False)
|
pooling_params = PoolingParams(task=task, use_activation=False)
|
||||||
pooling_params.verify(task=task, model_config=model_config)
|
pooling_params.verify(model_config)
|
||||||
|
|
||||||
invalid_parameters = embed_parameters + step_pooling_parameters
|
invalid_parameters = embed_parameters + step_pooling_parameters
|
||||||
for p in set(invalid_parameters) - set(classify_parameters):
|
for p in set(invalid_parameters) - set(classify_parameters):
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
pooling_params = PoolingParams(**{p: True})
|
pooling_params = PoolingParams(task=task, **{p: True})
|
||||||
pooling_params.verify(task=task, model_config=model_config)
|
pooling_params.verify(model_config)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("pooling_type", ["ALL", "STEP"])
|
@pytest.mark.parametrize("pooling_type", ["ALL", "STEP"])
|
||||||
@@ -111,14 +101,14 @@ def test_token_embed(pooling_type: str):
|
|||||||
pooler_config=PoolerConfig(tok_pooling_type=pooling_type)
|
pooler_config=PoolerConfig(tok_pooling_type=pooling_type)
|
||||||
)
|
)
|
||||||
|
|
||||||
pooling_params = PoolingParams(use_activation=None)
|
pooling_params = PoolingParams(task=task, use_activation=None)
|
||||||
pooling_params.verify(task=task, model_config=model_config)
|
pooling_params.verify(model_config)
|
||||||
|
|
||||||
pooling_params = PoolingParams(use_activation=True)
|
pooling_params = PoolingParams(task=task, use_activation=True)
|
||||||
pooling_params.verify(task=task, model_config=model_config)
|
pooling_params.verify(model_config)
|
||||||
|
|
||||||
pooling_params = PoolingParams(use_activation=False)
|
pooling_params = PoolingParams(task=task, use_activation=False)
|
||||||
pooling_params.verify(task=task, model_config=model_config)
|
pooling_params.verify(model_config)
|
||||||
|
|
||||||
invalid_parameters = classify_parameters
|
invalid_parameters = classify_parameters
|
||||||
if pooling_type != "STEP":
|
if pooling_type != "STEP":
|
||||||
@@ -126,8 +116,8 @@ def test_token_embed(pooling_type: str):
|
|||||||
|
|
||||||
for p in set(invalid_parameters) - set(embed_parameters):
|
for p in set(invalid_parameters) - set(embed_parameters):
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
pooling_params = PoolingParams(**{p: True})
|
pooling_params = PoolingParams(task=task, **{p: True})
|
||||||
pooling_params.verify(task=task, model_config=model_config)
|
pooling_params.verify(model_config)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("pooling_type", ["ALL", "STEP"])
|
@pytest.mark.parametrize("pooling_type", ["ALL", "STEP"])
|
||||||
@@ -137,14 +127,14 @@ def test_token_classify(pooling_type: str):
|
|||||||
pooler_config=PoolerConfig(tok_pooling_type=pooling_type)
|
pooler_config=PoolerConfig(tok_pooling_type=pooling_type)
|
||||||
)
|
)
|
||||||
|
|
||||||
pooling_params = PoolingParams(use_activation=None)
|
pooling_params = PoolingParams(task=task, use_activation=None)
|
||||||
pooling_params.verify(task=task, model_config=model_config)
|
pooling_params.verify(model_config)
|
||||||
|
|
||||||
pooling_params = PoolingParams(use_activation=True)
|
pooling_params = PoolingParams(task=task, use_activation=True)
|
||||||
pooling_params.verify(task=task, model_config=model_config)
|
pooling_params.verify(model_config)
|
||||||
|
|
||||||
pooling_params = PoolingParams(use_activation=False)
|
pooling_params = PoolingParams(task=task, use_activation=False)
|
||||||
pooling_params.verify(task=task, model_config=model_config)
|
pooling_params.verify(model_config)
|
||||||
|
|
||||||
invalid_parameters = embed_parameters
|
invalid_parameters = embed_parameters
|
||||||
if pooling_type != "STEP":
|
if pooling_type != "STEP":
|
||||||
@@ -152,5 +142,5 @@ def test_token_classify(pooling_type: str):
|
|||||||
|
|
||||||
for p in set(invalid_parameters) - set(classify_parameters):
|
for p in set(invalid_parameters) - set(classify_parameters):
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
pooling_params = PoolingParams(**{p: True})
|
pooling_params = PoolingParams(task=task, **{p: True})
|
||||||
pooling_params.verify(task=task, model_config=model_config)
|
pooling_params.verify(model_config)
|
||||||
|
|||||||
@@ -1135,11 +1135,12 @@ class LLM:
|
|||||||
# Use default pooling params.
|
# Use default pooling params.
|
||||||
pooling_params = PoolingParams()
|
pooling_params = PoolingParams()
|
||||||
|
|
||||||
if pooling_task not in self.supported_tasks:
|
|
||||||
raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")
|
|
||||||
|
|
||||||
for param in as_iter(pooling_params):
|
for param in as_iter(pooling_params):
|
||||||
param.verify(pooling_task, model_config)
|
if param.task is None:
|
||||||
|
param.task = pooling_task
|
||||||
|
elif param.task != pooling_task:
|
||||||
|
msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
self._validate_and_add_requests(
|
self._validate_and_add_requests(
|
||||||
prompts=prompts,
|
prompts=prompts,
|
||||||
@@ -1472,8 +1473,9 @@ class LLM:
|
|||||||
|
|
||||||
if pooling_params is None:
|
if pooling_params is None:
|
||||||
pooling_params = PoolingParams(task="score")
|
pooling_params = PoolingParams(task="score")
|
||||||
|
elif pooling_params.task is None:
|
||||||
|
pooling_params.task = "score"
|
||||||
|
|
||||||
pooling_params.verify("score", model_config)
|
|
||||||
pooling_params_list = list[PoolingParams]()
|
pooling_params_list = list[PoolingParams]()
|
||||||
|
|
||||||
prompts = list[PromptType]()
|
prompts = list[PromptType]()
|
||||||
@@ -1836,6 +1838,7 @@ class LLM:
|
|||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
priority=priority,
|
priority=priority,
|
||||||
|
supported_tasks=self.supported_tasks,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.llm_engine.add_request(
|
self.llm_engine.add_request(
|
||||||
|
|||||||
@@ -68,7 +68,6 @@ def init_pooling_state(
|
|||||||
OpenAIServingPooling(
|
OpenAIServingPooling(
|
||||||
engine_client,
|
engine_client,
|
||||||
state.openai_serving_models,
|
state.openai_serving_models,
|
||||||
supported_tasks=supported_tasks,
|
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
chat_template=resolved_chat_template,
|
chat_template=resolved_chat_template,
|
||||||
chat_template_content_format=args.chat_template_content_format,
|
chat_template_content_format=args.chat_template_content_format,
|
||||||
@@ -76,7 +75,7 @@ def init_pooling_state(
|
|||||||
log_error_stack=args.log_error_stack,
|
log_error_stack=args.log_error_stack,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if any(task in POOLING_TASKS for task in supported_tasks)
|
if any(t in supported_tasks for t in POOLING_TASKS)
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
state.openai_serving_embedding = (
|
state.openai_serving_embedding = (
|
||||||
|
|||||||
@@ -6,19 +6,15 @@ from typing import Annotated, Any
|
|||||||
|
|
||||||
from pydantic import Field, model_validator
|
from pydantic import Field, model_validator
|
||||||
|
|
||||||
from vllm import PoolingParams
|
|
||||||
from vllm.entrypoints.chat_utils import (
|
from vllm.entrypoints.chat_utils import (
|
||||||
ChatCompletionMessageParam,
|
ChatCompletionMessageParam,
|
||||||
ChatTemplateContentFormatOption,
|
ChatTemplateContentFormatOption,
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
|
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.renderers import ChatParams, merge_kwargs
|
from vllm.renderers import ChatParams, merge_kwargs
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
|
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class PoolingBasicRequestMixin(OpenAIBaseModel):
|
class PoolingBasicRequestMixin(OpenAIBaseModel):
|
||||||
# --8<-- [start:pooling-common-params]
|
# --8<-- [start:pooling-common-params]
|
||||||
@@ -185,20 +181,6 @@ class EmbedRequestMixin(EncodingRequestMixin):
|
|||||||
)
|
)
|
||||||
# --8<-- [end:embed-extra-params]
|
# --8<-- [end:embed-extra-params]
|
||||||
|
|
||||||
def to_pooling_params(self):
|
|
||||||
if self.normalize is not None:
|
|
||||||
logger.warning_once(
|
|
||||||
"`normalize` is deprecated and will be removed in v0.17. "
|
|
||||||
"Please pass `use_activation` instead."
|
|
||||||
)
|
|
||||||
self.use_activation = self.normalize
|
|
||||||
|
|
||||||
return PoolingParams(
|
|
||||||
dimensions=self.dimensions,
|
|
||||||
use_activation=self.use_activation,
|
|
||||||
truncate_prompt_tokens=getattr(self, "truncate_prompt_tokens", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ClassifyRequestMixin(OpenAIBaseModel):
|
class ClassifyRequestMixin(OpenAIBaseModel):
|
||||||
# --8<-- [start:classify-extra-params]
|
# --8<-- [start:classify-extra-params]
|
||||||
@@ -208,9 +190,3 @@ class ClassifyRequestMixin(OpenAIBaseModel):
|
|||||||
"`None` uses the pooler's default, which is `True` in most cases.",
|
"`None` uses the pooler's default, which is `True` in most cases.",
|
||||||
)
|
)
|
||||||
# --8<-- [end:classify-extra-params]
|
# --8<-- [end:classify-extra-params]
|
||||||
|
|
||||||
def to_pooling_params(self):
|
|
||||||
return PoolingParams(
|
|
||||||
use_activation=self.use_activation,
|
|
||||||
truncate_prompt_tokens=getattr(self, "truncate_prompt_tokens", None),
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import Any, TypeAlias
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
|
from vllm import PoolingParams
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
|
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
|
||||||
from vllm.entrypoints.pooling.base.protocol import (
|
from vllm.entrypoints.pooling.base.protocol import (
|
||||||
@@ -14,9 +15,12 @@ from vllm.entrypoints.pooling.base.protocol import (
|
|||||||
CompletionRequestMixin,
|
CompletionRequestMixin,
|
||||||
PoolingBasicRequestMixin,
|
PoolingBasicRequestMixin,
|
||||||
)
|
)
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.renderers import TokenizeParams
|
from vllm.renderers import TokenizeParams
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ClassificationCompletionRequest(
|
class ClassificationCompletionRequest(
|
||||||
PoolingBasicRequestMixin, CompletionRequestMixin, ClassifyRequestMixin
|
PoolingBasicRequestMixin, CompletionRequestMixin, ClassifyRequestMixin
|
||||||
@@ -33,6 +37,13 @@ class ClassificationCompletionRequest(
|
|||||||
max_total_tokens_param="max_model_len",
|
max_total_tokens_param="max_model_len",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_pooling_params(self):
|
||||||
|
return PoolingParams(
|
||||||
|
task="classify",
|
||||||
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
|
use_activation=self.use_activation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ClassificationChatRequest(
|
class ClassificationChatRequest(
|
||||||
PoolingBasicRequestMixin, ChatRequestMixin, ClassifyRequestMixin
|
PoolingBasicRequestMixin, ChatRequestMixin, ClassifyRequestMixin
|
||||||
@@ -55,6 +66,13 @@ class ClassificationChatRequest(
|
|||||||
max_total_tokens_param="max_model_len",
|
max_total_tokens_param="max_model_len",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_pooling_params(self):
|
||||||
|
return PoolingParams(
|
||||||
|
task="classify",
|
||||||
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
|
use_activation=self.use_activation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
ClassificationRequest: TypeAlias = (
|
ClassificationRequest: TypeAlias = (
|
||||||
ClassificationCompletionRequest | ClassificationChatRequest
|
ClassificationCompletionRequest | ClassificationChatRequest
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ from vllm.entrypoints.pooling.classify.protocol import (
|
|||||||
)
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import ClassificationOutput
|
from vllm.outputs import ClassificationOutput
|
||||||
from vllm.pooling_params import PoolingParams
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -159,18 +158,3 @@ class ServingClassification(OpenAIServing):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return await self.handle(ctx) # type: ignore[return-value]
|
return await self.handle(ctx) # type: ignore[return-value]
|
||||||
|
|
||||||
def _create_pooling_params(
|
|
||||||
self,
|
|
||||||
ctx: ClassificationServeContext,
|
|
||||||
) -> PoolingParams | ErrorResponse:
|
|
||||||
pooling_params = super()._create_pooling_params(ctx)
|
|
||||||
if isinstance(pooling_params, ErrorResponse):
|
|
||||||
return pooling_params
|
|
||||||
|
|
||||||
try:
|
|
||||||
pooling_params.verify("classify", self.model_config)
|
|
||||||
except ValueError as e:
|
|
||||||
return self.create_error_response(str(e))
|
|
||||||
|
|
||||||
return pooling_params
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from typing import Any, TypeAlias
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
|
from vllm import PoolingParams
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
|
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
|
||||||
from vllm.entrypoints.pooling.base.protocol import (
|
from vllm.entrypoints.pooling.base.protocol import (
|
||||||
@@ -13,9 +14,12 @@ from vllm.entrypoints.pooling.base.protocol import (
|
|||||||
EmbedRequestMixin,
|
EmbedRequestMixin,
|
||||||
PoolingBasicRequestMixin,
|
PoolingBasicRequestMixin,
|
||||||
)
|
)
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.renderers import TokenizeParams
|
from vllm.renderers import TokenizeParams
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _get_max_total_output_tokens(
|
def _get_max_total_output_tokens(
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
@@ -55,6 +59,21 @@ class EmbeddingCompletionRequest(
|
|||||||
max_output_tokens_param="max_model_len - max_embed_len",
|
max_output_tokens_param="max_model_len - max_embed_len",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_pooling_params(self):
|
||||||
|
if self.normalize is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"`normalize` is deprecated and will be removed in v0.17. "
|
||||||
|
"Please pass `use_activation` instead."
|
||||||
|
)
|
||||||
|
self.use_activation = self.normalize
|
||||||
|
|
||||||
|
return PoolingParams(
|
||||||
|
task="embed",
|
||||||
|
dimensions=self.dimensions,
|
||||||
|
use_activation=self.use_activation,
|
||||||
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingChatRequest(
|
class EmbeddingChatRequest(
|
||||||
PoolingBasicRequestMixin, ChatRequestMixin, EmbedRequestMixin
|
PoolingBasicRequestMixin, ChatRequestMixin, EmbedRequestMixin
|
||||||
@@ -82,6 +101,21 @@ class EmbeddingChatRequest(
|
|||||||
max_output_tokens_param="max_model_len - max_embed_len",
|
max_output_tokens_param="max_model_len - max_embed_len",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_pooling_params(self):
|
||||||
|
if self.normalize is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"`normalize` is deprecated and will be removed in v0.17. "
|
||||||
|
"Please pass `use_activation` instead."
|
||||||
|
)
|
||||||
|
self.use_activation = self.normalize
|
||||||
|
|
||||||
|
return PoolingParams(
|
||||||
|
task="embed",
|
||||||
|
dimensions=self.dimensions,
|
||||||
|
use_activation=self.use_activation,
|
||||||
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
|
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
|
||||||
|
|
||||||
|
|||||||
@@ -424,12 +424,6 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||||||
if isinstance(pooling_params, ErrorResponse):
|
if isinstance(pooling_params, ErrorResponse):
|
||||||
return pooling_params
|
return pooling_params
|
||||||
|
|
||||||
# Verify and set the task for pooling params
|
|
||||||
try:
|
|
||||||
pooling_params.verify("embed", self.model_config)
|
|
||||||
except ValueError as e:
|
|
||||||
return self.create_error_response(str(e))
|
|
||||||
|
|
||||||
if ctx.engine_prompts is None:
|
if ctx.engine_prompts is None:
|
||||||
return self.create_error_response("Engine prompts not available")
|
return self.create_error_response("Engine prompts not available")
|
||||||
|
|
||||||
@@ -463,8 +457,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
return self.create_error_response(e)
|
||||||
return self.create_error_response(str(e))
|
|
||||||
|
|
||||||
async def _collect_batch(
|
async def _collect_batch(
|
||||||
self,
|
self,
|
||||||
@@ -634,7 +627,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(e)
|
||||||
|
|
||||||
async def create_embedding(
|
async def create_embedding(
|
||||||
self,
|
self,
|
||||||
@@ -661,18 +654,3 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return await self.handle(ctx) # type: ignore[return-value]
|
return await self.handle(ctx) # type: ignore[return-value]
|
||||||
|
|
||||||
def _create_pooling_params(
|
|
||||||
self,
|
|
||||||
ctx: EmbeddingServeContext,
|
|
||||||
) -> PoolingParams | ErrorResponse:
|
|
||||||
pooling_params = super()._create_pooling_params(ctx)
|
|
||||||
if isinstance(pooling_params, ErrorResponse):
|
|
||||||
return pooling_params
|
|
||||||
|
|
||||||
try:
|
|
||||||
pooling_params.verify("embed", self.model_config)
|
|
||||||
except ValueError as e:
|
|
||||||
return self.create_error_response(str(e))
|
|
||||||
|
|
||||||
return pooling_params
|
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ class PoolingCompletionRequest(
|
|||||||
self.use_activation = self.normalize
|
self.use_activation = self.normalize
|
||||||
|
|
||||||
return PoolingParams(
|
return PoolingParams(
|
||||||
|
task=self.task,
|
||||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
use_activation=self.use_activation,
|
use_activation=self.use_activation,
|
||||||
dimensions=self.dimensions,
|
dimensions=self.dimensions,
|
||||||
@@ -90,6 +91,7 @@ class PoolingChatRequest(
|
|||||||
self.use_activation = self.normalize
|
self.use_activation = self.normalize
|
||||||
|
|
||||||
return PoolingParams(
|
return PoolingParams(
|
||||||
|
task=self.task,
|
||||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
use_activation=self.use_activation,
|
use_activation=self.use_activation,
|
||||||
dimensions=self.dimensions,
|
dimensions=self.dimensions,
|
||||||
@@ -104,7 +106,7 @@ class IOProcessorRequest(PoolingBasicRequestMixin, EncodingRequestMixin, Generic
|
|||||||
task: PoolingTask = "plugin"
|
task: PoolingTask = "plugin"
|
||||||
|
|
||||||
def to_pooling_params(self):
|
def to_pooling_params(self):
|
||||||
return PoolingParams()
|
return PoolingParams(task=self.task)
|
||||||
|
|
||||||
|
|
||||||
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
|
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
|
||||||
|
|||||||
@@ -35,7 +35,6 @@ from vllm.entrypoints.pooling.utils import (
|
|||||||
)
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import PoolingRequestOutput
|
from vllm.outputs import PoolingRequestOutput
|
||||||
from vllm.tasks import PoolingTask, SupportedTask
|
|
||||||
from vllm.utils.async_utils import merge_async_iterators
|
from vllm.utils.async_utils import merge_async_iterators
|
||||||
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
|
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
|
||||||
|
|
||||||
@@ -48,7 +47,6 @@ class OpenAIServingPooling(OpenAIServing):
|
|||||||
engine_client: EngineClient,
|
engine_client: EngineClient,
|
||||||
models: OpenAIServingModels,
|
models: OpenAIServingModels,
|
||||||
*,
|
*,
|
||||||
supported_tasks: tuple[SupportedTask, ...],
|
|
||||||
request_logger: RequestLogger | None,
|
request_logger: RequestLogger | None,
|
||||||
chat_template: str | None,
|
chat_template: str | None,
|
||||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||||
@@ -62,7 +60,6 @@ class OpenAIServingPooling(OpenAIServing):
|
|||||||
log_error_stack=log_error_stack,
|
log_error_stack=log_error_stack,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.supported_tasks = supported_tasks
|
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
self.chat_template_content_format: Final = chat_template_content_format
|
self.chat_template_content_format: Final = chat_template_content_format
|
||||||
self.trust_request_chat_template = trust_request_chat_template
|
self.trust_request_chat_template = trust_request_chat_template
|
||||||
@@ -152,32 +149,6 @@ class OpenAIServingPooling(OpenAIServing):
|
|||||||
else:
|
else:
|
||||||
pooling_params = request.to_pooling_params()
|
pooling_params = request.to_pooling_params()
|
||||||
|
|
||||||
pooling_task: PoolingTask
|
|
||||||
if request.task is None:
|
|
||||||
if "token_embed" in self.supported_tasks:
|
|
||||||
pooling_task = "token_embed"
|
|
||||||
elif "token_classify" in self.supported_tasks:
|
|
||||||
pooling_task = "token_classify"
|
|
||||||
elif "plugin" in self.supported_tasks:
|
|
||||||
pooling_task = "plugin"
|
|
||||||
else:
|
|
||||||
return self.create_error_response(
|
|
||||||
f"pooling_task must be one of {self.supported_tasks}."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
pooling_task = request.task
|
|
||||||
|
|
||||||
if pooling_task not in self.supported_tasks:
|
|
||||||
return self.create_error_response(
|
|
||||||
f"Task {pooling_task} is not supported, it"
|
|
||||||
f" must be one of {self.supported_tasks}."
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
pooling_params.verify(pooling_task, self.model_config)
|
|
||||||
except ValueError as e:
|
|
||||||
return self.create_error_response(str(e))
|
|
||||||
|
|
||||||
for i, engine_prompt in enumerate(engine_prompts):
|
for i, engine_prompt in enumerate(engine_prompts):
|
||||||
request_id_item = f"{request_id}-{i}"
|
request_id_item = f"{request_id}-{i}"
|
||||||
|
|
||||||
@@ -212,8 +183,7 @@ class OpenAIServingPooling(OpenAIServing):
|
|||||||
|
|
||||||
generators.append(generator)
|
generators.append(generator)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
return self.create_error_response(e)
|
||||||
return self.create_error_response(str(e))
|
|
||||||
|
|
||||||
result_generator = merge_async_iterators(*generators)
|
result_generator = merge_async_iterators(*generators)
|
||||||
|
|
||||||
@@ -251,8 +221,7 @@ class OpenAIServingPooling(OpenAIServing):
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
return self.create_error_response("Client disconnected")
|
return self.create_error_response("Client disconnected")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
return self.create_error_response(e)
|
||||||
return self.create_error_response(str(e))
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from vllm.entrypoints.pooling.score.utils import (
|
|||||||
ScoreInputs,
|
ScoreInputs,
|
||||||
)
|
)
|
||||||
from vllm.renderers import TokenizeParams
|
from vllm.renderers import TokenizeParams
|
||||||
|
from vllm.tasks import PoolingTask
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
|
|
||||||
@@ -40,8 +41,9 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
|
|||||||
max_total_tokens_param="max_model_len",
|
max_total_tokens_param="max_model_len",
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_pooling_params(self):
|
def to_pooling_params(self, task: PoolingTask = "score"):
|
||||||
return PoolingParams(
|
return PoolingParams(
|
||||||
|
task=task,
|
||||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
use_activation=self.use_activation,
|
use_activation=self.use_activation,
|
||||||
)
|
)
|
||||||
@@ -122,6 +124,13 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
|
|||||||
max_total_tokens_param="max_model_len",
|
max_total_tokens_param="max_model_len",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_pooling_params(self, task: PoolingTask = "score"):
|
||||||
|
return PoolingParams(
|
||||||
|
task=task,
|
||||||
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
|
use_activation=self.use_activation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RerankDocument(BaseModel):
|
class RerankDocument(BaseModel):
|
||||||
text: str | None = None
|
text: str | None = None
|
||||||
|
|||||||
@@ -118,12 +118,7 @@ class ServingScores(OpenAIServing):
|
|||||||
|
|
||||||
# Schedule the request and get the result generator.
|
# Schedule the request and get the result generator.
|
||||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||||
pooling_params = request.to_pooling_params()
|
pooling_params = request.to_pooling_params("embed")
|
||||||
|
|
||||||
try:
|
|
||||||
pooling_params.verify("embed", self.model_config)
|
|
||||||
except ValueError as e:
|
|
||||||
return self.create_error_response(str(e))
|
|
||||||
|
|
||||||
for i, engine_prompt in enumerate(engine_prompts):
|
for i, engine_prompt in enumerate(engine_prompts):
|
||||||
request_id_item = f"{request_id}-{i}"
|
request_id_item = f"{request_id}-{i}"
|
||||||
@@ -223,19 +218,7 @@ class ServingScores(OpenAIServing):
|
|||||||
# Schedule the request and get the result generator.
|
# Schedule the request and get the result generator.
|
||||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||||
|
|
||||||
# Use token_embed task for late interaction models
|
pooling_params = request.to_pooling_params("token_embed")
|
||||||
from vllm import PoolingParams
|
|
||||||
|
|
||||||
pooling_params = PoolingParams(
|
|
||||||
task="token_embed",
|
|
||||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
|
||||||
use_activation=request.use_activation,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
pooling_params.verify("token_embed", self.model_config)
|
|
||||||
except ValueError as e:
|
|
||||||
return self.create_error_response(str(e))
|
|
||||||
|
|
||||||
for i, engine_prompt in enumerate(engine_prompts):
|
for i, engine_prompt in enumerate(engine_prompts):
|
||||||
request_id_item = f"{request_id}-{i}"
|
request_id_item = f"{request_id}-{i}"
|
||||||
@@ -358,12 +341,7 @@ class ServingScores(OpenAIServing):
|
|||||||
# Schedule the request and get the result generator.
|
# Schedule the request and get the result generator.
|
||||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||||
|
|
||||||
default_pooling_params = request.to_pooling_params()
|
default_pooling_params = request.to_pooling_params("score")
|
||||||
|
|
||||||
try:
|
|
||||||
default_pooling_params.verify("score", self.model_config)
|
|
||||||
except ValueError as e:
|
|
||||||
return self.create_error_response(str(e))
|
|
||||||
|
|
||||||
for i, engine_prompt in enumerate(engine_prompts):
|
for i, engine_prompt in enumerate(engine_prompts):
|
||||||
request_id_item = f"{request_id}-{i}"
|
request_id_item = f"{request_id}-{i}"
|
||||||
@@ -497,8 +475,7 @@ class ServingScores(OpenAIServing):
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
return self.create_error_response("Client disconnected")
|
return self.create_error_response("Client disconnected")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
return self.create_error_response(e)
|
||||||
return self.create_error_response(str(e))
|
|
||||||
|
|
||||||
async def do_rerank(
|
async def do_rerank(
|
||||||
self, request: RerankRequest, raw_request: Request | None = None
|
self, request: RerankRequest, raw_request: Request | None = None
|
||||||
@@ -542,8 +519,7 @@ class ServingScores(OpenAIServing):
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
return self.create_error_response("Client disconnected")
|
return self.create_error_response("Client disconnected")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
return self.create_error_response(e)
|
||||||
return self.create_error_response(str(e))
|
|
||||||
|
|
||||||
def request_output_to_score_response(
|
def request_output_to_score_response(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -72,15 +72,7 @@ class PoolingParams(
|
|||||||
"""Returns a deep copy of the PoolingParams instance."""
|
"""Returns a deep copy of the PoolingParams instance."""
|
||||||
return deepcopy(self)
|
return deepcopy(self)
|
||||||
|
|
||||||
def verify(
|
def verify(self, model_config: "ModelConfig") -> None:
|
||||||
self, task: PoolingTask, model_config: "ModelConfig | None" = None
|
|
||||||
) -> None:
|
|
||||||
if self.task is None:
|
|
||||||
self.task = task
|
|
||||||
elif self.task != task:
|
|
||||||
msg = f"You cannot overwrite {self.task=!r} with {task=!r}!"
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
# plugin task uses io_processor.parse_request to verify inputs,
|
# plugin task uses io_processor.parse_request to verify inputs,
|
||||||
# skipping PoolingParams verify
|
# skipping PoolingParams verify
|
||||||
if self.task == "plugin":
|
if self.task == "plugin":
|
||||||
@@ -167,7 +159,7 @@ class PoolingParams(
|
|||||||
if mds is not None:
|
if mds is not None:
|
||||||
if self.dimensions not in mds:
|
if self.dimensions not in mds:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'Model "{model_config.served_model_name}" '
|
f"Model {model_config.served_model_name!r} "
|
||||||
f"only supports {str(mds)} matryoshka dimensions, "
|
f"only supports {str(mds)} matryoshka dimensions, "
|
||||||
f"use other output dimensions will "
|
f"use other output dimensions will "
|
||||||
f"lead to poor results."
|
f"lead to poor results."
|
||||||
@@ -179,7 +171,7 @@ class PoolingParams(
|
|||||||
if self.use_activation is None:
|
if self.use_activation is None:
|
||||||
self.use_activation = True
|
self.use_activation = True
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown pooling task: {self.task}")
|
raise ValueError(f"Unknown pooling task: {self.task!r}")
|
||||||
|
|
||||||
def _verify_valid_parameters(self):
|
def _verify_valid_parameters(self):
|
||||||
assert self.task is not None, "task must be set"
|
assert self.task is not None, "task must be set"
|
||||||
@@ -194,7 +186,7 @@ class PoolingParams(
|
|||||||
|
|
||||||
if invalid_parameters:
|
if invalid_parameters:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Task {self.task} only supports {valid_parameters} "
|
f"Task {self.task!r} only supports {valid_parameters} "
|
||||||
f"parameters, does not support "
|
f"parameters, does not support "
|
||||||
f"{invalid_parameters} parameters"
|
f"{invalid_parameters} parameters"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -269,7 +269,11 @@ class AsyncLLM(EngineClient):
|
|||||||
cancel_task_threadsafe(handler)
|
cancel_task_threadsafe(handler)
|
||||||
|
|
||||||
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||||
return await self.engine_core.get_supported_tasks_async()
|
if not hasattr(self, "_supported_tasks"):
|
||||||
|
# Cache the result
|
||||||
|
self._supported_tasks = await self.engine_core.get_supported_tasks_async()
|
||||||
|
|
||||||
|
return self._supported_tasks
|
||||||
|
|
||||||
async def add_request(
|
async def add_request(
|
||||||
self,
|
self,
|
||||||
@@ -355,6 +359,7 @@ class AsyncLLM(EngineClient):
|
|||||||
trace_headers=trace_headers,
|
trace_headers=trace_headers,
|
||||||
priority=priority,
|
priority=priority,
|
||||||
data_parallel_rank=data_parallel_rank,
|
data_parallel_rank=data_parallel_rank,
|
||||||
|
supported_tasks=await self.get_supported_tasks(),
|
||||||
)
|
)
|
||||||
prompt_text = get_prompt_text(prompt)
|
prompt_text = get_prompt_text(prompt)
|
||||||
|
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from vllm.multimodal.utils import argsort_mm_positions
|
|||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.renderers import BaseRenderer
|
from vllm.renderers import BaseRenderer
|
||||||
from vllm.sampling_params import _SAMPLING_EPS, SamplingParams
|
from vllm.sampling_params import _SAMPLING_EPS, SamplingParams
|
||||||
|
from vllm.tasks import POOLING_TASKS, SupportedTask
|
||||||
from vllm.tokenizers import TokenizerLike
|
from vllm.tokenizers import TokenizerLike
|
||||||
from vllm.tokenizers.mistral import MistralTokenizer
|
from vllm.tokenizers.mistral import MistralTokenizer
|
||||||
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
|
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
|
||||||
@@ -196,13 +197,41 @@ class InputProcessor:
|
|||||||
def _validate_params(
|
def _validate_params(
|
||||||
self,
|
self,
|
||||||
params: SamplingParams | PoolingParams,
|
params: SamplingParams | PoolingParams,
|
||||||
|
# TODO: Validate generation tasks as well once `supported_tasks`
|
||||||
|
# is passed to all `process_inputs` calls
|
||||||
|
supported_tasks: tuple[SupportedTask, ...] | None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Validate supported SamplingParam.
|
Validate supported SamplingParam.
|
||||||
Should raise ValueError if unsupported for API Server.
|
Should raise ValueError if unsupported for API Server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(params, PoolingParams):
|
if isinstance(params, PoolingParams):
|
||||||
|
if supported_tasks is None:
|
||||||
|
raise RuntimeError("`supported_tasks` must be passed for pooling")
|
||||||
|
|
||||||
|
supported_pooling_tasks = [
|
||||||
|
task for task in supported_tasks if task in POOLING_TASKS
|
||||||
|
]
|
||||||
|
|
||||||
|
if params.task is None:
|
||||||
|
if not supported_pooling_tasks:
|
||||||
|
raise ValueError("Pooling tasks are not supported")
|
||||||
|
|
||||||
|
if "token_embed" in supported_pooling_tasks:
|
||||||
|
params.task = "token_embed"
|
||||||
|
elif "token_classify" in supported_pooling_tasks:
|
||||||
|
params.task = "token_classify"
|
||||||
|
elif "plugin" in supported_pooling_tasks:
|
||||||
|
params.task = "plugin"
|
||||||
|
|
||||||
|
if params.task not in supported_pooling_tasks:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported task: {params.task!r} "
|
||||||
|
f"Supported tasks: {supported_pooling_tasks}"
|
||||||
|
)
|
||||||
|
|
||||||
|
params.verify(self.model_config)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
self._validate_logprobs(params)
|
self._validate_logprobs(params)
|
||||||
@@ -498,10 +527,11 @@ class InputProcessor:
|
|||||||
trace_headers: Mapping[str, str] | None = None,
|
trace_headers: Mapping[str, str] | None = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
data_parallel_rank: int | None = None,
|
data_parallel_rank: int | None = None,
|
||||||
|
supported_tasks: tuple[SupportedTask, ...] | None = None,
|
||||||
resumable: bool = False,
|
resumable: bool = False,
|
||||||
) -> EngineCoreRequest:
|
) -> EngineCoreRequest:
|
||||||
self._validate_lora(lora_request)
|
self._validate_lora(lora_request)
|
||||||
self._validate_params(params)
|
self._validate_params(params, supported_tasks)
|
||||||
|
|
||||||
parallel_config = self.vllm_config.parallel_config
|
parallel_config = self.vllm_config.parallel_config
|
||||||
dp_size = parallel_config.data_parallel_size
|
dp_size = parallel_config.data_parallel_size
|
||||||
|
|||||||
@@ -201,7 +201,11 @@ class LLMEngine:
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||||
return self.engine_core.get_supported_tasks()
|
if not hasattr(self, "_supported_tasks"):
|
||||||
|
# Cache the result
|
||||||
|
self._supported_tasks = self.engine_core.get_supported_tasks()
|
||||||
|
|
||||||
|
return self._supported_tasks
|
||||||
|
|
||||||
def abort_request(self, request_ids: list[str], internal: bool = False) -> None:
|
def abort_request(self, request_ids: list[str], internal: bool = False) -> None:
|
||||||
"""Remove request_ids from EngineCore and Detokenizer."""
|
"""Remove request_ids from EngineCore and Detokenizer."""
|
||||||
@@ -245,6 +249,7 @@ class LLMEngine:
|
|||||||
tokenization_kwargs,
|
tokenization_kwargs,
|
||||||
trace_headers,
|
trace_headers,
|
||||||
priority,
|
priority,
|
||||||
|
supported_tasks=self.get_supported_tasks(),
|
||||||
)
|
)
|
||||||
prompt_text = get_prompt_text(prompt)
|
prompt_text = get_prompt_text(prompt)
|
||||||
|
|
||||||
|
|||||||
@@ -5037,7 +5037,7 @@ class GPUModelRunner(
|
|||||||
|
|
||||||
model = cast(VllmModelForPooling, self.get_model())
|
model = cast(VllmModelForPooling, self.get_model())
|
||||||
dummy_pooling_params = PoolingParams(task=task)
|
dummy_pooling_params = PoolingParams(task=task)
|
||||||
dummy_pooling_params.verify(task=task, model_config=self.model_config)
|
dummy_pooling_params.verify(self.model_config)
|
||||||
to_update = model.pooler.get_pooling_updates(task)
|
to_update = model.pooler.get_pooling_updates(task)
|
||||||
to_update.apply(dummy_pooling_params)
|
to_update.apply(dummy_pooling_params)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user