diff --git a/tests/entrypoints/openai/chat_completion/test_chat_error.py b/tests/entrypoints/openai/chat_completion/test_chat_error.py index 46070e481..f1fb7c751 100644 --- a/tests/entrypoints/openai/chat_completion/test_chat_error.py +++ b/tests/entrypoints/openai/chat_completion/test_chat_error.py @@ -87,7 +87,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: serving_render = OpenAIServingRender( model_config=engine.model_config, renderer=engine.renderer, - io_processor=engine.io_processor, model_registry=models.registry, request_logger=None, chat_template=None, @@ -123,7 +122,6 @@ async def test_chat_error_non_stream(): mock_engine.errored = False mock_engine.model_config = MockModelConfig() mock_engine.input_processor = MagicMock() - mock_engine.io_processor = MagicMock() mock_engine.renderer = _build_renderer(mock_engine.model_config) serving_chat = _build_serving_chat(mock_engine) @@ -173,7 +171,6 @@ async def test_chat_error_stream(): mock_engine.errored = False mock_engine.model_config = MockModelConfig() mock_engine.input_processor = MagicMock() - mock_engine.io_processor = MagicMock() mock_engine.renderer = _build_renderer(mock_engine.model_config) serving_chat = _build_serving_chat(mock_engine) diff --git a/tests/entrypoints/openai/chat_completion/test_serving_chat.py b/tests/entrypoints/openai/chat_completion/test_serving_chat.py index cb356e0e1..39d59d28f 100644 --- a/tests/entrypoints/openai/chat_completion/test_serving_chat.py +++ b/tests/entrypoints/openai/chat_completion/test_serving_chat.py @@ -567,7 +567,6 @@ def _build_serving_render( return OpenAIServingRender( model_config=engine.model_config, renderer=engine.renderer, - io_processor=engine.io_processor, model_registry=model_registry, request_logger=None, chat_template=CHAT_TEMPLATE, @@ -599,7 +598,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: class MockEngine: model_config: MockModelConfig = field(default_factory=MockModelConfig) input_processor: MagicMock = field(default_factory=MagicMock) - io_processor: MagicMock = field(default_factory=MagicMock) renderer: MagicMock = field(default_factory=MagicMock) @@ -632,7 +630,6 @@ async def test_serving_chat_returns_correct_model_name(): mock_engine.errored = False mock_engine.model_config = MockModelConfig() mock_engine.input_processor = MagicMock() - mock_engine.io_processor = MagicMock() mock_engine.renderer = _build_renderer(mock_engine.model_config) serving_chat = _build_serving_chat(mock_engine) @@ -662,7 +659,6 @@ async def test_serving_chat_should_set_correct_max_tokens(): mock_engine.errored = False mock_engine.model_config = MockModelConfig() mock_engine.input_processor = MagicMock() - mock_engine.io_processor = MagicMock() mock_engine.renderer = _build_renderer(mock_engine.model_config) serving_chat = _build_serving_chat(mock_engine) @@ -693,7 +689,6 @@ async def test_serving_chat_should_set_correct_max_tokens(): mock_engine.errored = False mock_engine.model_config = mock_model_config mock_engine.input_processor = MagicMock() - mock_engine.io_processor = MagicMock() mock_engine.renderer = _build_renderer(mock_engine.model_config) # Initialize the serving chat @@ -737,7 +732,6 @@ async def test_serving_chat_should_set_correct_max_tokens(): mock_engine.errored = False mock_engine.model_config = mock_model_config mock_engine.input_processor = MagicMock() - mock_engine.io_processor = MagicMock() mock_engine.renderer = _build_renderer(mock_engine.model_config) serving_chat = _build_serving_chat(mock_engine) @@ -779,7 +773,6 @@ async def test_serving_chat_should_set_correct_max_tokens(): mock_engine.errored = False mock_engine.model_config = mock_model_config mock_engine.input_processor = MagicMock() - mock_engine.io_processor = MagicMock() mock_engine.renderer = _build_renderer(mock_engine.model_config) # Initialize the serving chat @@ -823,7 +816,6 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated(): mock_engine.errored = False mock_engine.model_config = MockModelConfig(skip_tokenizer_init=True) mock_engine.input_processor = MagicMock() - mock_engine.io_processor = MagicMock() mock_tokenizer = MagicMock(spec=MistralTokenizer) mock_renderer = MistralRenderer( @@ -863,7 +855,6 @@ async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected(): mock_engine.errored = False mock_engine.model_config = MockModelConfig(skip_tokenizer_init=True) mock_engine.input_processor = MagicMock() - mock_engine.io_processor = MagicMock() mock_tokenizer = MagicMock(spec=MistralTokenizer) mock_renderer = MistralRenderer( @@ -906,7 +897,6 @@ async def test_serving_chat_could_load_correct_generation_config(): mock_engine.errored = False mock_engine.model_config = mock_model_config mock_engine.input_processor = MagicMock() - mock_engine.io_processor = MagicMock() mock_engine.renderer = _build_renderer(mock_engine.model_config) # Initialize the serving chat @@ -952,7 +942,6 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type): mock_engine.errored = False mock_engine.model_config = mock_model_config mock_engine.input_processor = MagicMock() - mock_engine.io_processor = MagicMock() mock_engine.renderer = _build_renderer(mock_engine.model_config) serving_chat = _build_serving_chat(mock_engine) @@ -1003,7 +992,6 @@ async def test_serving_chat_data_parallel_rank_extraction(): mock_engine.errored = False mock_engine.model_config = MockModelConfig() mock_engine.input_processor = MagicMock() - mock_engine.io_processor = MagicMock() mock_engine.renderer = _build_renderer(mock_engine.model_config) # Mock the generate method to return an async generator @@ -1095,7 +1083,6 @@ class TestServingChatWithHarmony: mock_engine.errored = False mock_engine.model_config = MockModelConfig() mock_engine.input_processor = MagicMock() - mock_engine.io_processor = MagicMock() mock_engine.renderer = _build_renderer(mock_engine.model_config) return mock_engine @@ -1732,7 +1719,6 @@ async def test_tool_choice_validation_without_parser(): mock_engine.errored = False mock_engine.model_config = MockModelConfig() mock_engine.input_processor = MagicMock() - mock_engine.io_processor = MagicMock() mock_engine.renderer = _build_renderer(mock_engine.model_config) models = OpenAIServingModels( @@ -1802,7 +1788,6 @@ async def test_streaming_n_gt1_independent_tool_parsers(): mock_engine.errored = False mock_engine.model_config = MockModelConfig() mock_engine.input_processor = MagicMock() - mock_engine.io_processor = MagicMock() mock_engine.renderer = _build_renderer(mock_engine.model_config) models = OpenAIServingModels( diff --git a/tests/entrypoints/openai/completion/test_completion_error.py b/tests/entrypoints/openai/completion/test_completion_error.py index 46eb02e3c..3349f4126 100644 --- a/tests/entrypoints/openai/completion/test_completion_error.py +++ b/tests/entrypoints/openai/completion/test_completion_error.py @@ -79,7 +79,6 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion: serving_render = OpenAIServingRender( model_config=engine.model_config, renderer=engine.renderer, - io_processor=engine.io_processor, model_registry=models.registry, request_logger=None, chat_template=None, @@ -107,7 +106,6 @@ async def test_completion_error_non_stream(): mock_engine.errored = False mock_engine.model_config = MockModelConfig() mock_engine.input_processor = MagicMock() - mock_engine.io_processor = MagicMock() mock_engine.renderer = _build_renderer(mock_engine.model_config) serving_completion = _build_serving_completion(mock_engine) @@ -157,7 +155,6 @@ async def test_completion_error_stream(): mock_engine.errored = False mock_engine.model_config = MockModelConfig() mock_engine.input_processor = MagicMock() - mock_engine.io_processor = MagicMock() mock_engine.renderer = _build_renderer(mock_engine.model_config) serving_completion = _build_serving_completion(mock_engine) diff --git a/tests/entrypoints/openai/completion/test_lora_resolvers.py b/tests/entrypoints/openai/completion/test_lora_resolvers.py index 8d5283de5..6a0bec925 100644 --- a/tests/entrypoints/openai/completion/test_lora_resolvers.py +++ b/tests/entrypoints/openai/completion/test_lora_resolvers.py @@ -137,7 +137,6 @@ def mock_serving_setup(): mock_engine.model_config = MockModelConfig() mock_engine.input_processor = MagicMock() - mock_engine.io_processor = MagicMock() mock_engine.renderer = _build_renderer(mock_engine.model_config) models = OpenAIServingModels( @@ -148,7 +147,6 @@ def mock_serving_setup(): serving_render = OpenAIServingRender( model_config=mock_engine.model_config, renderer=mock_engine.renderer, - io_processor=mock_engine.io_processor, model_registry=models.registry, request_logger=None, chat_template=None, diff --git a/tests/entrypoints/openai/generative_scoring/test_generative_scoring.py b/tests/entrypoints/openai/generative_scoring/test_generative_scoring.py index a260027af..632c4bcc9 100644 --- a/tests/entrypoints/openai/generative_scoring/test_generative_scoring.py +++ b/tests/entrypoints/openai/generative_scoring/test_generative_scoring.py @@ -77,7 +77,6 @@ def _create_mock_engine(): mock_engine.errored = False mock_engine.model_config = MockModelConfig() mock_engine.input_processor = MagicMock() - mock_engine.io_processor = MagicMock() # renderer is accessed by OpenAIServing.__init__ and serving.py mock_renderer = MagicMock() diff --git a/tests/entrypoints/openai/responses/test_serving_responses.py b/tests/entrypoints/openai/responses/test_serving_responses.py index 9e2b9a7fc..32f4bf34b 100644 --- a/tests/entrypoints/openai/responses/test_serving_responses.py +++ b/tests/entrypoints/openai/responses/test_serving_responses.py @@ -218,7 +218,6 @@ class TestInitializeToolSessions: engine_client.model_config = model_config engine_client.input_processor = MagicMock() - engine_client.io_processor = MagicMock() engine_client.renderer = MagicMock() models = MagicMock() @@ -307,7 +306,6 @@ class TestValidateGeneratorInput: engine_client.model_config = model_config engine_client.input_processor = MagicMock() - engine_client.io_processor = MagicMock() engine_client.renderer = MagicMock() models = MagicMock() @@ -369,7 +367,6 @@ async def test_reasoning_tokens_counted_for_text_reasoning_model(monkeypatch): model_config.get_diff_sampling_param.return_value = {} engine_client.model_config = model_config engine_client.input_processor = MagicMock() - engine_client.io_processor = MagicMock() engine_client.renderer = MagicMock() tokenizer = FakeTokenizer() @@ -672,7 +669,6 @@ def _make_serving_instance_with_reasoning(): model_config.get_diff_sampling_param.return_value = {} engine_client.model_config = model_config engine_client.input_processor = MagicMock() - engine_client.io_processor = MagicMock() engine_client.renderer = MagicMock() models = MagicMock() diff --git a/tests/entrypoints/pooling/classify/test_offline.py b/tests/entrypoints/pooling/classify/test_offline.py index f556dd579..828443ebc 100644 --- a/tests/entrypoints/pooling/classify/test_offline.py +++ b/tests/entrypoints/pooling/classify/test_offline.py @@ -110,8 +110,11 @@ def test_score_api(llm: LLM): llm.score("ping", "pong", use_tqdm=False) -@pytest.mark.parametrize("task", ["embed", "token_embed"]) +@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"]) def test_unsupported_tasks(llm: LLM, task: PoolingTask): - err_msg = "Embedding API is not supported by this model.+" + if task == "plugin": + err_msg = "No IOProcessor plugin installed." + else: + err_msg = "Embedding API is not supported by this model.+" with pytest.raises(ValueError, match=err_msg): llm.encode(prompt, pooling_task=task, use_tqdm=False) diff --git a/tests/entrypoints/pooling/classify/test_online.py b/tests/entrypoints/pooling/classify/test_online.py index ed295a09b..848ce4083 100644 --- a/tests/entrypoints/pooling/classify/test_online.py +++ b/tests/entrypoints/pooling/classify/test_online.py @@ -469,4 +469,8 @@ async def test_pooling_not_supported( }, ) assert response.json()["error"]["type"] == "BadRequestError" - assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}") + if task == "plugin": + err_msg = "No IOProcessor plugin installed." + else: + err_msg = f"Unsupported task: {task!r}" + assert response.json()["error"]["message"].startswith(err_msg) diff --git a/tests/entrypoints/pooling/embed/test_offline.py b/tests/entrypoints/pooling/embed/test_offline.py index e8d84ed45..1ffeb027b 100644 --- a/tests/entrypoints/pooling/embed/test_offline.py +++ b/tests/entrypoints/pooling/embed/test_offline.py @@ -107,8 +107,11 @@ def test_pooling_params(llm: LLM): ) -@pytest.mark.parametrize("task", ["token_classify", "classify"]) +@pytest.mark.parametrize("task", ["token_classify", "classify", "plugin"]) def test_unsupported_tasks(llm: LLM, task: PoolingTask): - err_msg = "Classification API is not supported by this model.+" + if task == "plugin": + err_msg = "No IOProcessor plugin installed." + else: + err_msg = "Classification API is not supported by this model.+" with pytest.raises(ValueError, match=err_msg): llm.encode(prompt, pooling_task=task, use_tqdm=False) diff --git a/tests/entrypoints/pooling/embed/test_online.py b/tests/entrypoints/pooling/embed/test_online.py index dc61244c9..3032645c7 100644 --- a/tests/entrypoints/pooling/embed/test_online.py +++ b/tests/entrypoints/pooling/embed/test_online.py @@ -767,4 +767,8 @@ async def test_pooling_not_supported( }, ) assert response.json()["error"]["type"] == "BadRequestError" - assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}") + if task == "plugin": + err_msg = "No IOProcessor plugin installed." + else: + err_msg = f"Unsupported task: {task!r}" + assert response.json()["error"]["message"].startswith(err_msg) diff --git a/tests/entrypoints/pooling/scoring/test_bi_encoder_online.py b/tests/entrypoints/pooling/scoring/test_bi_encoder_online.py index 38146084e..392514056 100644 --- a/tests/entrypoints/pooling/scoring/test_bi_encoder_online.py +++ b/tests/entrypoints/pooling/scoring/test_bi_encoder_online.py @@ -411,4 +411,8 @@ async def test_pooling_not_supported(server: RemoteOpenAIServer, task: str): }, ) assert response.json()["error"]["type"] == "BadRequestError" - assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}") + if task == "plugin": + err_msg = "No IOProcessor plugin installed." + else: + err_msg = f"Unsupported task: {task!r}" + assert response.json()["error"]["message"].startswith(err_msg) diff --git a/tests/entrypoints/pooling/scoring/test_cross_encoder_online.py b/tests/entrypoints/pooling/scoring/test_cross_encoder_online.py index ebb339263..059a32d7c 100644 --- a/tests/entrypoints/pooling/scoring/test_cross_encoder_online.py +++ b/tests/entrypoints/pooling/scoring/test_cross_encoder_online.py @@ -484,4 +484,8 @@ async def test_pooling_not_supported(server: RemoteOpenAIServer, task: str): }, ) assert response.json()["error"]["type"] == "BadRequestError" - assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}") + if task == "plugin": + err_msg = "No IOProcessor plugin installed." + else: + err_msg = f"Unsupported task: {task!r}" + assert response.json()["error"]["message"].startswith(err_msg) diff --git a/tests/entrypoints/pooling/token_classify/test_offline.py b/tests/entrypoints/pooling/token_classify/test_offline.py index f7a746754..d36761466 100644 --- a/tests/entrypoints/pooling/token_classify/test_offline.py +++ b/tests/entrypoints/pooling/token_classify/test_offline.py @@ -65,14 +65,17 @@ def test_score_api(llm: LLM): llm.score("ping", "pong", use_tqdm=False) -@pytest.mark.parametrize("task", ["classify", "embed", "token_embed"]) +@pytest.mark.parametrize("task", ["classify", "embed", "token_embed", "plugin"]) def test_unsupported_tasks(llm: LLM, task: PoolingTask, caplog_vllm): if task == "classify": with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"): llm.encode(prompt, pooling_task=task, use_tqdm=False) assert "deprecated" in caplog_vllm.text else: - err_msg = "Embedding API is not supported by this model.+" + if task == "plugin": + err_msg = "No IOProcessor plugin installed." + else: + err_msg = "Embedding API is not supported by this model.+" with pytest.raises(ValueError, match=err_msg): llm.encode(prompt, pooling_task=task, use_tqdm=False) diff --git a/tests/entrypoints/pooling/token_classify/test_online.py b/tests/entrypoints/pooling/token_classify/test_online.py index e91d0bc9a..39fd37833 100644 --- a/tests/entrypoints/pooling/token_classify/test_online.py +++ b/tests/entrypoints/pooling/token_classify/test_online.py @@ -50,7 +50,7 @@ async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: st @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("task", ["classify", "embed", "token_embed", "plugin"]) +@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"]) async def test_pooling_not_supported( server: RemoteOpenAIServer, model_name: str, task: str ): @@ -64,7 +64,8 @@ async def test_pooling_not_supported( }, ) - if task != "classify": - assert response.json()["error"]["type"] == "BadRequestError" + if task == "plugin": + err_msg = "No IOProcessor plugin installed." + else: err_msg = f"Unsupported task: {task!r}" - assert response.json()["error"]["message"].startswith(err_msg) + assert response.json()["error"]["message"].startswith(err_msg) diff --git a/tests/entrypoints/pooling/token_embed/test_offline.py b/tests/entrypoints/pooling/token_embed/test_offline.py index 697f4f81a..d2e87fbf2 100644 --- a/tests/entrypoints/pooling/token_embed/test_offline.py +++ b/tests/entrypoints/pooling/token_embed/test_offline.py @@ -62,14 +62,17 @@ def test_token_ids_prompts(llm: LLM): assert outputs[0].outputs.data.shape == (11, 384) -@pytest.mark.parametrize("task", ["embed", "classify", "token_classify"]) +@pytest.mark.parametrize("task", ["embed", "classify", "token_classify", "plugin"]) def test_unsupported_tasks(llm: LLM, task: PoolingTask, caplog_vllm): if task == "embed": with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"): llm.encode(prompt, pooling_task=task, use_tqdm=False) assert "deprecated" in caplog_vllm.text else: - err_msg = "Classification API is not supported by this model.+" + if task == "plugin": + err_msg = "No IOProcessor plugin installed." + else: + err_msg = "Classification API is not supported by this model.+" with pytest.raises(ValueError, match=err_msg): llm.encode(prompt, pooling_task=task, use_tqdm=False) diff --git a/tests/entrypoints/pooling/token_embed/test_online.py b/tests/entrypoints/pooling/token_embed/test_online.py index 922c624e9..048491dac 100644 --- a/tests/entrypoints/pooling/token_embed/test_online.py +++ b/tests/entrypoints/pooling/token_embed/test_online.py @@ -73,7 +73,7 @@ async def test_pooling_token_embed(server: RemoteOpenAIServer, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("task", ["embed", "classify", "token_classify", "plugin"]) +@pytest.mark.parametrize("task", ["classify", "token_classify", "plugin"]) async def test_pooling_not_supported( server: RemoteOpenAIServer, model_name: str, task: str ): @@ -87,7 +87,8 @@ async def test_pooling_not_supported( }, ) - if task != "embed": - assert response.json()["error"]["type"] == "BadRequestError" + if task == "plugin": + err_msg = "No IOProcessor plugin installed." + else: err_msg = f"Unsupported task: {task!r}" - assert response.json()["error"]["message"].startswith(err_msg) + assert response.json()["error"]["message"].startswith(err_msg) diff --git a/tests/entrypoints/serve/disagg/test_generate_stream.py b/tests/entrypoints/serve/disagg/test_generate_stream.py index a9ca02630..76a9df22f 100644 --- a/tests/entrypoints/serve/disagg/test_generate_stream.py +++ b/tests/entrypoints/serve/disagg/test_generate_stream.py @@ -86,7 +86,6 @@ def _build_serving_tokens(engine: AsyncLLM, **kwargs) -> ServingTokens: serving_render = OpenAIServingRender( model_config=engine.model_config, renderer=engine.renderer, - io_processor=engine.io_processor, model_registry=models.registry, request_logger=None, chat_template=None, @@ -148,7 +147,6 @@ def _mock_engine() -> MagicMock: engine.errored = False engine.model_config = MockModelConfig() engine.input_processor = MagicMock() - engine.io_processor = MagicMock() engine.renderer = _build_renderer(engine.model_config) return engine diff --git a/tests/entrypoints/serve/lora/test_serving_models.py b/tests/entrypoints/serve/lora/test_serving_models.py index f6755f489..ce9fdcc2b 100644 --- a/tests/entrypoints/serve/lora/test_serving_models.py +++ b/tests/entrypoints/serve/lora/test_serving_models.py @@ -34,7 +34,6 @@ async def _async_serving_models_init() -> OpenAIServingModels: mock_model_config.max_model_len = 2048 mock_engine_client.model_config = mock_model_config mock_engine_client.input_processor = MagicMock() - mock_engine_client.io_processor = MagicMock() mock_engine_client.renderer = MagicMock() serving_models = OpenAIServingModels( diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index 69a1c38a4..21a651c62 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -514,7 +514,6 @@ async def test_header_dp_rank_argument(): serving_render = OpenAIServingRender( model_config=engine.model_config, renderer=engine.renderer, - io_processor=engine.io_processor, model_registry=models.registry, request_logger=None, chat_template=None, diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 3d466e3fc..50013a060 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -14,7 +14,6 @@ from vllm.distributed.weight_transfer.base import ( from vllm.inputs import EngineInput, PromptType from vllm.lora.request import LoRARequest from vllm.outputs import PoolingRequestOutput, RequestOutput -from vllm.plugins.io_processors import IOProcessor from vllm.pooling_params import PoolingParams from vllm.renderers import BaseRenderer from vllm.sampling_params import SamplingParams @@ -44,7 +43,6 @@ class EngineClient(ABC): vllm_config: VllmConfig model_config: ModelConfig renderer: BaseRenderer - io_processor: IOProcessor | None input_processor: InputProcessor @property diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 1be2cdd5c..d296e84d0 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -49,9 +49,7 @@ from vllm.entrypoints.chat_utils import ( load_chat_template, ) from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors -from vllm.entrypoints.pooling.scoring.io_processor import ( - ScoringIOProcessor, -) +from vllm.entrypoints.pooling.scoring.io_processor import ScoringIOProcessor from vllm.entrypoints.pooling.scoring.typing import ScoreInput from vllm.entrypoints.pooling.typing import OfflineInputsContext, OfflineOutputsContext from vllm.entrypoints.utils import log_non_default_args @@ -398,12 +396,11 @@ class LLM: self.runner_type = self.model_config.runner_type self.renderer = self.llm_engine.renderer self.chat_template = load_chat_template(chat_template) - self.io_processor = self.llm_engine.io_processor self.input_processor = self.llm_engine.input_processor self.chat_template_config = ChatTemplateConfig(chat_template=self.chat_template) self.pooling_io_processors = init_pooling_io_processors( supported_tasks=supported_tasks, - model_config=self.model_config, + vllm_config=self.llm_engine.vllm_config, renderer=self.renderer, chat_template_config=self.chat_template_config, ) @@ -1081,118 +1078,55 @@ class LLM: pooled hidden states in the same order as the input prompts. """ - self._verify_pooling_task(pooling_task) - - if isinstance(prompts, dict) and "data" in prompts: - if self.io_processor is None: - raise ValueError( - "No IOProcessor plugin installed. Please refer " - "to the documentation and to the " - "'prithvi_geospatial_mae_io_processor' " - "offline inference example for more details." - ) - - # Validate the request data is valid for the loaded plugin - prompt_data = prompts.get("data") - if prompt_data is None: - raise ValueError( - "The 'data' field of the prompt is expected to contain " - "the prompt data and it cannot be None. " - "Refer to the documentation of the IOProcessor " - "in use for more details." - ) - validated_prompt = self.io_processor.parse_data(prompt_data) - - # obtain the actual model prompts from the pre-processor - prompts = self.io_processor.pre_process(prompt=validated_prompt) - prompts_seq = prompt_to_seq(prompts) - - params_seq: Sequence[PoolingParams] = [ - self.io_processor.merge_pooling_params(param) - for param in self._params_to_seq( - pooling_params, - len(prompts_seq), - ) - ] - for p in params_seq: - if p.task is None: - p.task = "plugin" - - outputs = self._run_completion( - prompts=prompts_seq, - params=params_seq, - output_type=PoolingRequestOutput, - use_tqdm=use_tqdm, - lora_request=lora_request, - tokenization_kwargs=tokenization_kwargs, + if isinstance(prompts, dict) and "data" in prompts and pooling_task != "plugin": + raise ValueError( + "The 'data' field is only supported for the 'plugin' pooling task." ) + self._verify_pooling_task(pooling_task) + assert pooling_task is not None and pooling_task in self.pooling_io_processors - # get the post-processed model outputs - assert self.io_processor is not None - processed_outputs = self.io_processor.post_process(outputs) + io_processor = self.pooling_io_processors[pooling_task] - return [ - PoolingRequestOutput[Any]( - request_id="", - outputs=processed_outputs, - num_cached_tokens=getattr( - processed_outputs, "num_cached_tokens", 0 - ), - prompt_token_ids=[], - finished=True, - ) - ] - else: - if pooling_params is None: - # Use default pooling params. - pooling_params = PoolingParams() + if pooling_params is None: + pooling_params = PoolingParams() - prompts_seq = prompt_to_seq(prompts) - params_seq = self._params_to_seq(pooling_params, len(prompts_seq)) + ctx = OfflineInputsContext( + prompts=prompts, + pooling_params=pooling_params, + tokenization_kwargs=tokenization_kwargs, + ) - for param in params_seq: - if param.task is None: - param.task = pooling_task - elif param.task != pooling_task: - msg = ( - f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!" - ) - raise ValueError(msg) + engine_inputs = io_processor.pre_process_offline(ctx) + n_inputs = len(engine_inputs) + assert ctx.pooling_params is not None - if pooling_task in self.pooling_io_processors: - io_processor = self.pooling_io_processors[pooling_task] - processor_inputs = io_processor.pre_process_offline( - ctx=OfflineInputsContext( - prompts=prompts_seq, tokenization_kwargs=tokenization_kwargs - ) - ) - seq_lora_requests = self._lora_request_to_seq( - lora_request, len(prompts_seq) - ) - seq_priority = self._priority_to_seq(None, len(prompts)) + params_seq = self._params_to_seq(ctx.pooling_params, n_inputs) - self._render_and_add_requests( - prompts=processor_inputs, - params=params_seq, - lora_requests=seq_lora_requests, - priorities=seq_priority, - ) + for param in params_seq: + if param.task is None: + param.task = pooling_task + elif pooling_task == "plugin": + # `plugin` task uses io_processor.parse_request to verify inputs. + # We actually allow plugin to overwrite pooling_task. + pass + elif param.task != pooling_task: + msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!" + raise ValueError(msg) - outputs = self._run_engine( - use_tqdm=use_tqdm, output_type=PoolingRequestOutput - ) - outputs = io_processor.post_process_offline( - ctx=OfflineOutputsContext(outputs=outputs) - ) - else: - outputs = self._run_completion( - prompts=prompts_seq, - params=params_seq, - output_type=PoolingRequestOutput, - use_tqdm=use_tqdm, - lora_request=lora_request, - tokenization_kwargs=tokenization_kwargs, - ) + seq_lora_requests = self._lora_request_to_seq(lora_request, n_inputs) + seq_priority = self._priority_to_seq(None, n_inputs) + + self._render_and_add_requests( + prompts=engine_inputs, + params=params_seq, + lora_requests=seq_lora_requests, + priorities=seq_priority, + ) + + outputs = self._run_engine(use_tqdm=use_tqdm, output_type=PoolingRequestOutput) + outputs = io_processor.post_process_offline( + ctx=OfflineOutputsContext(outputs=outputs) + ) return outputs def _verify_pooling_task(self, pooling_task: PoolingTask | None): @@ -1254,6 +1188,14 @@ class LLM: pooling_task, ) + if pooling_task == "plugin" and "plugin" not in self.pooling_io_processors: + raise ValueError( + "No IOProcessor plugin installed. Please refer " + "to the documentation and to the " + "'prithvi_geospatial_mae_io_processor' " + "offline inference example for more details." + ) + def embed( self, prompts: PromptType | Sequence[PromptType], @@ -1458,6 +1400,9 @@ class LLM: scoring_data = io_processor.valid_inputs(data_1, data_2) n_queries = len(scoring_data.data_1) + if pooling_params is None: + pooling_params = PoolingParams() + ctx = OfflineInputsContext( prompts=scoring_data, pooling_params=pooling_params, @@ -1466,15 +1411,11 @@ class LLM: n_queries=n_queries, ) - processor_inputs = io_processor.pre_process_offline(ctx) + engine_inputs = io_processor.pre_process_offline(ctx) + n_inputs = len(engine_inputs) - seq_lora_requests = self._lora_request_to_seq( - lora_request, len(processor_inputs) - ) - - if ctx.pooling_params is None: - ctx.pooling_params = PoolingParams() - params_seq = self._params_to_seq(ctx.pooling_params, len(processor_inputs)) + seq_lora_requests = self._lora_request_to_seq(lora_request, n_inputs) + params_seq = self._params_to_seq(ctx.pooling_params, n_inputs) for param in params_seq: if param.task is None: @@ -1483,10 +1424,10 @@ class LLM: msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!" raise ValueError(msg) - seq_priority = self._priority_to_seq(None, len(processor_inputs)) + seq_priority = self._priority_to_seq(None, n_inputs) self._render_and_add_requests( - prompts=processor_inputs, + prompts=engine_inputs, params=params_seq, lora_requests=seq_lora_requests, priorities=seq_priority, @@ -1579,7 +1520,7 @@ class LLM: if isinstance(params, Sequence): if len(params) != num_requests: raise ValueError( - f"The lengths of prompts ({params}) " + f"The lengths of prompts ({num_requests}) " f"and params ({len(params)}) must be the same." ) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 2b6cb810e..85d2fe43d 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -370,7 +370,6 @@ async def init_app_state( state.openai_serving_render = OpenAIServingRender( model_config=engine_client.model_config, renderer=engine_client.renderer, - io_processor=engine_client.io_processor, model_registry=state.openai_serving_models.registry, request_logger=request_logger, chat_template=resolved_chat_template, @@ -441,13 +440,12 @@ async def init_render_app_state( Unlike :func:`init_app_state` this function does not require an :class:`~vllm.engine.protocol.EngineClient`; it bootstraps the - preprocessing pipeline (renderer, io_processor, input_processor) + preprocessing pipeline (renderer, input_processor) directly from the :class:`~vllm.config.VllmConfig`. """ from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.openai.models.serving import OpenAIModelRegistry from vllm.entrypoints.serve.render.serving import OpenAIServingRender - from vllm.plugins.io_processors import get_io_processor from vllm.renderers import renderer_from_config served_model_names = args.served_model_name or [args.model] @@ -465,15 +463,11 @@ async def init_render_app_state( request_logger = None renderer = renderer_from_config(vllm_config) - io_processor = get_io_processor( - vllm_config, renderer, vllm_config.model_config.io_processor_plugin - ) resolved_chat_template = load_chat_template(args.chat_template) state.openai_serving_render = OpenAIServingRender( model_config=vllm_config.model_config, renderer=renderer, - io_processor=io_processor, model_registry=model_registry, request_logger=request_logger, chat_template=resolved_chat_template, diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index f5f011a96..5bd415b4f 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -44,12 +44,6 @@ from vllm.entrypoints.openai.speech_to_text.protocol import ( TranscriptionResponse, TranslationRequest, ) -from vllm.entrypoints.pooling.pooling.protocol import ( - IOProcessorRequest, - PoolingChatRequest, - PoolingCompletionRequest, - PoolingResponse, -) from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse from vllm.entrypoints.serve.tokenize.protocol import ( DetokenizeRequest, @@ -62,8 +56,7 @@ from vllm.inputs import EngineInput, PromptType from vllm.logger import init_logger from vllm.logprobs import Logprob, PromptLogprobs from vllm.lora.request import LoRARequest -from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput -from vllm.pooling_params import PoolingParams +from vllm.outputs import CompletionOutput, RequestOutput from vllm.renderers import ChatParams, TokenizeParams from vllm.renderers.inputs.preprocess import ( extract_prompt_components, @@ -78,10 +71,7 @@ from vllm.tracing import ( log_tracing_disabled_warning, ) from vllm.utils import random_uuid -from vllm.utils.async_utils import ( - collect_from_async_generator, - merge_async_iterators, -) +from vllm.utils.async_utils import collect_from_async_generator logger = init_logger(__name__) @@ -101,17 +91,11 @@ class RendererChatRequest(RendererRequest, Protocol): CompletionLikeRequest: TypeAlias = ( - CompletionRequest - | TokenizeCompletionRequest - | DetokenizeRequest - | PoolingCompletionRequest + CompletionRequest | TokenizeCompletionRequest | DetokenizeRequest ) ChatLikeRequest: TypeAlias = ( - ChatCompletionRequest - | BatchChatCompletionRequest - | TokenizeChatRequest - | PoolingChatRequest + ChatCompletionRequest | BatchChatCompletionRequest | TokenizeChatRequest ) SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest @@ -121,7 +105,6 @@ AnyRequest: TypeAlias = ( | ChatLikeRequest | SpeechToTextRequest | ResponsesRequest - | IOProcessorRequest | GenerateRequest ) @@ -130,7 +113,6 @@ AnyResponse: TypeAlias = ( | ChatCompletionResponse | TranscriptionResponse | TokenizeResponse - | PoolingResponse | GenerateResponse ) @@ -146,12 +128,6 @@ class ServeContext(Generic[RequestT]): created_time: int = field(default_factory=lambda: int(time.time())) lora_request: LoRARequest | None = None engine_inputs: list[EngineInput] | None = None - - result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = ( - None - ) - final_res_batch: list[PoolingRequestOutput] = field(default_factory=list) - model_config = ConfigDict(arbitrary_types_allowed=True) @@ -171,7 +147,6 @@ class OpenAIServing: super().__init__() self.engine_client = engine_client - self.models = models self.request_logger = request_logger @@ -179,7 +154,6 @@ class OpenAIServing: self.model_config = engine_client.model_config self.renderer = engine_client.renderer - self.io_processor = engine_client.io_processor self.input_processor = engine_client.input_processor async def beam_search( @@ -381,155 +355,6 @@ class OpenAIServing: prompt_logprobs=None, ) - async def _preprocess( - self, - ctx: ServeContext, - ) -> ErrorResponse | None: - """ - Default preprocessing hook. Subclasses may override to prepare `ctx`. - """ - return None - - def _build_response( - self, - ctx: ServeContext, - ) -> AnyResponse | ErrorResponse: - """ - Default response builder. Subclass may override this method - to return the appropriate response object. - """ - return self.create_error_response("unimplemented endpoint") - - async def handle( - self, - ctx: ServeContext, - ) -> AnyResponse | ErrorResponse: - async for response in self._pipeline(ctx): - return response - - return self.create_error_response("No response yielded from pipeline") - - async def _pipeline( - self, - ctx: ServeContext, - ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]: - """Execute the request processing pipeline yielding responses.""" - if error := await self._check_model(ctx.request): - yield error - if error := self._validate_request(ctx): - yield error - - preprocess_ret = await self._preprocess(ctx) - if isinstance(preprocess_ret, ErrorResponse): - yield preprocess_ret - - generators_ret = await self._prepare_generators(ctx) - if isinstance(generators_ret, ErrorResponse): - yield generators_ret - - collect_ret = await self._collect_batch(ctx) - if isinstance(collect_ret, ErrorResponse): - yield collect_ret - - yield self._build_response(ctx) - - def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None: - truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None) - - if ( - truncate_prompt_tokens is not None - and truncate_prompt_tokens > self.model_config.max_model_len - ): - return self.create_error_response( - "truncate_prompt_tokens value is " - "greater than max_model_len." - " Please request a smaller truncation size." - ) - return None - - def _create_pooling_params( - self, - ctx: ServeContext, - ) -> PoolingParams | ErrorResponse: - if not hasattr(ctx.request, "to_pooling_params"): - return self.create_error_response( - "Request type does not support pooling parameters" - ) - - return ctx.request.to_pooling_params() - - async def _prepare_generators( - self, - ctx: ServeContext, - ) -> ErrorResponse | None: - """Schedule the request and get the result generator.""" - generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] - - trace_headers = ( - None - if ctx.raw_request is None - else await self._get_trace_headers(ctx.raw_request.headers) - ) - - pooling_params = self._create_pooling_params(ctx) - if isinstance(pooling_params, ErrorResponse): - return pooling_params - - if ctx.engine_inputs is None: - return self.create_error_response("Engine prompts not available") - - for i, engine_input in enumerate(ctx.engine_inputs): - request_id_item = f"{ctx.request_id}-{i}" - - self._log_inputs( - request_id_item, - engine_input, - params=pooling_params, - lora_request=ctx.lora_request, - ) - - generator = self.engine_client.encode( - engine_input, - pooling_params, - request_id_item, - lora_request=ctx.lora_request, - trace_headers=trace_headers, - priority=getattr(ctx.request, "priority", 0), - ) - - generators.append(generator) - - ctx.result_generator = merge_async_iterators(*generators) - - return None - - async def _collect_batch( - self, - ctx: ServeContext, - ) -> ErrorResponse | None: - """Collect batch results from the result generator.""" - if ctx.engine_inputs is None: - return self.create_error_response("Engine prompts not available") - - num_prompts = len(ctx.engine_inputs) - final_res_batch: list[PoolingRequestOutput | None] - final_res_batch = [None] * num_prompts - - if ctx.result_generator is None: - return self.create_error_response("Result generator not available") - - async for i, res in ctx.result_generator: - final_res_batch[i] = res - - if None in final_res_batch: - return self.create_error_response( - "Failed to generate results for all prompts" - ) - - ctx.final_res_batch = [res for res in final_res_batch if res is not None] - - return None - @staticmethod def create_error_response( message: str | Exception, @@ -719,7 +544,7 @@ class OpenAIServing: self, request_id: str, inputs: PromptType | EngineInput, - params: SamplingParams | PoolingParams | BeamSearchParams | None, + params: SamplingParams | BeamSearchParams | None, lora_request: LoRARequest | None, ) -> None: if self.request_logger is None: diff --git a/vllm/entrypoints/openai/models/serving.py b/vllm/entrypoints/openai/models/serving.py index dd7a8687f..ba1902d5b 100644 --- a/vllm/entrypoints/openai/models/serving.py +++ b/vllm/entrypoints/openai/models/serving.py @@ -112,7 +112,6 @@ class OpenAIServingModels: self.model_config = self.engine_client.model_config self.renderer = self.engine_client.renderer - self.io_processor = self.engine_client.io_processor self.input_processor = self.engine_client.input_processor async def init_static_loras(self): diff --git a/vllm/entrypoints/pooling/__init__.py b/vllm/entrypoints/pooling/__init__.py index fb0c10e6f..1980750ec 100644 --- a/vllm/entrypoints/pooling/__init__.py +++ b/vllm/entrypoints/pooling/__init__.py @@ -67,20 +67,18 @@ def init_pooling_state( from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.pooling.classify.serving import ServingClassification from vllm.entrypoints.pooling.embed.serving import ServingEmbedding - from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling + from vllm.entrypoints.pooling.pooling.serving import ServingPooling from vllm.entrypoints.pooling.scoring.serving import ServingScores from vllm.tasks import POOLING_TASKS model_config = engine_client.model_config - resolved_chat_template = load_chat_template(args.chat_template) state.serving_pooling = ( ( - OpenAIServingPooling( + ServingPooling( engine_client, state.openai_serving_models, - state.openai_serving_render, supported_tasks=supported_tasks, request_logger=request_logger, chat_template=resolved_chat_template, diff --git a/vllm/entrypoints/pooling/base/io_processor.py b/vllm/entrypoints/pooling/base/io_processor.py index fd4c076cd..79f350382 100644 --- a/vllm/entrypoints/pooling/base/io_processor.py +++ b/vllm/entrypoints/pooling/base/io_processor.py @@ -4,8 +4,8 @@ from collections.abc import Sequence from typing import Any, Final -from vllm import PoolingRequestOutput, PromptType -from vllm.config import ModelConfig +from vllm import PoolingParams, PoolingRequestOutput, PromptType +from vllm.config import VllmConfig from vllm.entrypoints.chat_utils import ( ChatCompletionMessageParam, ChatTemplateConfig, @@ -33,11 +33,12 @@ class PoolingIOProcessor: def __init__( self, - model_config: ModelConfig, + vllm_config: VllmConfig, renderer: BaseRenderer, chat_template_config: ChatTemplateConfig, ): - self.model_config = model_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config self.renderer = renderer self.chat_template = chat_template_config.chat_template @@ -48,12 +49,12 @@ class PoolingIOProcessor: chat_template_config.trust_request_chat_template ) - def create_pooling_params(self, request): - return request.to_pooling_params() - ####################################### # online APIs + def create_pooling_params(self, request): + return request.to_pooling_params() + def pre_process_online(self, ctx: PoolingServeContext): request = ctx.request @@ -100,12 +101,16 @@ class PoolingIOProcessor: # offline APIs def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]: - assert not isinstance(ctx.prompts, ScoringData) + assert not isinstance(ctx.prompts, ScoringData) and not ( + isinstance(ctx.prompts, dict) and "data" in ctx.prompts + ) + + prompts_seq = prompt_to_seq(ctx.prompts) tok_params = self.renderer.default_cmpl_tok_params.with_kwargs( **(ctx.tokenization_kwargs or {}) ) return self._preprocess_completion_offline( - prompts=ctx.prompts, tok_params=tok_params + prompts=prompts_seq, tok_params=tok_params ) async def pre_process_offline_async(self, ctx: OfflineInputsContext): @@ -243,3 +248,19 @@ class PoolingIOProcessor: "Refused request with untrusted chat template." ) return None + + def _params_to_seq( + self, + params: PoolingParams | Sequence[PoolingParams], + num_requests: int, + ) -> Sequence[PoolingParams]: + if isinstance(params, Sequence): + if len(params) != num_requests: + raise ValueError( + f"The lengths of prompts ({num_requests}) " + f"and params ({len(params)}) must be the same." + ) + + return params + + return [params] * num_requests diff --git a/vllm/entrypoints/pooling/base/serving.py b/vllm/entrypoints/pooling/base/serving.py index 90554aa63..10f5cd6b5 100644 --- a/vllm/entrypoints/pooling/base/serving.py +++ b/vllm/entrypoints/pooling/base/serving.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod from collections.abc import AsyncGenerator, Mapping from http import HTTPStatus from typing import ClassVar @@ -9,7 +11,7 @@ from fastapi.responses import Response from starlette.datastructures import Headers from vllm import PoolingParams, PoolingRequestOutput, envs -from vllm.config import ModelConfig +from vllm.config import VllmConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ( ChatTemplateConfig, @@ -35,7 +37,7 @@ from vllm.utils.async_utils import merge_async_iterators from .io_processor import PoolingIOProcessor -class PoolingServing: +class PoolingServingBase(ABC): request_id_prefix: ClassVar[str] def __init__( @@ -50,10 +52,11 @@ class PoolingServing: return_tokens_as_token_ids: bool = False, log_error_stack: bool = False, ): - super().__init__() self.engine_client = engine_client self.models = models self.model_config = models.model_config + self.renderer = models.renderer + self.vllm_config = engine_client.vllm_config self.max_model_len = self.model_config.max_model_len self.request_logger = request_logger self.return_tokens_as_token_ids = return_tokens_as_token_ids @@ -63,31 +66,14 @@ class PoolingServing: chat_template_content_format=chat_template_content_format, trust_request_chat_template=trust_request_chat_template, ) - self.io_processor = self.init_io_processor( - model_config=models.model_config, - renderer=models.renderer, - chat_template_config=self.chat_template_config, - ) - - def init_io_processor( - self, - model_config: ModelConfig, - renderer: BaseRenderer, - chat_template_config: ChatTemplateConfig, - ) -> PoolingIOProcessor: - raise NotImplementedError + @abstractmethod async def __call__( self, request: AnyPoolingRequest, raw_request: Request | None = None, ) -> Response: - ctx = await self._init_ctx(request, raw_request) - await self.io_processor.pre_process_online_async(ctx) - await self._prepare_generators(ctx) - await self._collect_batch(ctx) - await self.io_processor.post_process_online_async(ctx) - return await self._build_response(ctx) + raise NotImplementedError async def _init_ctx( self, @@ -124,10 +110,8 @@ class PoolingServing: else await self._get_trace_headers(ctx.raw_request.headers) ) - if ctx.pooling_params is None: - pooling_params = self.io_processor.create_pooling_params(ctx.request) - else: - pooling_params = ctx.pooling_params + assert ctx.pooling_params is not None + pooling_params = ctx.pooling_params if isinstance(pooling_params, list): for params in pooling_params: @@ -190,6 +174,7 @@ class PoolingServing: ctx.final_res_batch = [res for res in final_res_batch if res is not None] + @abstractmethod async def _build_response( self, ctx: PoolingServeContext, @@ -355,3 +340,39 @@ class PoolingServing: params=params, lora_request=lora_request, ) + + +class PoolingServing(PoolingServingBase, ABC): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.io_processor = self.init_io_processor( + vllm_config=self.vllm_config, + renderer=self.renderer, + chat_template_config=self.chat_template_config, + ) + + @abstractmethod + def init_io_processor( + self, + vllm_config: VllmConfig, + renderer: BaseRenderer, + chat_template_config: ChatTemplateConfig, + ) -> PoolingIOProcessor: + raise NotImplementedError + + async def __call__( + self, + request: AnyPoolingRequest, + raw_request: Request | None = None, + ) -> Response: + ctx = await self._init_ctx(request, raw_request) + await self.io_processor.pre_process_online_async(ctx) + + if ctx.pooling_params is None: + ctx.pooling_params = self.io_processor.create_pooling_params(request) + + await self._prepare_generators(ctx) + await self._collect_batch(ctx) + await self.io_processor.post_process_online_async(ctx) + return await self._build_response(ctx) diff --git a/vllm/entrypoints/pooling/classify/io_processor.py b/vllm/entrypoints/pooling/classify/io_processor.py index ee73207df..9bb3774ab 100644 --- a/vllm/entrypoints/pooling/classify/io_processor.py +++ b/vllm/entrypoints/pooling/classify/io_processor.py @@ -5,4 +5,8 @@ from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor class ClassifyIOProcessor(PoolingIOProcessor): - name = "classification" + name = "classify" + + +class TokenClassifyIOProcessor(PoolingIOProcessor): + name = "token_classify" diff --git a/vllm/entrypoints/pooling/classify/serving.py b/vllm/entrypoints/pooling/classify/serving.py index 24d4f9aac..0a729075b 100644 --- a/vllm/entrypoints/pooling/classify/serving.py +++ b/vllm/entrypoints/pooling/classify/serving.py @@ -6,14 +6,11 @@ from typing import TypeAlias import numpy as np from fastapi.responses import JSONResponse -from vllm.config import ModelConfig -from vllm.entrypoints.chat_utils import ChatTemplateConfig from vllm.entrypoints.openai.engine.protocol import UsageInfo from vllm.entrypoints.pooling.base.serving import PoolingServing from vllm.entrypoints.pooling.typing import PoolingServeContext from vllm.logger import init_logger from vllm.outputs import ClassificationOutput -from vllm.renderers import BaseRenderer from .io_processor import ClassifyIOProcessor from .protocol import ( @@ -31,17 +28,8 @@ ClassificationServeContext: TypeAlias = PoolingServeContext[ClassificationReques class ServingClassification(PoolingServing): request_id_prefix = "classify" - def init_io_processor( - self, - model_config: ModelConfig, - renderer: BaseRenderer, - chat_template_config: ChatTemplateConfig, - ) -> ClassifyIOProcessor: - return ClassifyIOProcessor( - model_config=model_config, - renderer=renderer, - chat_template_config=chat_template_config, - ) + def init_io_processor(self, *args, **kwargs) -> ClassifyIOProcessor: + return ClassifyIOProcessor(*args, **kwargs) async def _build_response( self, diff --git a/vllm/entrypoints/pooling/embed/io_processor.py b/vllm/entrypoints/pooling/embed/io_processor.py index 614f8e0d9..623fee4fd 100644 --- a/vllm/entrypoints/pooling/embed/io_processor.py +++ b/vllm/entrypoints/pooling/embed/io_processor.py @@ -37,7 +37,7 @@ logger = init_logger(__name__) class EmbedIOProcessor(PoolingIOProcessor): - name = "embedding" + name = "embed" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -549,3 +549,7 @@ class EmbedIOProcessor(PoolingIOProcessor): request = ctx.request if request.truncate == "NONE" and request.max_tokens is not None: self._check_cohere_max_tokens(ctx.final_res_batch, request.max_tokens) + + +class TokenEmbedIOProcessor(PoolingIOProcessor): + name = "token_embed" diff --git a/vllm/entrypoints/pooling/embed/serving.py b/vllm/entrypoints/pooling/embed/serving.py index f0c331645..3abf0c7f3 100644 --- a/vllm/entrypoints/pooling/embed/serving.py +++ b/vllm/entrypoints/pooling/embed/serving.py @@ -8,8 +8,6 @@ from typing import Literal, TypeAlias, cast from fastapi.responses import JSONResponse, Response, StreamingResponse from typing_extensions import assert_never -from vllm.config import ModelConfig -from vllm.entrypoints.chat_utils import ChatTemplateConfig from vllm.entrypoints.openai.engine.protocol import UsageInfo from vllm.entrypoints.pooling.base.serving import PoolingServing from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor @@ -33,12 +31,10 @@ from vllm.entrypoints.pooling.utils import ( ) from vllm.logger import init_logger from vllm.outputs import PoolingRequestOutput -from vllm.renderers import BaseRenderer from vllm.utils.serial_utils import EmbedDType, Endianness logger = init_logger(__name__) -JSONResponseCLS = get_json_response_cls() EmbeddingServeContext: TypeAlias = PoolingServeContext[EmbeddingRequest] @@ -49,17 +45,13 @@ class ServingEmbedding(PoolingServing): request_id_prefix = "embd" io_processor: EmbedIOProcessor - def init_io_processor( - self, - model_config: ModelConfig, - renderer: BaseRenderer, - chat_template_config: ChatTemplateConfig, - ) -> EmbedIOProcessor: - return EmbedIOProcessor( - model_config=model_config, - renderer=renderer, - chat_template_config=chat_template_config, - ) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.json_response_cls = get_json_response_cls() + + def init_io_processor(self, *args, **kwargs) -> EmbedIOProcessor: + return EmbedIOProcessor(*args, **kwargs) async def _build_response( self, @@ -149,7 +141,7 @@ class ServingEmbedding(PoolingServing): data=items, usage=usage, ) - return JSONResponseCLS(content=response.model_dump()) + return self.json_response_cls(content=response.model_dump()) def _openai_bytes_response( self, @@ -190,8 +182,8 @@ class ServingEmbedding(PoolingServing): media_type=response.media_type, ) - @staticmethod def _build_cohere_response_from_ctx( + self, ctx: PoolingServeContext, ) -> JSONResponse: request = ctx.request @@ -218,4 +210,4 @@ class ServingEmbedding(PoolingServing): ), ), ) - return JSONResponse(content=response.model_dump(exclude_none=True)) + return self.json_response_cls(content=response.model_dump(exclude_none=True)) diff --git a/vllm/entrypoints/pooling/io_processor_factories.py b/vllm/entrypoints/pooling/io_processor_factories.py index 71033bd23..de60a746a 100644 --- a/vllm/entrypoints/pooling/io_processor_factories.py +++ b/vllm/entrypoints/pooling/io_processor_factories.py @@ -1,42 +1,69 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from vllm.config import ModelConfig +from vllm.config import VllmConfig from vllm.entrypoints.chat_utils import ChatTemplateConfig -from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor -from vllm.entrypoints.pooling.scoring.io_processor import ScoringIOProcessors -from vllm.entrypoints.pooling.utils import enable_scoring_api +from vllm.plugins.io_processors import has_io_processor from vllm.renderers import BaseRenderer from vllm.tasks import SupportedTask +from .base.io_processor import PoolingIOProcessor +from .utils import enable_scoring_api + def init_pooling_io_processors( supported_tasks: tuple[SupportedTask, ...], - model_config: ModelConfig, + vllm_config: VllmConfig, renderer: BaseRenderer, chat_template_config: ChatTemplateConfig, ) -> dict[str, PoolingIOProcessor]: - processors: list[tuple[str, type[PoolingIOProcessor]]] = [] + model_config = vllm_config.model_config + processors: dict[str, type[PoolingIOProcessor]] = {} + if "classify" in supported_tasks: - from vllm.entrypoints.pooling.classify.io_processor import ClassifyIOProcessor + from .classify.io_processor import ClassifyIOProcessor + + processors["classify"] = ClassifyIOProcessor + + if "token_classify" in supported_tasks: + from .classify.io_processor import TokenClassifyIOProcessor + + processors["token_classify"] = TokenClassifyIOProcessor - processors.append(("classify", ClassifyIOProcessor)) if "embed" in supported_tasks: - from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor + from .embed.io_processor import EmbedIOProcessor - processors.append(("embed", EmbedIOProcessor)) + processors["embed"] = EmbedIOProcessor + + if "token_embed" in supported_tasks: + from .embed.io_processor import TokenEmbedIOProcessor + + processors["token_embed"] = TokenEmbedIOProcessor + + if has_io_processor( + vllm_config, + model_config.io_processor_plugin, + ): + from .pooling.io_processor import PluginWithIOProcessorPlugins + + processors["plugin"] = PluginWithIOProcessorPlugins + elif "plugin" in supported_tasks: + from .pooling.io_processor import PluginWithoutIOProcessorPlugins + + processors["plugin"] = PluginWithoutIOProcessorPlugins if enable_scoring_api(supported_tasks, model_config): score_type = model_config.score_type + from .scoring.io_processor import ScoringIOProcessors + if score_type is not None and score_type in ScoringIOProcessors: - processors.append((score_type, ScoringIOProcessors[score_type])) + processors[score_type] = ScoringIOProcessors[score_type] return { task: processor_cls( - model_config=model_config, + vllm_config=vllm_config, renderer=renderer, chat_template_config=chat_template_config, ) - for task, processor_cls in processors + for task, processor_cls in processors.items() } diff --git a/vllm/entrypoints/pooling/pooling/api_router.py b/vllm/entrypoints/pooling/pooling/api_router.py index f63a8edf6..a08570038 100644 --- a/vllm/entrypoints/pooling/pooling/api_router.py +++ b/vllm/entrypoints/pooling/pooling/api_router.py @@ -3,24 +3,17 @@ from http import HTTPStatus from fastapi import APIRouter, Depends, Request -from fastapi.responses import JSONResponse, StreamingResponse -from typing_extensions import assert_never from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.utils import validate_json_request -from vllm.entrypoints.pooling.pooling.protocol import ( - IOProcessorResponse, - PoolingBytesResponse, - PoolingRequest, - PoolingResponse, -) -from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling +from vllm.entrypoints.pooling.pooling.protocol import PoolingRequest +from vllm.entrypoints.pooling.pooling.serving import ServingPooling from vllm.entrypoints.utils import load_aware_call, with_cancellation router = APIRouter() -def pooling(request: Request) -> OpenAIServingPooling | None: +def pooling(request: Request) -> ServingPooling | None: return request.app.state.serving_pooling @@ -39,19 +32,4 @@ async def create_pooling(request: PoolingRequest, raw_request: Request): if handler is None: raise NotImplementedError("The model does not support Pooling API") - generator = await handler.create_pooling(request, raw_request) - - if isinstance(generator, ErrorResponse): - return JSONResponse( - content=generator.model_dump(), status_code=generator.error.code - ) - elif isinstance(generator, (PoolingResponse, IOProcessorResponse)): - return JSONResponse(content=generator.model_dump()) - elif isinstance(generator, PoolingBytesResponse): - return StreamingResponse( - content=generator.content, - headers=generator.headers, - media_type=generator.media_type, - ) - - assert_never(generator) + return await handler(request, raw_request) diff --git a/vllm/entrypoints/pooling/pooling/io_processor.py b/vllm/entrypoints/pooling/pooling/io_processor.py new file mode 100644 index 000000000..31f860144 --- /dev/null +++ b/vllm/entrypoints/pooling/pooling/io_processor.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Sequence +from typing import Any + +from vllm import PoolingParams, PoolingRequestOutput +from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor +from vllm.inputs import EngineInput +from vllm.logger import init_logger +from vllm.plugins.io_processors import get_io_processor +from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq + +from ..typing import OfflineInputsContext, OfflineOutputsContext, PoolingServeContext +from .protocol import IOProcessorRequest, IOProcessorResponse + +logger = init_logger(__name__) + + +class PluginWithoutIOProcessorPlugins(PoolingIOProcessor): + name = "plugin" + + +class PluginWithIOProcessorPlugins(PoolingIOProcessor): + """IO Processor plugins are a feature that allows pre- and post-processing + of the model input and output for pooling models.""" + + name = "plugin" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + io_processor = get_io_processor( + self.vllm_config, + self.renderer, + self.model_config.io_processor_plugin, + ) + + assert io_processor is not None + self.io_processor = io_processor + + ####################################### + # online APIs + + def pre_process_online(self, ctx: PoolingServeContext): + assert isinstance(ctx.request, IOProcessorRequest) + + validated_prompt = self.io_processor.parse_data(ctx.request.data) + + raw_prompts = self.io_processor.pre_process( + prompt=validated_prompt, request_id=ctx.request_id + ) + + parsed_prompts = [ + ( + prompt + if isinstance(prompt, bytes) + else parse_model_prompt(self.model_config, prompt) + ) + for prompt in prompt_to_seq(raw_prompts) + ] + + tok_params = ctx.request.build_tok_params(self.model_config) + + ctx.engine_inputs = self.renderer.render_cmpl( + parsed_prompts, + tok_params, + prompt_extras={ + k: v + for k in ("mm_processor_kwargs", "cache_salt") + if (v := getattr(ctx.request, k, None)) is not None + }, + ) + + pooling_params = self.io_processor.merge_pooling_params() + if pooling_params.task is None: + pooling_params.task = "plugin" + ctx.pooling_params = pooling_params + + def post_process_online( + self, + ctx: PoolingServeContext, + ): + output = self.io_processor.post_process( + ctx.final_res_batch, + request_id=ctx.request_id, + ) + + if callable( + output_to_response := getattr(self.io_processor, "output_to_response", None) + ): + logger.warning_once( + "`IOProcessor.output_to_response` is deprecated. To ensure " + "consistency between offline and online APIs, " + "`IOProcessorResponse` will become a transparent wrapper " + "around output data from v0.19 onwards.", + ) + + if hasattr(output, "request_id") and output.request_id is None: + output.request_id = ctx.request_id # type: ignore + + ctx.response = output_to_response(output) # type: ignore + else: + ctx.response = IOProcessorResponse(request_id=ctx.request_id, data=output) + + ####################################### + # offline APIs + + def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]: + assert isinstance(ctx.prompts, dict) and "data" in ctx.prompts + assert ctx.pooling_params is not None + + # Validate the request data is valid for the loaded plugin + prompt_data = ctx.prompts.get("data") + if prompt_data is None: + raise ValueError( + "The 'data' field of the prompt is expected to contain " + "the prompt data and it cannot be None. " + "Refer to the documentation of the IOProcessor " + "in use for more details." + ) + validated_prompt = self.io_processor.parse_data(prompt_data) + + # obtain the actual model prompts from the pre-processor + prompts = self.io_processor.pre_process(prompt=validated_prompt) + prompts_seq = prompt_to_seq(prompts) + + params_seq: list[PoolingParams] = [ + self.io_processor.merge_pooling_params(param) + for param in self._params_to_seq( + ctx.pooling_params, + len(prompts_seq), + ) + ] + for p in params_seq: + if p.task is None: + p.task = "plugin" + + ctx.pooling_params = params_seq + ctx.prompts = prompts_seq + return super().pre_process_offline(ctx) + + def post_process_offline( + self, + ctx: OfflineOutputsContext, + ) -> list[PoolingRequestOutput]: + processed_outputs = self.io_processor.post_process(ctx.outputs) + + return [ + PoolingRequestOutput[Any]( + request_id="", + outputs=processed_outputs, + num_cached_tokens=getattr(processed_outputs, "num_cached_tokens", 0), + prompt_token_ids=[], + finished=True, + ) + ] diff --git a/vllm/entrypoints/pooling/pooling/serving.py b/vllm/entrypoints/pooling/pooling/serving.py index 4706684f3..3d16c1f2c 100644 --- a/vllm/entrypoints/pooling/pooling/serving.py +++ b/vllm/entrypoints/pooling/pooling/serving.py @@ -1,252 +1,157 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import asyncio + import json -import time -from collections.abc import AsyncGenerator, Callable, Sequence +from collections.abc import Callable from functools import partial -from typing import Final, Literal, cast +from typing import Literal, cast from fastapi import Request +from fastapi.responses import JSONResponse, Response, StreamingResponse from typing_extensions import assert_never -from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption -from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo -from vllm.entrypoints.openai.engine.serving import OpenAIServing -from vllm.entrypoints.openai.models.serving import OpenAIServingModels +from vllm.entrypoints.openai.engine.protocol import UsageInfo +from vllm.entrypoints.pooling.base.serving import PoolingServingBase +from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors from vllm.entrypoints.pooling.pooling.protocol import ( IOProcessorRequest, - IOProcessorResponse, PoolingBytesResponse, - PoolingChatRequest, - PoolingCompletionRequest, PoolingRequest, PoolingResponse, PoolingResponseData, ) +from vllm.entrypoints.pooling.typing import AnyPoolingRequest, PoolingServeContext from vllm.entrypoints.pooling.utils import ( encode_pooling_bytes, encode_pooling_output_base64, encode_pooling_output_float, + get_json_response_cls, ) -from vllm.entrypoints.serve.render.serving import OpenAIServingRender -from vllm.inputs import EngineInput from vllm.logger import init_logger from vllm.outputs import PoolingRequestOutput -from vllm.renderers.inputs.preprocess import prompt_to_seq from vllm.tasks import SupportedTask -from vllm.utils.async_utils import merge_async_iterators -from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness +from vllm.utils.serial_utils import EmbedDType, Endianness logger = init_logger(__name__) -class OpenAIServingPooling(OpenAIServing): +class ServingPooling(PoolingServingBase): + request_id_prefix = "pooling" + def __init__( self, - engine_client: EngineClient, - models: OpenAIServingModels, - openai_serving_render: OpenAIServingRender, + *args, supported_tasks: tuple[SupportedTask, ...], - *, - request_logger: RequestLogger | None, - chat_template: str | None, - chat_template_content_format: ChatTemplateContentFormatOption, - trust_request_chat_template: bool = False, - ) -> None: - super().__init__( - engine_client=engine_client, - models=models, - request_logger=request_logger, - ) + **kwargs, + ): + super().__init__(*args, **kwargs) + self.supported_tasks = supported_tasks self.pooling_task = self.model_config.get_pooling_task(supported_tasks) - self.openai_serving_render = openai_serving_render - self.chat_template = chat_template - self.chat_template_content_format: Final = chat_template_content_format - self.trust_request_chat_template = trust_request_chat_template + self.io_processors = init_pooling_io_processors( + supported_tasks=supported_tasks, + vllm_config=self.vllm_config, + renderer=self.renderer, + chat_template_config=self.chat_template_config, + ) + self.json_response_cls = get_json_response_cls() - async def create_pooling( + async def __call__( self, - request: PoolingRequest, + request: AnyPoolingRequest, raw_request: Request | None = None, - ) -> PoolingResponse | IOProcessorResponse | PoolingBytesResponse | ErrorResponse: - """ - See https://platform.openai.com/docs/api-reference/embeddings/create - for the API specification. This API mimics the OpenAI Embedding API. - """ - error_check_ret = await self._check_model(request) - if error_check_ret is not None: - return error_check_ret + ) -> Response: + assert isinstance(request, PoolingRequest) + pooling_task = self._verify_pooling_task(request) - model_name = self.models.model_name() + io_processor = self.io_processors[pooling_task] + ctx = await self._init_ctx(request, raw_request) - request_id = f"pool-{self._base_request_id(raw_request)}" - created_time = int(time.time()) + await io_processor.pre_process_online_async(ctx) - lora_request = self._maybe_get_adapters(request) + if ctx.pooling_params is None: + ctx.pooling_params = io_processor.create_pooling_params(request) + + await self._prepare_generators(ctx) + await self._collect_batch(ctx) + + await io_processor.post_process_online_async(ctx) + return await self._build_response(ctx) + + def _verify_pooling_task(self, request: PoolingRequest) -> str: + if getattr(request, "dimensions", None) is not None: + raise ValueError("dimensions is currently not supported") if request.task is None: request.task = self.pooling_task - if getattr(request, "dimensions", None) is not None: - return self.create_error_response("dimensions is currently not supported") + if isinstance(request, IOProcessorRequest): + request.task = "plugin" + + assert request.task is not None + pooling_task = request.task # plugin task uses io_processor.parse_request to verify inputs - if request.task != "plugin" and request.task != self.pooling_task: - if request.task not in self.supported_tasks: + if pooling_task != "plugin" and pooling_task != self.pooling_task: + if pooling_task not in self.io_processors: raise ValueError( - f"Unsupported task: {request.task!r} " + f"Unsupported task: {pooling_task!r} " f"Supported tasks: {self.supported_tasks}" ) else: logger.warning_once( "Pooling multitask support is deprecated and will be removed " "in v0.20. When the default pooling task is not what you want, you " - 'need to manually specify it via --pooler-config.task "%s". ', - request.task, + "need to manually specify it via --pooler-config.task %s. ", + pooling_task, ) - engine_inputs: Sequence[EngineInput] - if use_io_processor := isinstance(request, IOProcessorRequest): - if self.io_processor is None: - raise ValueError( - "No IOProcessor plugin installed. Please refer " - "to the documentation and to the " - "'prithvi_geospatial_mae_io_processor' " - "offline inference example for more details." - ) - - validated_prompt = self.io_processor.parse_data(request.data) - - raw_prompts = await self.io_processor.pre_process_async( - prompt=validated_prompt, request_id=request_id - ) - engine_inputs = await self.openai_serving_render.preprocess_cmpl( - request, - prompt_to_seq(raw_prompts), - ) - elif isinstance(request, PoolingChatRequest): - error_check_ret = self.openai_serving_render.validate_chat_template( - request_chat_template=request.chat_template, - chat_template_kwargs=request.chat_template_kwargs, - trust_request_chat_template=self.trust_request_chat_template, - ) - if error_check_ret is not None: - return error_check_ret - - _, engine_inputs = await self.openai_serving_render.preprocess_chat( - request, - request.messages, - default_template=self.chat_template, - default_template_content_format=self.chat_template_content_format, - default_template_kwargs=None, - ) - elif isinstance(request, PoolingCompletionRequest): - engine_inputs = await self.openai_serving_render.preprocess_completion( - request, - prompt_input=request.input, - prompt_embeds=None, - ) - else: - raise ValueError(f"Unsupported request of type {type(request)}") - - # Schedule the request and get the result generator. - generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] - if use_io_processor: - assert self.io_processor is not None - - pooling_params = self.io_processor.merge_pooling_params() - if pooling_params.task is None: - pooling_params.task = "plugin" - else: - pooling_params = request.to_pooling_params() # type: ignore - - for i, engine_input in enumerate(engine_inputs): - request_id_item = f"{request_id}-{i}" - - self._log_inputs( - request_id_item, - engine_input, - params=pooling_params, - lora_request=lora_request, + if pooling_task == "plugin" and "plugin" not in self.io_processors: + raise ValueError( + "No IOProcessor plugin installed. Please refer " + "to the documentation and to the " + "'prithvi_geospatial_mae_io_processor' " + "offline inference example for more details." ) - trace_headers = ( - None - if raw_request is None - else await self._get_trace_headers(raw_request.headers) + return pooling_task + + async def _build_response( + self, + ctx: PoolingServeContext, + ) -> Response: + if ctx.response is not None: + # for IOProcessorResponse + return self.json_response_cls(content=ctx.response.model_dump()) + + encoding_format = ctx.request.encoding_format + embed_dtype = ctx.request.embed_dtype + endianness = ctx.request.endianness + + if encoding_format == "float" or encoding_format == "base64": + return self.request_output_to_pooling_json_response( + ctx.final_res_batch, + ctx.request_id, + ctx.created_time, + ctx.model_name, + encoding_format, + embed_dtype, + endianness, ) - generator = self.engine_client.encode( - engine_input, - pooling_params, - request_id_item, - lora_request=lora_request, - trace_headers=trace_headers, - priority=request.priority, + if encoding_format == "bytes" or encoding_format == "bytes_only": + return self.request_output_to_pooling_bytes_response( + ctx.final_res_batch, + ctx.request_id, + ctx.created_time, + ctx.model_name, + encoding_format, + embed_dtype, + endianness, ) - generators.append(generator) - - result_generator = merge_async_iterators(*generators) - - if use_io_processor: - assert self.io_processor is not None - output = await self.io_processor.post_process_async( - result_generator, - request_id=request_id, - ) - - if callable( - output_to_response := getattr( - self.io_processor, "output_to_response", None - ) - ): - logger.warning_once( - "`IOProcessor.output_to_response` is deprecated. To ensure " - "consistency between offline and online APIs, " - "`IOProcessorResponse` will become a transparent wrapper " - "around output data from v0.19 onwards.", - ) - - if hasattr(output, "request_id") and output.request_id is None: - output.request_id = request_id # type: ignore - - return output_to_response(output) # type: ignore - - return IOProcessorResponse(request_id=request_id, data=output) - - assert isinstance(request, (PoolingCompletionRequest, PoolingChatRequest)) - num_prompts = len(engine_inputs) - - # Non-streaming response - final_res_batch: list[PoolingRequestOutput | None] - final_res_batch = [None] * num_prompts - try: - async for i, res in result_generator: - final_res_batch[i] = res - - assert all(final_res is not None for final_res in final_res_batch) - - final_res_batch_checked = cast(list[PoolingRequestOutput], final_res_batch) - - response = self.request_output_to_pooling_response( - final_res_batch_checked, - request_id, - created_time, - model_name, - request.encoding_format, - request.embed_dtype, - request.endianness, - ) - except asyncio.CancelledError: - return self.create_error_response("Client disconnected") - - return response + assert_never(encoding_format) def request_output_to_pooling_json_response( self, @@ -257,7 +162,7 @@ class OpenAIServingPooling(OpenAIServing): encoding_format: Literal["float", "base64"], embed_dtype: EmbedDType, endianness: Endianness, - ) -> PoolingResponse: + ) -> JSONResponse: encode_fn = cast( Callable[[PoolingRequestOutput], list[float] | str], ( @@ -289,13 +194,14 @@ class OpenAIServingPooling(OpenAIServing): total_tokens=num_prompt_tokens, ) - return PoolingResponse( + response = PoolingResponse( id=request_id, created=created_time, model=model_name, data=items, usage=usage, ) + return self.json_response_cls(content=response.model_dump()) def request_output_to_pooling_bytes_response( self, @@ -306,7 +212,7 @@ class OpenAIServingPooling(OpenAIServing): encoding_format: Literal["bytes", "bytes_only"], embed_dtype: EmbedDType, endianness: Endianness, - ) -> PoolingBytesResponse: + ) -> StreamingResponse: content, items, usage = encode_pooling_bytes( pooling_outputs=final_res_batch, embed_dtype=embed_dtype, @@ -329,38 +235,10 @@ class OpenAIServingPooling(OpenAIServing): } ) - return PoolingBytesResponse(content=content, headers=headers) + response = PoolingBytesResponse(content=content, headers=headers) - def request_output_to_pooling_response( - self, - final_res_batch: list[PoolingRequestOutput], - request_id: str, - created_time: int, - model_name: str, - encoding_format: EncodingFormat, - embed_dtype: EmbedDType, - endianness: Endianness, - ) -> PoolingResponse | PoolingBytesResponse: - if encoding_format == "float" or encoding_format == "base64": - return self.request_output_to_pooling_json_response( - final_res_batch, - request_id, - created_time, - model_name, - encoding_format, - embed_dtype, - endianness, - ) - - if encoding_format == "bytes" or encoding_format == "bytes_only": - return self.request_output_to_pooling_bytes_response( - final_res_batch, - request_id, - created_time, - model_name, - encoding_format, - embed_dtype, - endianness, - ) - - assert_never(encoding_format) + return StreamingResponse( + content=response.content, + headers=response.headers, + media_type=response.media_type, + ) diff --git a/vllm/entrypoints/pooling/scoring/io_processor.py b/vllm/entrypoints/pooling/scoring/io_processor.py index c520eb5ce..dd505c79c 100644 --- a/vllm/entrypoints/pooling/scoring/io_processor.py +++ b/vllm/entrypoints/pooling/scoring/io_processor.py @@ -278,7 +278,7 @@ class CrossEncoderIOProcessor(ScoringIOProcessor): def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]: assert isinstance(ctx.prompts, ScoringData) - assert not isinstance(ctx.pooling_params, list) + assert not isinstance(ctx.pooling_params, Sequence) tok_params = self.renderer.default_cmpl_tok_params.with_kwargs( **(ctx.tokenization_kwargs or {}) diff --git a/vllm/entrypoints/pooling/scoring/serving.py b/vllm/entrypoints/pooling/scoring/serving.py index de5b5797c..fc5207561 100644 --- a/vllm/entrypoints/pooling/scoring/serving.py +++ b/vllm/entrypoints/pooling/scoring/serving.py @@ -4,15 +4,13 @@ from fastapi.responses import JSONResponse, Response from vllm import PoolingParams -from vllm.config import ModelConfig +from vllm.config import VllmConfig from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import ChatTemplateConfig from vllm.entrypoints.openai.engine.protocol import UsageInfo from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor from vllm.entrypoints.pooling.base.serving import PoolingServing from vllm.logger import init_logger from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput -from vllm.renderers import BaseRenderer from vllm.v1.pool.late_interaction import ( build_late_interaction_doc_params, build_late_interaction_query_params, @@ -52,22 +50,17 @@ class ServingScores(PoolingServing): super().__init__(engine_client, *args, **kwargs) def init_io_processor( - self, - model_config: ModelConfig, - renderer: BaseRenderer, - chat_template_config: ChatTemplateConfig, + self, vllm_config: VllmConfig, *args, **kwargs ) -> PoolingIOProcessor: + model_config = vllm_config.model_config + score_type: str = model_config.score_type if self.enable_flash_late_interaction: score_type = "flash-late-interaction" assert score_type in ScoringIOProcessors processor_cls = ScoringIOProcessors[score_type] - return processor_cls( - model_config=model_config, - renderer=renderer, - chat_template_config=chat_template_config, - ) + return processor_cls(vllm_config, *args, **kwargs) async def __call__(self, *args, **kwargs) -> Response: if not self.enable_flash_late_interaction: diff --git a/vllm/entrypoints/pooling/typing.py b/vllm/entrypoints/pooling/typing.py index 66dd9dd4d..ead772251 100644 --- a/vllm/entrypoints/pooling/typing.py +++ b/vllm/entrypoints/pooling/typing.py @@ -30,7 +30,7 @@ from vllm.entrypoints.pooling.pooling.protocol import ( ) from vllm.entrypoints.pooling.scoring.protocol import ScoringRequest, ScoringResponse from vllm.entrypoints.pooling.scoring.typing import ScoringData -from vllm.inputs import EngineInput +from vllm.inputs import DataPrompt, EngineInput from vllm.lora.request import LoRARequest PoolingCompletionLikeRequest: TypeAlias = ( @@ -86,11 +86,14 @@ class PoolingServeContext(Generic[PoolingRequestT]): ## for bi-encoder & late-interaction n_queries: int | None = None + ## for IOProcessorResponse + response: Any | None = None + @dataclass class OfflineInputsContext: - prompts: PromptType | Sequence[PromptType] | ScoringData - pooling_params: PoolingParams | list[PoolingParams] | None = None + prompts: PromptType | Sequence[PromptType] | DataPrompt | ScoringData + pooling_params: PoolingParams | Sequence[PoolingParams] tokenization_kwargs: dict[str, Any] | None = None chat_template: str | None = None diff --git a/vllm/entrypoints/sagemaker/api_router.py b/vllm/entrypoints/sagemaker/api_router.py index 45f5613bf..1d63793df 100644 --- a/vllm/entrypoints/sagemaker/api_router.py +++ b/vllm/entrypoints/sagemaker/api_router.py @@ -14,7 +14,7 @@ from vllm.config import ModelConfig from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.utils import validate_json_request -from vllm.entrypoints.pooling.base.serving import PoolingServing +from vllm.entrypoints.pooling.base.serving import PoolingServingBase from vllm.entrypoints.pooling.utils import enable_scoring_api from vllm.entrypoints.serve.instrumentator.basic import base from vllm.entrypoints.serve.instrumentator.health import health @@ -23,7 +23,7 @@ from vllm.tasks import POOLING_TASKS, SupportedTask # TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers # (requires typing_extensions >= 4.13) RequestType = Any -GetHandlerFn = Callable[[Request], OpenAIServing | PoolingServing | None] +GetHandlerFn = Callable[[Request], OpenAIServing | PoolingServingBase | None] EndpointFn = Callable[[RequestType, Request], Awaitable[Any]] diff --git a/vllm/entrypoints/serve/render/serving.py b/vllm/entrypoints/serve/render/serving.py index 2aaa83e75..c99c01441 100644 --- a/vllm/entrypoints/serve/render/serving.py +++ b/vllm/entrypoints/serve/render/serving.py @@ -65,7 +65,6 @@ class OpenAIServingRender: self, model_config: ModelConfig, renderer: BaseRenderer, - io_processor: Any, model_registry: OpenAIModelRegistry, *, request_logger: RequestLogger | None, @@ -81,7 +80,6 @@ class OpenAIServingRender: ) -> None: self.model_config = model_config self.renderer = renderer - self.io_processor = io_processor self.model_registry = model_registry self.request_logger = request_logger self.chat_template = chat_template diff --git a/vllm/plugins/io_processors/__init__.py b/vllm/plugins/io_processors/__init__.py index c8cb4f185..c502f4f74 100644 --- a/vllm/plugins/io_processors/__init__.py +++ b/vllm/plugins/io_processors/__init__.py @@ -12,6 +12,23 @@ from vllm.utils.import_utils import resolve_obj_by_qualname logger = logging.getLogger(__name__) +def has_io_processor( + vllm_config: VllmConfig, + plugin_from_init: str | None = None, +): + if plugin_from_init: + model_plugin = plugin_from_init + else: + # A plugin can be specified via the model config + # Retrieve the model specific plugin if available + # This is using a custom field in the hf_config for the model + hf_config = vllm_config.model_config.hf_config.to_dict() + config_plugin = hf_config.get("io_processor_plugin") + model_plugin = config_plugin + + return model_plugin is not None + + def get_io_processor( vllm_config: VllmConfig, renderer: BaseRenderer, diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 7ff324f12..1c87d9ec0 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -26,7 +26,6 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput -from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams from vllm.renderers import renderer_from_config from vllm.renderers.inputs.preprocess import extract_prompt_components @@ -133,11 +132,6 @@ class AsyncLLM(EngineClient): ) self.renderer = renderer = renderer_from_config(self.vllm_config) - self.io_processor = get_io_processor( - self.vllm_config, - self.renderer, - self.model_config.io_processor_plugin, - ) # Convert EngineInput --> EngineCoreRequest. self.input_processor = InputProcessor(self.vllm_config, renderer) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 4b6a7ba44..d0545651b 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -19,7 +19,6 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.outputs import PoolingRequestOutput, RequestOutput -from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams from vllm.renderers import renderer_from_config from vllm.renderers.inputs.preprocess import extract_prompt_components @@ -90,11 +89,6 @@ class LLMEngine: self.should_execute_dummy_batch = False self.renderer = renderer = renderer_from_config(self.vllm_config) - self.io_processor = get_io_processor( - self.vllm_config, - self.renderer, - self.model_config.io_processor_plugin, - ) # Convert EngineInput --> EngineCoreRequest. self.input_processor = InputProcessor(self.vllm_config, renderer)