[Frontend][4/n] Improve pooling entrypoints | pooling. (#39153)

Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
wang.yuqi
2026-04-09 18:09:45 +08:00
committed by GitHub
parent b6c9be509e
commit 66c079ae83
43 changed files with 554 additions and 733 deletions

View File

@@ -87,7 +87,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
serving_render = OpenAIServingRender(
model_config=engine.model_config,
renderer=engine.renderer,
io_processor=engine.io_processor,
model_registry=models.registry,
request_logger=None,
chat_template=None,
@@ -123,7 +122,6 @@ async def test_chat_error_non_stream():
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine)
@@ -173,7 +171,6 @@ async def test_chat_error_stream():
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine)

View File

@@ -567,7 +567,6 @@ def _build_serving_render(
return OpenAIServingRender(
model_config=engine.model_config,
renderer=engine.renderer,
io_processor=engine.io_processor,
model_registry=model_registry,
request_logger=None,
chat_template=CHAT_TEMPLATE,
@@ -599,7 +598,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
class MockEngine:
model_config: MockModelConfig = field(default_factory=MockModelConfig)
input_processor: MagicMock = field(default_factory=MagicMock)
io_processor: MagicMock = field(default_factory=MagicMock)
renderer: MagicMock = field(default_factory=MagicMock)
@@ -632,7 +630,6 @@ async def test_serving_chat_returns_correct_model_name():
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine)
@@ -662,7 +659,6 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine)
@@ -693,7 +689,6 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.errored = False
mock_engine.model_config = mock_model_config
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
# Initialize the serving chat
@@ -737,7 +732,6 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.errored = False
mock_engine.model_config = mock_model_config
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine)
@@ -779,7 +773,6 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.errored = False
mock_engine.model_config = mock_model_config
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
# Initialize the serving chat
@@ -823,7 +816,6 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated():
mock_engine.errored = False
mock_engine.model_config = MockModelConfig(skip_tokenizer_init=True)
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_tokenizer = MagicMock(spec=MistralTokenizer)
mock_renderer = MistralRenderer(
@@ -863,7 +855,6 @@ async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected():
mock_engine.errored = False
mock_engine.model_config = MockModelConfig(skip_tokenizer_init=True)
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_tokenizer = MagicMock(spec=MistralTokenizer)
mock_renderer = MistralRenderer(
@@ -906,7 +897,6 @@ async def test_serving_chat_could_load_correct_generation_config():
mock_engine.errored = False
mock_engine.model_config = mock_model_config
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
# Initialize the serving chat
@@ -952,7 +942,6 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
mock_engine.errored = False
mock_engine.model_config = mock_model_config
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine)
@@ -1003,7 +992,6 @@ async def test_serving_chat_data_parallel_rank_extraction():
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
# Mock the generate method to return an async generator
@@ -1095,7 +1083,6 @@ class TestServingChatWithHarmony:
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
return mock_engine
@@ -1732,7 +1719,6 @@ async def test_tool_choice_validation_without_parser():
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
models = OpenAIServingModels(
@@ -1802,7 +1788,6 @@ async def test_streaming_n_gt1_independent_tool_parsers():
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
models = OpenAIServingModels(

View File

@@ -79,7 +79,6 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
serving_render = OpenAIServingRender(
model_config=engine.model_config,
renderer=engine.renderer,
io_processor=engine.io_processor,
model_registry=models.registry,
request_logger=None,
chat_template=None,
@@ -107,7 +106,6 @@ async def test_completion_error_non_stream():
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_completion = _build_serving_completion(mock_engine)
@@ -157,7 +155,6 @@ async def test_completion_error_stream():
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_completion = _build_serving_completion(mock_engine)

View File

@@ -137,7 +137,6 @@ def mock_serving_setup():
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
models = OpenAIServingModels(
@@ -148,7 +147,6 @@ def mock_serving_setup():
serving_render = OpenAIServingRender(
model_config=mock_engine.model_config,
renderer=mock_engine.renderer,
io_processor=mock_engine.io_processor,
model_registry=models.registry,
request_logger=None,
chat_template=None,

View File

@@ -77,7 +77,6 @@ def _create_mock_engine():
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
# renderer is accessed by OpenAIServing.__init__ and serving.py
mock_renderer = MagicMock()

View File

@@ -218,7 +218,6 @@ class TestInitializeToolSessions:
engine_client.model_config = model_config
engine_client.input_processor = MagicMock()
engine_client.io_processor = MagicMock()
engine_client.renderer = MagicMock()
models = MagicMock()
@@ -307,7 +306,6 @@ class TestValidateGeneratorInput:
engine_client.model_config = model_config
engine_client.input_processor = MagicMock()
engine_client.io_processor = MagicMock()
engine_client.renderer = MagicMock()
models = MagicMock()
@@ -369,7 +367,6 @@ async def test_reasoning_tokens_counted_for_text_reasoning_model(monkeypatch):
model_config.get_diff_sampling_param.return_value = {}
engine_client.model_config = model_config
engine_client.input_processor = MagicMock()
engine_client.io_processor = MagicMock()
engine_client.renderer = MagicMock()
tokenizer = FakeTokenizer()
@@ -672,7 +669,6 @@ def _make_serving_instance_with_reasoning():
model_config.get_diff_sampling_param.return_value = {}
engine_client.model_config = model_config
engine_client.input_processor = MagicMock()
engine_client.io_processor = MagicMock()
engine_client.renderer = MagicMock()
models = MagicMock()

View File

@@ -110,8 +110,11 @@ def test_score_api(llm: LLM):
llm.score("ping", "pong", use_tqdm=False)
@pytest.mark.parametrize("task", ["embed", "token_embed"])
@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
def test_unsupported_tasks(llm: LLM, task: PoolingTask):
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = "Embedding API is not supported by this model.+"
with pytest.raises(ValueError, match=err_msg):
llm.encode(prompt, pooling_task=task, use_tqdm=False)

View File

@@ -469,4 +469,8 @@ async def test_pooling_not_supported(
},
)
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = f"Unsupported task: {task!r}"
assert response.json()["error"]["message"].startswith(err_msg)

View File

@@ -107,8 +107,11 @@ def test_pooling_params(llm: LLM):
)
@pytest.mark.parametrize("task", ["token_classify", "classify"])
@pytest.mark.parametrize("task", ["token_classify", "classify", "plugin"])
def test_unsupported_tasks(llm: LLM, task: PoolingTask):
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = "Classification API is not supported by this model.+"
with pytest.raises(ValueError, match=err_msg):
llm.encode(prompt, pooling_task=task, use_tqdm=False)

View File

@@ -767,4 +767,8 @@ async def test_pooling_not_supported(
},
)
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = f"Unsupported task: {task!r}"
assert response.json()["error"]["message"].startswith(err_msg)

View File

@@ -411,4 +411,8 @@ async def test_pooling_not_supported(server: RemoteOpenAIServer, task: str):
},
)
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = f"Unsupported task: {task!r}"
assert response.json()["error"]["message"].startswith(err_msg)

View File

@@ -484,4 +484,8 @@ async def test_pooling_not_supported(server: RemoteOpenAIServer, task: str):
},
)
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = f"Unsupported task: {task!r}"
assert response.json()["error"]["message"].startswith(err_msg)

View File

@@ -65,12 +65,15 @@ def test_score_api(llm: LLM):
llm.score("ping", "pong", use_tqdm=False)
@pytest.mark.parametrize("task", ["classify", "embed", "token_embed"])
@pytest.mark.parametrize("task", ["classify", "embed", "token_embed", "plugin"])
def test_unsupported_tasks(llm: LLM, task: PoolingTask, caplog_vllm):
if task == "classify":
with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"):
llm.encode(prompt, pooling_task=task, use_tqdm=False)
assert "deprecated" in caplog_vllm.text
else:
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = "Embedding API is not supported by this model.+"

View File

@@ -50,7 +50,7 @@ async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: st
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("task", ["classify", "embed", "token_embed", "plugin"])
@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
async def test_pooling_not_supported(
server: RemoteOpenAIServer, model_name: str, task: str
):
@@ -64,7 +64,8 @@ async def test_pooling_not_supported(
},
)
if task != "classify":
assert response.json()["error"]["type"] == "BadRequestError"
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = f"Unsupported task: {task!r}"
assert response.json()["error"]["message"].startswith(err_msg)

View File

@@ -62,12 +62,15 @@ def test_token_ids_prompts(llm: LLM):
assert outputs[0].outputs.data.shape == (11, 384)
@pytest.mark.parametrize("task", ["embed", "classify", "token_classify"])
@pytest.mark.parametrize("task", ["embed", "classify", "token_classify", "plugin"])
def test_unsupported_tasks(llm: LLM, task: PoolingTask, caplog_vllm):
if task == "embed":
with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"):
llm.encode(prompt, pooling_task=task, use_tqdm=False)
assert "deprecated" in caplog_vllm.text
else:
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = "Classification API is not supported by this model.+"

View File

@@ -73,7 +73,7 @@ async def test_pooling_token_embed(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("task", ["embed", "classify", "token_classify", "plugin"])
@pytest.mark.parametrize("task", ["classify", "token_classify", "plugin"])
async def test_pooling_not_supported(
server: RemoteOpenAIServer, model_name: str, task: str
):
@@ -87,7 +87,8 @@ async def test_pooling_not_supported(
},
)
if task != "embed":
assert response.json()["error"]["type"] == "BadRequestError"
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = f"Unsupported task: {task!r}"
assert response.json()["error"]["message"].startswith(err_msg)

View File

@@ -86,7 +86,6 @@ def _build_serving_tokens(engine: AsyncLLM, **kwargs) -> ServingTokens:
serving_render = OpenAIServingRender(
model_config=engine.model_config,
renderer=engine.renderer,
io_processor=engine.io_processor,
model_registry=models.registry,
request_logger=None,
chat_template=None,
@@ -148,7 +147,6 @@ def _mock_engine() -> MagicMock:
engine.errored = False
engine.model_config = MockModelConfig()
engine.input_processor = MagicMock()
engine.io_processor = MagicMock()
engine.renderer = _build_renderer(engine.model_config)
return engine

View File

@@ -34,7 +34,6 @@ async def _async_serving_models_init() -> OpenAIServingModels:
mock_model_config.max_model_len = 2048
mock_engine_client.model_config = mock_model_config
mock_engine_client.input_processor = MagicMock()
mock_engine_client.io_processor = MagicMock()
mock_engine_client.renderer = MagicMock()
serving_models = OpenAIServingModels(

View File

@@ -514,7 +514,6 @@ async def test_header_dp_rank_argument():
serving_render = OpenAIServingRender(
model_config=engine.model_config,
renderer=engine.renderer,
io_processor=engine.io_processor,
model_registry=models.registry,
request_logger=None,
chat_template=None,

View File

@@ -14,7 +14,6 @@ from vllm.distributed.weight_transfer.base import (
from vllm.inputs import EngineInput, PromptType
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor
from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer
from vllm.sampling_params import SamplingParams
@@ -44,7 +43,6 @@ class EngineClient(ABC):
vllm_config: VllmConfig
model_config: ModelConfig
renderer: BaseRenderer
io_processor: IOProcessor | None
input_processor: InputProcessor
@property

View File

@@ -49,9 +49,7 @@ from vllm.entrypoints.chat_utils import (
load_chat_template,
)
from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors
from vllm.entrypoints.pooling.scoring.io_processor import (
ScoringIOProcessor,
)
from vllm.entrypoints.pooling.scoring.io_processor import ScoringIOProcessor
from vllm.entrypoints.pooling.scoring.typing import ScoreInput
from vllm.entrypoints.pooling.typing import OfflineInputsContext, OfflineOutputsContext
from vllm.entrypoints.utils import log_non_default_args
@@ -398,12 +396,11 @@ class LLM:
self.runner_type = self.model_config.runner_type
self.renderer = self.llm_engine.renderer
self.chat_template = load_chat_template(chat_template)
self.io_processor = self.llm_engine.io_processor
self.input_processor = self.llm_engine.input_processor
self.chat_template_config = ChatTemplateConfig(chat_template=self.chat_template)
self.pooling_io_processors = init_pooling_io_processors(
supported_tasks=supported_tasks,
model_config=self.model_config,
vllm_config=self.llm_engine.vllm_config,
renderer=self.renderer,
chat_template_config=self.chat_template_config,
)
@@ -1081,118 +1078,55 @@ class LLM:
pooled hidden states in the same order as the input prompts.
"""
if isinstance(prompts, dict) and "data" in prompts and pooling_task != "plugin":
raise ValueError(
"The 'data' field is only supported for the 'plugin' pooling task."
)
self._verify_pooling_task(pooling_task)
assert pooling_task is not None and pooling_task in self.pooling_io_processors
if isinstance(prompts, dict) and "data" in prompts:
if self.io_processor is None:
raise ValueError(
"No IOProcessor plugin installed. Please refer "
"to the documentation and to the "
"'prithvi_geospatial_mae_io_processor' "
"offline inference example for more details."
)
io_processor = self.pooling_io_processors[pooling_task]
# Validate the request data is valid for the loaded plugin
prompt_data = prompts.get("data")
if prompt_data is None:
raise ValueError(
"The 'data' field of the prompt is expected to contain "
"the prompt data and it cannot be None. "
"Refer to the documentation of the IOProcessor "
"in use for more details."
)
validated_prompt = self.io_processor.parse_data(prompt_data)
if pooling_params is None:
pooling_params = PoolingParams()
# obtain the actual model prompts from the pre-processor
prompts = self.io_processor.pre_process(prompt=validated_prompt)
prompts_seq = prompt_to_seq(prompts)
params_seq: Sequence[PoolingParams] = [
self.io_processor.merge_pooling_params(param)
for param in self._params_to_seq(
pooling_params,
len(prompts_seq),
)
]
for p in params_seq:
if p.task is None:
p.task = "plugin"
outputs = self._run_completion(
prompts=prompts_seq,
params=params_seq,
output_type=PoolingRequestOutput,
use_tqdm=use_tqdm,
lora_request=lora_request,
ctx = OfflineInputsContext(
prompts=prompts,
pooling_params=pooling_params,
tokenization_kwargs=tokenization_kwargs,
)
# get the post-processed model outputs
assert self.io_processor is not None
processed_outputs = self.io_processor.post_process(outputs)
engine_inputs = io_processor.pre_process_offline(ctx)
n_inputs = len(engine_inputs)
assert ctx.pooling_params is not None
return [
PoolingRequestOutput[Any](
request_id="",
outputs=processed_outputs,
num_cached_tokens=getattr(
processed_outputs, "num_cached_tokens", 0
),
prompt_token_ids=[],
finished=True,
)
]
else:
if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()
prompts_seq = prompt_to_seq(prompts)
params_seq = self._params_to_seq(pooling_params, len(prompts_seq))
params_seq = self._params_to_seq(ctx.pooling_params, n_inputs)
for param in params_seq:
if param.task is None:
param.task = pooling_task
elif pooling_task == "plugin":
# `plugin` task uses io_processor.parse_request to verify inputs.
# We actually allow plugin to overwrite pooling_task.
pass
elif param.task != pooling_task:
msg = (
f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
)
msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
raise ValueError(msg)
if pooling_task in self.pooling_io_processors:
io_processor = self.pooling_io_processors[pooling_task]
processor_inputs = io_processor.pre_process_offline(
ctx=OfflineInputsContext(
prompts=prompts_seq, tokenization_kwargs=tokenization_kwargs
)
)
seq_lora_requests = self._lora_request_to_seq(
lora_request, len(prompts_seq)
)
seq_priority = self._priority_to_seq(None, len(prompts))
seq_lora_requests = self._lora_request_to_seq(lora_request, n_inputs)
seq_priority = self._priority_to_seq(None, n_inputs)
self._render_and_add_requests(
prompts=processor_inputs,
prompts=engine_inputs,
params=params_seq,
lora_requests=seq_lora_requests,
priorities=seq_priority,
)
outputs = self._run_engine(
use_tqdm=use_tqdm, output_type=PoolingRequestOutput
)
outputs = self._run_engine(use_tqdm=use_tqdm, output_type=PoolingRequestOutput)
outputs = io_processor.post_process_offline(
ctx=OfflineOutputsContext(outputs=outputs)
)
else:
outputs = self._run_completion(
prompts=prompts_seq,
params=params_seq,
output_type=PoolingRequestOutput,
use_tqdm=use_tqdm,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
return outputs
def _verify_pooling_task(self, pooling_task: PoolingTask | None):
@@ -1254,6 +1188,14 @@ class LLM:
pooling_task,
)
if pooling_task == "plugin" and "plugin" not in self.pooling_io_processors:
raise ValueError(
"No IOProcessor plugin installed. Please refer "
"to the documentation and to the "
"'prithvi_geospatial_mae_io_processor' "
"offline inference example for more details."
)
def embed(
self,
prompts: PromptType | Sequence[PromptType],
@@ -1458,6 +1400,9 @@ class LLM:
scoring_data = io_processor.valid_inputs(data_1, data_2)
n_queries = len(scoring_data.data_1)
if pooling_params is None:
pooling_params = PoolingParams()
ctx = OfflineInputsContext(
prompts=scoring_data,
pooling_params=pooling_params,
@@ -1466,15 +1411,11 @@ class LLM:
n_queries=n_queries,
)
processor_inputs = io_processor.pre_process_offline(ctx)
engine_inputs = io_processor.pre_process_offline(ctx)
n_inputs = len(engine_inputs)
seq_lora_requests = self._lora_request_to_seq(
lora_request, len(processor_inputs)
)
if ctx.pooling_params is None:
ctx.pooling_params = PoolingParams()
params_seq = self._params_to_seq(ctx.pooling_params, len(processor_inputs))
seq_lora_requests = self._lora_request_to_seq(lora_request, n_inputs)
params_seq = self._params_to_seq(ctx.pooling_params, n_inputs)
for param in params_seq:
if param.task is None:
@@ -1483,10 +1424,10 @@ class LLM:
msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
raise ValueError(msg)
seq_priority = self._priority_to_seq(None, len(processor_inputs))
seq_priority = self._priority_to_seq(None, n_inputs)
self._render_and_add_requests(
prompts=processor_inputs,
prompts=engine_inputs,
params=params_seq,
lora_requests=seq_lora_requests,
priorities=seq_priority,
@@ -1579,7 +1520,7 @@ class LLM:
if isinstance(params, Sequence):
if len(params) != num_requests:
raise ValueError(
f"The lengths of prompts ({params}) "
f"The lengths of prompts ({num_requests}) "
f"and params ({len(params)}) must be the same."
)

View File

@@ -370,7 +370,6 @@ async def init_app_state(
state.openai_serving_render = OpenAIServingRender(
model_config=engine_client.model_config,
renderer=engine_client.renderer,
io_processor=engine_client.io_processor,
model_registry=state.openai_serving_models.registry,
request_logger=request_logger,
chat_template=resolved_chat_template,
@@ -441,13 +440,12 @@ async def init_render_app_state(
Unlike :func:`init_app_state` this function does not require an
:class:`~vllm.engine.protocol.EngineClient`; it bootstraps the
preprocessing pipeline (renderer, io_processor, input_processor)
preprocessing pipeline (renderer, input_processor)
directly from the :class:`~vllm.config.VllmConfig`.
"""
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.openai.models.serving import OpenAIModelRegistry
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.plugins.io_processors import get_io_processor
from vllm.renderers import renderer_from_config
served_model_names = args.served_model_name or [args.model]
@@ -465,15 +463,11 @@ async def init_render_app_state(
request_logger = None
renderer = renderer_from_config(vllm_config)
io_processor = get_io_processor(
vllm_config, renderer, vllm_config.model_config.io_processor_plugin
)
resolved_chat_template = load_chat_template(args.chat_template)
state.openai_serving_render = OpenAIServingRender(
model_config=vllm_config.model_config,
renderer=renderer,
io_processor=io_processor,
model_registry=model_registry,
request_logger=request_logger,
chat_template=resolved_chat_template,

View File

@@ -44,12 +44,6 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
TranscriptionResponse,
TranslationRequest,
)
from vllm.entrypoints.pooling.pooling.protocol import (
IOProcessorRequest,
PoolingChatRequest,
PoolingCompletionRequest,
PoolingResponse,
)
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
from vllm.entrypoints.serve.tokenize.protocol import (
DetokenizeRequest,
@@ -62,8 +56,7 @@ from vllm.inputs import EngineInput, PromptType
from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.renderers import ChatParams, TokenizeParams
from vllm.renderers.inputs.preprocess import (
extract_prompt_components,
@@ -78,10 +71,7 @@ from vllm.tracing import (
log_tracing_disabled_warning,
)
from vllm.utils import random_uuid
from vllm.utils.async_utils import (
collect_from_async_generator,
merge_async_iterators,
)
from vllm.utils.async_utils import collect_from_async_generator
logger = init_logger(__name__)
@@ -101,17 +91,11 @@ class RendererChatRequest(RendererRequest, Protocol):
CompletionLikeRequest: TypeAlias = (
CompletionRequest
| TokenizeCompletionRequest
| DetokenizeRequest
| PoolingCompletionRequest
CompletionRequest | TokenizeCompletionRequest | DetokenizeRequest
)
ChatLikeRequest: TypeAlias = (
ChatCompletionRequest
| BatchChatCompletionRequest
| TokenizeChatRequest
| PoolingChatRequest
ChatCompletionRequest | BatchChatCompletionRequest | TokenizeChatRequest
)
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
@@ -121,7 +105,6 @@ AnyRequest: TypeAlias = (
| ChatLikeRequest
| SpeechToTextRequest
| ResponsesRequest
| IOProcessorRequest
| GenerateRequest
)
@@ -130,7 +113,6 @@ AnyResponse: TypeAlias = (
| ChatCompletionResponse
| TranscriptionResponse
| TokenizeResponse
| PoolingResponse
| GenerateResponse
)
@@ -146,12 +128,6 @@ class ServeContext(Generic[RequestT]):
created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None
engine_inputs: list[EngineInput] | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None
)
final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -171,7 +147,6 @@ class OpenAIServing:
super().__init__()
self.engine_client = engine_client
self.models = models
self.request_logger = request_logger
@@ -179,7 +154,6 @@ class OpenAIServing:
self.model_config = engine_client.model_config
self.renderer = engine_client.renderer
self.io_processor = engine_client.io_processor
self.input_processor = engine_client.input_processor
async def beam_search(
@@ -381,155 +355,6 @@ class OpenAIServing:
prompt_logprobs=None,
)
async def _preprocess(
self,
ctx: ServeContext,
) -> ErrorResponse | None:
"""
Default preprocessing hook. Subclasses may override to prepare `ctx`.
"""
return None
def _build_response(
self,
ctx: ServeContext,
) -> AnyResponse | ErrorResponse:
"""
Default response builder. Subclass may override this method
to return the appropriate response object.
"""
return self.create_error_response("unimplemented endpoint")
async def handle(
self,
ctx: ServeContext,
) -> AnyResponse | ErrorResponse:
async for response in self._pipeline(ctx):
return response
return self.create_error_response("No response yielded from pipeline")
async def _pipeline(
self,
ctx: ServeContext,
) -> AsyncGenerator[AnyResponse | ErrorResponse, None]:
"""Execute the request processing pipeline yielding responses."""
if error := await self._check_model(ctx.request):
yield error
if error := self._validate_request(ctx):
yield error
preprocess_ret = await self._preprocess(ctx)
if isinstance(preprocess_ret, ErrorResponse):
yield preprocess_ret
generators_ret = await self._prepare_generators(ctx)
if isinstance(generators_ret, ErrorResponse):
yield generators_ret
collect_ret = await self._collect_batch(ctx)
if isinstance(collect_ret, ErrorResponse):
yield collect_ret
yield self._build_response(ctx)
def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None:
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
if (
truncate_prompt_tokens is not None
and truncate_prompt_tokens > self.model_config.max_model_len
):
return self.create_error_response(
"truncate_prompt_tokens value is "
"greater than max_model_len."
" Please request a smaller truncation size."
)
return None
def _create_pooling_params(
self,
ctx: ServeContext,
) -> PoolingParams | ErrorResponse:
if not hasattr(ctx.request, "to_pooling_params"):
return self.create_error_response(
"Request type does not support pooling parameters"
)
return ctx.request.to_pooling_params()
async def _prepare_generators(
self,
ctx: ServeContext,
) -> ErrorResponse | None:
"""Schedule the request and get the result generator."""
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
trace_headers = (
None
if ctx.raw_request is None
else await self._get_trace_headers(ctx.raw_request.headers)
)
pooling_params = self._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
return pooling_params
if ctx.engine_inputs is None:
return self.create_error_response("Engine prompts not available")
for i, engine_input in enumerate(ctx.engine_inputs):
request_id_item = f"{ctx.request_id}-{i}"
self._log_inputs(
request_id_item,
engine_input,
params=pooling_params,
lora_request=ctx.lora_request,
)
generator = self.engine_client.encode(
engine_input,
pooling_params,
request_id_item,
lora_request=ctx.lora_request,
trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0),
)
generators.append(generator)
ctx.result_generator = merge_async_iterators(*generators)
return None
async def _collect_batch(
self,
ctx: ServeContext,
) -> ErrorResponse | None:
"""Collect batch results from the result generator."""
if ctx.engine_inputs is None:
return self.create_error_response("Engine prompts not available")
num_prompts = len(ctx.engine_inputs)
final_res_batch: list[PoolingRequestOutput | None]
final_res_batch = [None] * num_prompts
if ctx.result_generator is None:
return self.create_error_response("Result generator not available")
async for i, res in ctx.result_generator:
final_res_batch[i] = res
if None in final_res_batch:
return self.create_error_response(
"Failed to generate results for all prompts"
)
ctx.final_res_batch = [res for res in final_res_batch if res is not None]
return None
@staticmethod
def create_error_response(
message: str | Exception,
@@ -719,7 +544,7 @@ class OpenAIServing:
self,
request_id: str,
inputs: PromptType | EngineInput,
params: SamplingParams | PoolingParams | BeamSearchParams | None,
params: SamplingParams | BeamSearchParams | None,
lora_request: LoRARequest | None,
) -> None:
if self.request_logger is None:

View File

@@ -112,7 +112,6 @@ class OpenAIServingModels:
self.model_config = self.engine_client.model_config
self.renderer = self.engine_client.renderer
self.io_processor = self.engine_client.io_processor
self.input_processor = self.engine_client.input_processor
async def init_static_loras(self):

View File

@@ -67,20 +67,18 @@ def init_pooling_state(
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.pooling.classify.serving import ServingClassification
from vllm.entrypoints.pooling.embed.serving import ServingEmbedding
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
from vllm.entrypoints.pooling.pooling.serving import ServingPooling
from vllm.entrypoints.pooling.scoring.serving import ServingScores
from vllm.tasks import POOLING_TASKS
model_config = engine_client.model_config
resolved_chat_template = load_chat_template(args.chat_template)
state.serving_pooling = (
(
OpenAIServingPooling(
ServingPooling(
engine_client,
state.openai_serving_models,
state.openai_serving_render,
supported_tasks=supported_tasks,
request_logger=request_logger,
chat_template=resolved_chat_template,

View File

@@ -4,8 +4,8 @@
from collections.abc import Sequence
from typing import Any, Final
from vllm import PoolingRequestOutput, PromptType
from vllm.config import ModelConfig
from vllm import PoolingParams, PoolingRequestOutput, PromptType
from vllm.config import VllmConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateConfig,
@@ -33,11 +33,12 @@ class PoolingIOProcessor:
def __init__(
self,
model_config: ModelConfig,
vllm_config: VllmConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
):
self.model_config = model_config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.renderer = renderer
self.chat_template = chat_template_config.chat_template
@@ -48,12 +49,12 @@ class PoolingIOProcessor:
chat_template_config.trust_request_chat_template
)
def create_pooling_params(self, request):
return request.to_pooling_params()
#######################################
# online APIs
def create_pooling_params(self, request):
return request.to_pooling_params()
def pre_process_online(self, ctx: PoolingServeContext):
request = ctx.request
@@ -100,12 +101,16 @@ class PoolingIOProcessor:
# offline APIs
def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]:
assert not isinstance(ctx.prompts, ScoringData)
assert not isinstance(ctx.prompts, ScoringData) and not (
isinstance(ctx.prompts, dict) and "data" in ctx.prompts
)
prompts_seq = prompt_to_seq(ctx.prompts)
tok_params = self.renderer.default_cmpl_tok_params.with_kwargs(
**(ctx.tokenization_kwargs or {})
)
return self._preprocess_completion_offline(
prompts=ctx.prompts, tok_params=tok_params
prompts=prompts_seq, tok_params=tok_params
)
async def pre_process_offline_async(self, ctx: OfflineInputsContext):
@@ -243,3 +248,19 @@ class PoolingIOProcessor:
"Refused request with untrusted chat template."
)
return None
def _params_to_seq(
self,
params: PoolingParams | Sequence[PoolingParams],
num_requests: int,
) -> Sequence[PoolingParams]:
if isinstance(params, Sequence):
if len(params) != num_requests:
raise ValueError(
f"The lengths of prompts ({num_requests}) "
f"and params ({len(params)}) must be the same."
)
return params
return [params] * num_requests

View File

@@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Mapping
from http import HTTPStatus
from typing import ClassVar
@@ -9,7 +11,7 @@ from fastapi.responses import Response
from starlette.datastructures import Headers
from vllm import PoolingParams, PoolingRequestOutput, envs
from vllm.config import ModelConfig
from vllm.config import VllmConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (
ChatTemplateConfig,
@@ -35,7 +37,7 @@ from vllm.utils.async_utils import merge_async_iterators
from .io_processor import PoolingIOProcessor
class PoolingServing:
class PoolingServingBase(ABC):
request_id_prefix: ClassVar[str]
def __init__(
@@ -50,10 +52,11 @@ class PoolingServing:
return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False,
):
super().__init__()
self.engine_client = engine_client
self.models = models
self.model_config = models.model_config
self.renderer = models.renderer
self.vllm_config = engine_client.vllm_config
self.max_model_len = self.model_config.max_model_len
self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids
@@ -63,31 +66,14 @@ class PoolingServing:
chat_template_content_format=chat_template_content_format,
trust_request_chat_template=trust_request_chat_template,
)
self.io_processor = self.init_io_processor(
model_config=models.model_config,
renderer=models.renderer,
chat_template_config=self.chat_template_config,
)
def init_io_processor(
self,
model_config: ModelConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
) -> PoolingIOProcessor:
raise NotImplementedError
@abstractmethod
async def __call__(
self,
request: AnyPoolingRequest,
raw_request: Request | None = None,
) -> Response:
ctx = await self._init_ctx(request, raw_request)
await self.io_processor.pre_process_online_async(ctx)
await self._prepare_generators(ctx)
await self._collect_batch(ctx)
await self.io_processor.post_process_online_async(ctx)
return await self._build_response(ctx)
raise NotImplementedError
async def _init_ctx(
self,
@@ -124,9 +110,7 @@ class PoolingServing:
else await self._get_trace_headers(ctx.raw_request.headers)
)
if ctx.pooling_params is None:
pooling_params = self.io_processor.create_pooling_params(ctx.request)
else:
assert ctx.pooling_params is not None
pooling_params = ctx.pooling_params
if isinstance(pooling_params, list):
@@ -190,6 +174,7 @@ class PoolingServing:
ctx.final_res_batch = [res for res in final_res_batch if res is not None]
@abstractmethod
async def _build_response(
self,
ctx: PoolingServeContext,
@@ -355,3 +340,39 @@ class PoolingServing:
params=params,
lora_request=lora_request,
)
class PoolingServing(PoolingServingBase, ABC):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.io_processor = self.init_io_processor(
vllm_config=self.vllm_config,
renderer=self.renderer,
chat_template_config=self.chat_template_config,
)
@abstractmethod
def init_io_processor(
self,
vllm_config: VllmConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
) -> PoolingIOProcessor:
raise NotImplementedError
async def __call__(
self,
request: AnyPoolingRequest,
raw_request: Request | None = None,
) -> Response:
ctx = await self._init_ctx(request, raw_request)
await self.io_processor.pre_process_online_async(ctx)
if ctx.pooling_params is None:
ctx.pooling_params = self.io_processor.create_pooling_params(request)
await self._prepare_generators(ctx)
await self._collect_batch(ctx)
await self.io_processor.post_process_online_async(ctx)
return await self._build_response(ctx)

View File

@@ -5,4 +5,8 @@ from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
class ClassifyIOProcessor(PoolingIOProcessor):
name = "classification"
name = "classify"
class TokenClassifyIOProcessor(PoolingIOProcessor):
name = "token_classify"

View File

@@ -6,14 +6,11 @@ from typing import TypeAlias
import numpy as np
from fastapi.responses import JSONResponse
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.serving import PoolingServing
from vllm.entrypoints.pooling.typing import PoolingServeContext
from vllm.logger import init_logger
from vllm.outputs import ClassificationOutput
from vllm.renderers import BaseRenderer
from .io_processor import ClassifyIOProcessor
from .protocol import (
@@ -31,17 +28,8 @@ ClassificationServeContext: TypeAlias = PoolingServeContext[ClassificationReques
class ServingClassification(PoolingServing):
request_id_prefix = "classify"
def init_io_processor(
self,
model_config: ModelConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
) -> ClassifyIOProcessor:
return ClassifyIOProcessor(
model_config=model_config,
renderer=renderer,
chat_template_config=chat_template_config,
)
def init_io_processor(self, *args, **kwargs) -> ClassifyIOProcessor:
return ClassifyIOProcessor(*args, **kwargs)
async def _build_response(
self,

View File

@@ -37,7 +37,7 @@ logger = init_logger(__name__)
class EmbedIOProcessor(PoolingIOProcessor):
name = "embedding"
name = "embed"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -549,3 +549,7 @@ class EmbedIOProcessor(PoolingIOProcessor):
request = ctx.request
if request.truncate == "NONE" and request.max_tokens is not None:
self._check_cohere_max_tokens(ctx.final_res_batch, request.max_tokens)
class TokenEmbedIOProcessor(PoolingIOProcessor):
name = "token_embed"

View File

@@ -8,8 +8,6 @@ from typing import Literal, TypeAlias, cast
from fastapi.responses import JSONResponse, Response, StreamingResponse
from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.serving import PoolingServing
from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor
@@ -33,12 +31,10 @@ from vllm.entrypoints.pooling.utils import (
)
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.renderers import BaseRenderer
from vllm.utils.serial_utils import EmbedDType, Endianness
logger = init_logger(__name__)
JSONResponseCLS = get_json_response_cls()
EmbeddingServeContext: TypeAlias = PoolingServeContext[EmbeddingRequest]
@@ -49,17 +45,13 @@ class ServingEmbedding(PoolingServing):
request_id_prefix = "embd"
io_processor: EmbedIOProcessor
def init_io_processor(
self,
model_config: ModelConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
) -> EmbedIOProcessor:
return EmbedIOProcessor(
model_config=model_config,
renderer=renderer,
chat_template_config=chat_template_config,
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.json_response_cls = get_json_response_cls()
def init_io_processor(self, *args, **kwargs) -> EmbedIOProcessor:
return EmbedIOProcessor(*args, **kwargs)
async def _build_response(
self,
@@ -149,7 +141,7 @@ class ServingEmbedding(PoolingServing):
data=items,
usage=usage,
)
return JSONResponseCLS(content=response.model_dump())
return self.json_response_cls(content=response.model_dump())
def _openai_bytes_response(
self,
@@ -190,8 +182,8 @@ class ServingEmbedding(PoolingServing):
media_type=response.media_type,
)
@staticmethod
def _build_cohere_response_from_ctx(
self,
ctx: PoolingServeContext,
) -> JSONResponse:
request = ctx.request
@@ -218,4 +210,4 @@ class ServingEmbedding(PoolingServing):
),
),
)
return JSONResponse(content=response.model_dump(exclude_none=True))
return self.json_response_cls(content=response.model_dump(exclude_none=True))

View File

@@ -1,42 +1,69 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.config import ModelConfig
from vllm.config import VllmConfig
from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.entrypoints.pooling.scoring.io_processor import ScoringIOProcessors
from vllm.entrypoints.pooling.utils import enable_scoring_api
from vllm.plugins.io_processors import has_io_processor
from vllm.renderers import BaseRenderer
from vllm.tasks import SupportedTask
from .base.io_processor import PoolingIOProcessor
from .utils import enable_scoring_api
def init_pooling_io_processors(
supported_tasks: tuple[SupportedTask, ...],
model_config: ModelConfig,
vllm_config: VllmConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
) -> dict[str, PoolingIOProcessor]:
processors: list[tuple[str, type[PoolingIOProcessor]]] = []
model_config = vllm_config.model_config
processors: dict[str, type[PoolingIOProcessor]] = {}
if "classify" in supported_tasks:
from vllm.entrypoints.pooling.classify.io_processor import ClassifyIOProcessor
from .classify.io_processor import ClassifyIOProcessor
processors["classify"] = ClassifyIOProcessor
if "token_classify" in supported_tasks:
from .classify.io_processor import TokenClassifyIOProcessor
processors["token_classify"] = TokenClassifyIOProcessor
processors.append(("classify", ClassifyIOProcessor))
if "embed" in supported_tasks:
from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor
from .embed.io_processor import EmbedIOProcessor
processors.append(("embed", EmbedIOProcessor))
processors["embed"] = EmbedIOProcessor
if "token_embed" in supported_tasks:
from .embed.io_processor import TokenEmbedIOProcessor
processors["token_embed"] = TokenEmbedIOProcessor
if has_io_processor(
vllm_config,
model_config.io_processor_plugin,
):
from .pooling.io_processor import PluginWithIOProcessorPlugins
processors["plugin"] = PluginWithIOProcessorPlugins
elif "plugin" in supported_tasks:
from .pooling.io_processor import PluginWithoutIOProcessorPlugins
processors["plugin"] = PluginWithoutIOProcessorPlugins
if enable_scoring_api(supported_tasks, model_config):
score_type = model_config.score_type
from .scoring.io_processor import ScoringIOProcessors
if score_type is not None and score_type in ScoringIOProcessors:
processors.append((score_type, ScoringIOProcessors[score_type]))
processors[score_type] = ScoringIOProcessors[score_type]
return {
task: processor_cls(
model_config=model_config,
vllm_config=vllm_config,
renderer=renderer,
chat_template_config=chat_template_config,
)
for task, processor_cls in processors
for task, processor_cls in processors.items()
}

View File

@@ -3,24 +3,17 @@
from http import HTTPStatus
from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse, StreamingResponse
from typing_extensions import assert_never
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.pooling.protocol import (
IOProcessorResponse,
PoolingBytesResponse,
PoolingRequest,
PoolingResponse,
)
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
from vllm.entrypoints.pooling.pooling.protocol import PoolingRequest
from vllm.entrypoints.pooling.pooling.serving import ServingPooling
from vllm.entrypoints.utils import load_aware_call, with_cancellation
router = APIRouter()
def pooling(request: Request) -> OpenAIServingPooling | None:
def pooling(request: Request) -> ServingPooling | None:
return request.app.state.serving_pooling
@@ -39,19 +32,4 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
if handler is None:
raise NotImplementedError("The model does not support Pooling API")
generator = await handler.create_pooling(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, (PoolingResponse, IOProcessorResponse)):
return JSONResponse(content=generator.model_dump())
elif isinstance(generator, PoolingBytesResponse):
return StreamingResponse(
content=generator.content,
headers=generator.headers,
media_type=generator.media_type,
)
assert_never(generator)
return await handler(request, raw_request)

View File

@@ -0,0 +1,156 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Any
from vllm import PoolingParams, PoolingRequestOutput
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.inputs import EngineInput
from vllm.logger import init_logger
from vllm.plugins.io_processors import get_io_processor
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
from ..typing import OfflineInputsContext, OfflineOutputsContext, PoolingServeContext
from .protocol import IOProcessorRequest, IOProcessorResponse
logger = init_logger(__name__)
class PluginWithoutIOProcessorPlugins(PoolingIOProcessor):
name = "plugin"
class PluginWithIOProcessorPlugins(PoolingIOProcessor):
"""IO Processor plugins are a feature that allows pre- and post-processing
of the model input and output for pooling models."""
name = "plugin"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
io_processor = get_io_processor(
self.vllm_config,
self.renderer,
self.model_config.io_processor_plugin,
)
assert io_processor is not None
self.io_processor = io_processor
#######################################
# online APIs
def pre_process_online(self, ctx: PoolingServeContext):
assert isinstance(ctx.request, IOProcessorRequest)
validated_prompt = self.io_processor.parse_data(ctx.request.data)
raw_prompts = self.io_processor.pre_process(
prompt=validated_prompt, request_id=ctx.request_id
)
parsed_prompts = [
(
prompt
if isinstance(prompt, bytes)
else parse_model_prompt(self.model_config, prompt)
)
for prompt in prompt_to_seq(raw_prompts)
]
tok_params = ctx.request.build_tok_params(self.model_config)
ctx.engine_inputs = self.renderer.render_cmpl(
parsed_prompts,
tok_params,
prompt_extras={
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(ctx.request, k, None)) is not None
},
)
pooling_params = self.io_processor.merge_pooling_params()
if pooling_params.task is None:
pooling_params.task = "plugin"
ctx.pooling_params = pooling_params
def post_process_online(
self,
ctx: PoolingServeContext,
):
output = self.io_processor.post_process(
ctx.final_res_batch,
request_id=ctx.request_id,
)
if callable(
output_to_response := getattr(self.io_processor, "output_to_response", None)
):
logger.warning_once(
"`IOProcessor.output_to_response` is deprecated. To ensure "
"consistency between offline and online APIs, "
"`IOProcessorResponse` will become a transparent wrapper "
"around output data from v0.19 onwards.",
)
if hasattr(output, "request_id") and output.request_id is None:
output.request_id = ctx.request_id # type: ignore
ctx.response = output_to_response(output) # type: ignore
else:
ctx.response = IOProcessorResponse(request_id=ctx.request_id, data=output)
#######################################
# offline APIs
def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]:
assert isinstance(ctx.prompts, dict) and "data" in ctx.prompts
assert ctx.pooling_params is not None
# Validate the request data is valid for the loaded plugin
prompt_data = ctx.prompts.get("data")
if prompt_data is None:
raise ValueError(
"The 'data' field of the prompt is expected to contain "
"the prompt data and it cannot be None. "
"Refer to the documentation of the IOProcessor "
"in use for more details."
)
validated_prompt = self.io_processor.parse_data(prompt_data)
# obtain the actual model prompts from the pre-processor
prompts = self.io_processor.pre_process(prompt=validated_prompt)
prompts_seq = prompt_to_seq(prompts)
params_seq: list[PoolingParams] = [
self.io_processor.merge_pooling_params(param)
for param in self._params_to_seq(
ctx.pooling_params,
len(prompts_seq),
)
]
for p in params_seq:
if p.task is None:
p.task = "plugin"
ctx.pooling_params = params_seq
ctx.prompts = prompts_seq
return super().pre_process_offline(ctx)
def post_process_offline(
self,
ctx: OfflineOutputsContext,
) -> list[PoolingRequestOutput]:
processed_outputs = self.io_processor.post_process(ctx.outputs)
return [
PoolingRequestOutput[Any](
request_id="",
outputs=processed_outputs,
num_cached_tokens=getattr(processed_outputs, "num_cached_tokens", 0),
prompt_token_ids=[],
finished=True,
)
]

View File

@@ -1,118 +1,113 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import json
import time
from collections.abc import AsyncGenerator, Callable, Sequence
from collections.abc import Callable
from functools import partial
from typing import Final, Literal, cast
from typing import Literal, cast
from fastapi import Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from typing_extensions import assert_never
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.serving import PoolingServingBase
from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors
from vllm.entrypoints.pooling.pooling.protocol import (
IOProcessorRequest,
IOProcessorResponse,
PoolingBytesResponse,
PoolingChatRequest,
PoolingCompletionRequest,
PoolingRequest,
PoolingResponse,
PoolingResponseData,
)
from vllm.entrypoints.pooling.typing import AnyPoolingRequest, PoolingServeContext
from vllm.entrypoints.pooling.utils import (
encode_pooling_bytes,
encode_pooling_output_base64,
encode_pooling_output_float,
get_json_response_cls,
)
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.inputs import EngineInput
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.renderers.inputs.preprocess import prompt_to_seq
from vllm.tasks import SupportedTask
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
from vllm.utils.serial_utils import EmbedDType, Endianness
logger = init_logger(__name__)
class OpenAIServingPooling(OpenAIServing):
class ServingPooling(PoolingServingBase):
request_id_prefix = "pooling"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
openai_serving_render: OpenAIServingRender,
*args,
supported_tasks: tuple[SupportedTask, ...],
*,
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
)
**kwargs,
):
super().__init__(*args, **kwargs)
self.supported_tasks = supported_tasks
self.pooling_task = self.model_config.get_pooling_task(supported_tasks)
self.openai_serving_render = openai_serving_render
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
self.io_processors = init_pooling_io_processors(
supported_tasks=supported_tasks,
vllm_config=self.vllm_config,
renderer=self.renderer,
chat_template_config=self.chat_template_config,
)
self.json_response_cls = get_json_response_cls()
async def create_pooling(
async def __call__(
self,
request: PoolingRequest,
request: AnyPoolingRequest,
raw_request: Request | None = None,
) -> PoolingResponse | IOProcessorResponse | PoolingBytesResponse | ErrorResponse:
"""
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
) -> Response:
assert isinstance(request, PoolingRequest)
pooling_task = self._verify_pooling_task(request)
model_name = self.models.model_name()
io_processor = self.io_processors[pooling_task]
ctx = await self._init_ctx(request, raw_request)
request_id = f"pool-{self._base_request_id(raw_request)}"
created_time = int(time.time())
await io_processor.pre_process_online_async(ctx)
lora_request = self._maybe_get_adapters(request)
if ctx.pooling_params is None:
ctx.pooling_params = io_processor.create_pooling_params(request)
await self._prepare_generators(ctx)
await self._collect_batch(ctx)
await io_processor.post_process_online_async(ctx)
return await self._build_response(ctx)
def _verify_pooling_task(self, request: PoolingRequest) -> str:
if getattr(request, "dimensions", None) is not None:
raise ValueError("dimensions is currently not supported")
if request.task is None:
request.task = self.pooling_task
if getattr(request, "dimensions", None) is not None:
return self.create_error_response("dimensions is currently not supported")
if isinstance(request, IOProcessorRequest):
request.task = "plugin"
assert request.task is not None
pooling_task = request.task
# plugin task uses io_processor.parse_request to verify inputs
if request.task != "plugin" and request.task != self.pooling_task:
if request.task not in self.supported_tasks:
if pooling_task != "plugin" and pooling_task != self.pooling_task:
if pooling_task not in self.io_processors:
raise ValueError(
f"Unsupported task: {request.task!r} "
f"Unsupported task: {pooling_task!r} "
f"Supported tasks: {self.supported_tasks}"
)
else:
logger.warning_once(
"Pooling multitask support is deprecated and will be removed "
"in v0.20. When the default pooling task is not what you want, you "
'need to manually specify it via --pooler-config.task "%s". ',
request.task,
"need to manually specify it via --pooler-config.task %s. ",
pooling_task,
)
engine_inputs: Sequence[EngineInput]
if use_io_processor := isinstance(request, IOProcessorRequest):
if self.io_processor is None:
if pooling_task == "plugin" and "plugin" not in self.io_processors:
raise ValueError(
"No IOProcessor plugin installed. Please refer "
"to the documentation and to the "
@@ -120,133 +115,43 @@ class OpenAIServingPooling(OpenAIServing):
"offline inference example for more details."
)
validated_prompt = self.io_processor.parse_data(request.data)
return pooling_task
raw_prompts = await self.io_processor.pre_process_async(
prompt=validated_prompt, request_id=request_id
)
engine_inputs = await self.openai_serving_render.preprocess_cmpl(
request,
prompt_to_seq(raw_prompts),
)
elif isinstance(request, PoolingChatRequest):
error_check_ret = self.openai_serving_render.validate_chat_template(
request_chat_template=request.chat_template,
chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
async def _build_response(
self,
ctx: PoolingServeContext,
) -> Response:
if ctx.response is not None:
# for IOProcessorResponse
return self.json_response_cls(content=ctx.response.model_dump())
_, engine_inputs = await self.openai_serving_render.preprocess_chat(
request,
request.messages,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
)
elif isinstance(request, PoolingCompletionRequest):
engine_inputs = await self.openai_serving_render.preprocess_completion(
request,
prompt_input=request.input,
prompt_embeds=None,
)
else:
raise ValueError(f"Unsupported request of type {type(request)}")
encoding_format = ctx.request.encoding_format
embed_dtype = ctx.request.embed_dtype
endianness = ctx.request.endianness
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
if use_io_processor:
assert self.io_processor is not None
pooling_params = self.io_processor.merge_pooling_params()
if pooling_params.task is None:
pooling_params.task = "plugin"
else:
pooling_params = request.to_pooling_params() # type: ignore
for i, engine_input in enumerate(engine_inputs):
request_id_item = f"{request_id}-{i}"
self._log_inputs(
request_id_item,
engine_input,
params=pooling_params,
lora_request=lora_request,
if encoding_format == "float" or encoding_format == "base64":
return self.request_output_to_pooling_json_response(
ctx.final_res_batch,
ctx.request_id,
ctx.created_time,
ctx.model_name,
encoding_format,
embed_dtype,
endianness,
)
trace_headers = (
None
if raw_request is None
else await self._get_trace_headers(raw_request.headers)
if encoding_format == "bytes" or encoding_format == "bytes_only":
return self.request_output_to_pooling_bytes_response(
ctx.final_res_batch,
ctx.request_id,
ctx.created_time,
ctx.model_name,
encoding_format,
embed_dtype,
endianness,
)
generator = self.engine_client.encode(
engine_input,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator)
result_generator = merge_async_iterators(*generators)
if use_io_processor:
assert self.io_processor is not None
output = await self.io_processor.post_process_async(
result_generator,
request_id=request_id,
)
if callable(
output_to_response := getattr(
self.io_processor, "output_to_response", None
)
):
logger.warning_once(
"`IOProcessor.output_to_response` is deprecated. To ensure "
"consistency between offline and online APIs, "
"`IOProcessorResponse` will become a transparent wrapper "
"around output data from v0.19 onwards.",
)
if hasattr(output, "request_id") and output.request_id is None:
output.request_id = request_id # type: ignore
return output_to_response(output) # type: ignore
return IOProcessorResponse(request_id=request_id, data=output)
assert isinstance(request, (PoolingCompletionRequest, PoolingChatRequest))
num_prompts = len(engine_inputs)
# Non-streaming response
final_res_batch: list[PoolingRequestOutput | None]
final_res_batch = [None] * num_prompts
try:
async for i, res in result_generator:
final_res_batch[i] = res
assert all(final_res is not None for final_res in final_res_batch)
final_res_batch_checked = cast(list[PoolingRequestOutput], final_res_batch)
response = self.request_output_to_pooling_response(
final_res_batch_checked,
request_id,
created_time,
model_name,
request.encoding_format,
request.embed_dtype,
request.endianness,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
return response
assert_never(encoding_format)
def request_output_to_pooling_json_response(
self,
@@ -257,7 +162,7 @@ class OpenAIServingPooling(OpenAIServing):
encoding_format: Literal["float", "base64"],
embed_dtype: EmbedDType,
endianness: Endianness,
) -> PoolingResponse:
) -> JSONResponse:
encode_fn = cast(
Callable[[PoolingRequestOutput], list[float] | str],
(
@@ -289,13 +194,14 @@ class OpenAIServingPooling(OpenAIServing):
total_tokens=num_prompt_tokens,
)
return PoolingResponse(
response = PoolingResponse(
id=request_id,
created=created_time,
model=model_name,
data=items,
usage=usage,
)
return self.json_response_cls(content=response.model_dump())
def request_output_to_pooling_bytes_response(
self,
@@ -306,7 +212,7 @@ class OpenAIServingPooling(OpenAIServing):
encoding_format: Literal["bytes", "bytes_only"],
embed_dtype: EmbedDType,
endianness: Endianness,
) -> PoolingBytesResponse:
) -> StreamingResponse:
content, items, usage = encode_pooling_bytes(
pooling_outputs=final_res_batch,
embed_dtype=embed_dtype,
@@ -329,38 +235,10 @@ class OpenAIServingPooling(OpenAIServing):
}
)
return PoolingBytesResponse(content=content, headers=headers)
response = PoolingBytesResponse(content=content, headers=headers)
def request_output_to_pooling_response(
self,
final_res_batch: list[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
encoding_format: EncodingFormat,
embed_dtype: EmbedDType,
endianness: Endianness,
) -> PoolingResponse | PoolingBytesResponse:
if encoding_format == "float" or encoding_format == "base64":
return self.request_output_to_pooling_json_response(
final_res_batch,
request_id,
created_time,
model_name,
encoding_format,
embed_dtype,
endianness,
return StreamingResponse(
content=response.content,
headers=response.headers,
media_type=response.media_type,
)
if encoding_format == "bytes" or encoding_format == "bytes_only":
return self.request_output_to_pooling_bytes_response(
final_res_batch,
request_id,
created_time,
model_name,
encoding_format,
embed_dtype,
endianness,
)
assert_never(encoding_format)

View File

@@ -278,7 +278,7 @@ class CrossEncoderIOProcessor(ScoringIOProcessor):
def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]:
assert isinstance(ctx.prompts, ScoringData)
assert not isinstance(ctx.pooling_params, list)
assert not isinstance(ctx.pooling_params, Sequence)
tok_params = self.renderer.default_cmpl_tok_params.with_kwargs(
**(ctx.tokenization_kwargs or {})

View File

@@ -4,15 +4,13 @@
from fastapi.responses import JSONResponse, Response
from vllm import PoolingParams
from vllm.config import ModelConfig
from vllm.config import VllmConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.entrypoints.pooling.base.serving import PoolingServing
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.renderers import BaseRenderer
from vllm.v1.pool.late_interaction import (
build_late_interaction_doc_params,
build_late_interaction_query_params,
@@ -52,22 +50,17 @@ class ServingScores(PoolingServing):
super().__init__(engine_client, *args, **kwargs)
def init_io_processor(
self,
model_config: ModelConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
self, vllm_config: VllmConfig, *args, **kwargs
) -> PoolingIOProcessor:
model_config = vllm_config.model_config
score_type: str = model_config.score_type
if self.enable_flash_late_interaction:
score_type = "flash-late-interaction"
assert score_type in ScoringIOProcessors
processor_cls = ScoringIOProcessors[score_type]
return processor_cls(
model_config=model_config,
renderer=renderer,
chat_template_config=chat_template_config,
)
return processor_cls(vllm_config, *args, **kwargs)
async def __call__(self, *args, **kwargs) -> Response:
if not self.enable_flash_late_interaction:

View File

@@ -30,7 +30,7 @@ from vllm.entrypoints.pooling.pooling.protocol import (
)
from vllm.entrypoints.pooling.scoring.protocol import ScoringRequest, ScoringResponse
from vllm.entrypoints.pooling.scoring.typing import ScoringData
from vllm.inputs import EngineInput
from vllm.inputs import DataPrompt, EngineInput
from vllm.lora.request import LoRARequest
PoolingCompletionLikeRequest: TypeAlias = (
@@ -86,11 +86,14 @@ class PoolingServeContext(Generic[PoolingRequestT]):
## for bi-encoder & late-interaction
n_queries: int | None = None
## for IOProcessorResponse
response: Any | None = None
@dataclass
class OfflineInputsContext:
prompts: PromptType | Sequence[PromptType] | ScoringData
pooling_params: PoolingParams | list[PoolingParams] | None = None
prompts: PromptType | Sequence[PromptType] | DataPrompt | ScoringData
pooling_params: PoolingParams | Sequence[PoolingParams]
tokenization_kwargs: dict[str, Any] | None = None
chat_template: str | None = None

View File

@@ -14,7 +14,7 @@ from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.base.serving import PoolingServing
from vllm.entrypoints.pooling.base.serving import PoolingServingBase
from vllm.entrypoints.pooling.utils import enable_scoring_api
from vllm.entrypoints.serve.instrumentator.basic import base
from vllm.entrypoints.serve.instrumentator.health import health
@@ -23,7 +23,7 @@ from vllm.tasks import POOLING_TASKS, SupportedTask
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# (requires typing_extensions >= 4.13)
RequestType = Any
GetHandlerFn = Callable[[Request], OpenAIServing | PoolingServing | None]
GetHandlerFn = Callable[[Request], OpenAIServing | PoolingServingBase | None]
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]

View File

@@ -65,7 +65,6 @@ class OpenAIServingRender:
self,
model_config: ModelConfig,
renderer: BaseRenderer,
io_processor: Any,
model_registry: OpenAIModelRegistry,
*,
request_logger: RequestLogger | None,
@@ -81,7 +80,6 @@ class OpenAIServingRender:
) -> None:
self.model_config = model_config
self.renderer = renderer
self.io_processor = io_processor
self.model_registry = model_registry
self.request_logger = request_logger
self.chat_template = chat_template

View File

@@ -12,6 +12,23 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
logger = logging.getLogger(__name__)
def has_io_processor(
vllm_config: VllmConfig,
plugin_from_init: str | None = None,
):
if plugin_from_init:
model_plugin = plugin_from_init
else:
# A plugin can be specified via the model config
# Retrieve the model specific plugin if available
# This is using a custom field in the hf_config for the model
hf_config = vllm_config.model_config.hf_config.to_dict()
config_plugin = hf_config.get("io_processor_plugin")
model_plugin = config_plugin
return model_plugin is not None
def get_io_processor(
vllm_config: VllmConfig,
renderer: BaseRenderer,

View File

@@ -26,7 +26,6 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.renderers import renderer_from_config
from vllm.renderers.inputs.preprocess import extract_prompt_components
@@ -133,11 +132,6 @@ class AsyncLLM(EngineClient):
)
self.renderer = renderer = renderer_from_config(self.vllm_config)
self.io_processor = get_io_processor(
self.vllm_config,
self.renderer,
self.model_config.io_processor_plugin,
)
# Convert EngineInput --> EngineCoreRequest.
self.input_processor = InputProcessor(self.vllm_config, renderer)

View File

@@ -19,7 +19,6 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.renderers import renderer_from_config
from vllm.renderers.inputs.preprocess import extract_prompt_components
@@ -90,11 +89,6 @@ class LLMEngine:
self.should_execute_dummy_batch = False
self.renderer = renderer = renderer_from_config(self.vllm_config)
self.io_processor = get_io_processor(
self.vllm_config,
self.renderer,
self.model_config.io_processor_plugin,
)
# Convert EngineInput --> EngineCoreRequest.
self.input_processor = InputProcessor(self.vllm_config, renderer)