[Frontend][4/n] Improve pooling entrypoints | pooling. (#39153)

Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
wang.yuqi
2026-04-09 18:09:45 +08:00
committed by GitHub
parent b6c9be509e
commit 66c079ae83
43 changed files with 554 additions and 733 deletions

View File

@@ -87,7 +87,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
serving_render = OpenAIServingRender(
model_config=engine.model_config,
renderer=engine.renderer,
io_processor=engine.io_processor,
model_registry=models.registry,
request_logger=None,
chat_template=None,
@@ -123,7 +122,6 @@ async def test_chat_error_non_stream():
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine)
@@ -173,7 +171,6 @@ async def test_chat_error_stream():
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine)

View File

@@ -567,7 +567,6 @@ def _build_serving_render(
return OpenAIServingRender(
model_config=engine.model_config,
renderer=engine.renderer,
io_processor=engine.io_processor,
model_registry=model_registry,
request_logger=None,
chat_template=CHAT_TEMPLATE,
@@ -599,7 +598,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
class MockEngine:
model_config: MockModelConfig = field(default_factory=MockModelConfig)
input_processor: MagicMock = field(default_factory=MagicMock)
io_processor: MagicMock = field(default_factory=MagicMock)
renderer: MagicMock = field(default_factory=MagicMock)
@@ -632,7 +630,6 @@ async def test_serving_chat_returns_correct_model_name():
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine)
@@ -662,7 +659,6 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine)
@@ -693,7 +689,6 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.errored = False
mock_engine.model_config = mock_model_config
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
# Initialize the serving chat
@@ -737,7 +732,6 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.errored = False
mock_engine.model_config = mock_model_config
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine)
@@ -779,7 +773,6 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.errored = False
mock_engine.model_config = mock_model_config
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
# Initialize the serving chat
@@ -823,7 +816,6 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated():
mock_engine.errored = False
mock_engine.model_config = MockModelConfig(skip_tokenizer_init=True)
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_tokenizer = MagicMock(spec=MistralTokenizer)
mock_renderer = MistralRenderer(
@@ -863,7 +855,6 @@ async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected():
mock_engine.errored = False
mock_engine.model_config = MockModelConfig(skip_tokenizer_init=True)
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_tokenizer = MagicMock(spec=MistralTokenizer)
mock_renderer = MistralRenderer(
@@ -906,7 +897,6 @@ async def test_serving_chat_could_load_correct_generation_config():
mock_engine.errored = False
mock_engine.model_config = mock_model_config
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
# Initialize the serving chat
@@ -952,7 +942,6 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
mock_engine.errored = False
mock_engine.model_config = mock_model_config
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine)
@@ -1003,7 +992,6 @@ async def test_serving_chat_data_parallel_rank_extraction():
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
# Mock the generate method to return an async generator
@@ -1095,7 +1083,6 @@ class TestServingChatWithHarmony:
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
return mock_engine
@@ -1732,7 +1719,6 @@ async def test_tool_choice_validation_without_parser():
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
models = OpenAIServingModels(
@@ -1802,7 +1788,6 @@ async def test_streaming_n_gt1_independent_tool_parsers():
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
models = OpenAIServingModels(

View File

@@ -79,7 +79,6 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
serving_render = OpenAIServingRender(
model_config=engine.model_config,
renderer=engine.renderer,
io_processor=engine.io_processor,
model_registry=models.registry,
request_logger=None,
chat_template=None,
@@ -107,7 +106,6 @@ async def test_completion_error_non_stream():
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_completion = _build_serving_completion(mock_engine)
@@ -157,7 +155,6 @@ async def test_completion_error_stream():
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_completion = _build_serving_completion(mock_engine)

View File

@@ -137,7 +137,6 @@ def mock_serving_setup():
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
models = OpenAIServingModels(
@@ -148,7 +147,6 @@ def mock_serving_setup():
serving_render = OpenAIServingRender(
model_config=mock_engine.model_config,
renderer=mock_engine.renderer,
io_processor=mock_engine.io_processor,
model_registry=models.registry,
request_logger=None,
chat_template=None,

View File

@@ -77,7 +77,6 @@ def _create_mock_engine():
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
# renderer is accessed by OpenAIServing.__init__ and serving.py
mock_renderer = MagicMock()

View File

@@ -218,7 +218,6 @@ class TestInitializeToolSessions:
engine_client.model_config = model_config
engine_client.input_processor = MagicMock()
engine_client.io_processor = MagicMock()
engine_client.renderer = MagicMock()
models = MagicMock()
@@ -307,7 +306,6 @@ class TestValidateGeneratorInput:
engine_client.model_config = model_config
engine_client.input_processor = MagicMock()
engine_client.io_processor = MagicMock()
engine_client.renderer = MagicMock()
models = MagicMock()
@@ -369,7 +367,6 @@ async def test_reasoning_tokens_counted_for_text_reasoning_model(monkeypatch):
model_config.get_diff_sampling_param.return_value = {}
engine_client.model_config = model_config
engine_client.input_processor = MagicMock()
engine_client.io_processor = MagicMock()
engine_client.renderer = MagicMock()
tokenizer = FakeTokenizer()
@@ -672,7 +669,6 @@ def _make_serving_instance_with_reasoning():
model_config.get_diff_sampling_param.return_value = {}
engine_client.model_config = model_config
engine_client.input_processor = MagicMock()
engine_client.io_processor = MagicMock()
engine_client.renderer = MagicMock()
models = MagicMock()

View File

@@ -110,8 +110,11 @@ def test_score_api(llm: LLM):
llm.score("ping", "pong", use_tqdm=False)
@pytest.mark.parametrize("task", ["embed", "token_embed"])
@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
def test_unsupported_tasks(llm: LLM, task: PoolingTask):
err_msg = "Embedding API is not supported by this model.+"
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = "Embedding API is not supported by this model.+"
with pytest.raises(ValueError, match=err_msg):
llm.encode(prompt, pooling_task=task, use_tqdm=False)

View File

@@ -469,4 +469,8 @@ async def test_pooling_not_supported(
},
)
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = f"Unsupported task: {task!r}"
assert response.json()["error"]["message"].startswith(err_msg)

View File

@@ -107,8 +107,11 @@ def test_pooling_params(llm: LLM):
)
@pytest.mark.parametrize("task", ["token_classify", "classify"])
@pytest.mark.parametrize("task", ["token_classify", "classify", "plugin"])
def test_unsupported_tasks(llm: LLM, task: PoolingTask):
err_msg = "Classification API is not supported by this model.+"
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = "Classification API is not supported by this model.+"
with pytest.raises(ValueError, match=err_msg):
llm.encode(prompt, pooling_task=task, use_tqdm=False)

View File

@@ -767,4 +767,8 @@ async def test_pooling_not_supported(
},
)
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = f"Unsupported task: {task!r}"
assert response.json()["error"]["message"].startswith(err_msg)

View File

@@ -411,4 +411,8 @@ async def test_pooling_not_supported(server: RemoteOpenAIServer, task: str):
},
)
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = f"Unsupported task: {task!r}"
assert response.json()["error"]["message"].startswith(err_msg)

View File

@@ -484,4 +484,8 @@ async def test_pooling_not_supported(server: RemoteOpenAIServer, task: str):
},
)
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = f"Unsupported task: {task!r}"
assert response.json()["error"]["message"].startswith(err_msg)

View File

@@ -65,14 +65,17 @@ def test_score_api(llm: LLM):
llm.score("ping", "pong", use_tqdm=False)
@pytest.mark.parametrize("task", ["classify", "embed", "token_embed"])
@pytest.mark.parametrize("task", ["classify", "embed", "token_embed", "plugin"])
def test_unsupported_tasks(llm: LLM, task: PoolingTask, caplog_vllm):
if task == "classify":
with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"):
llm.encode(prompt, pooling_task=task, use_tqdm=False)
assert "deprecated" in caplog_vllm.text
else:
err_msg = "Embedding API is not supported by this model.+"
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = "Embedding API is not supported by this model.+"
with pytest.raises(ValueError, match=err_msg):
llm.encode(prompt, pooling_task=task, use_tqdm=False)

View File

@@ -50,7 +50,7 @@ async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: st
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("task", ["classify", "embed", "token_embed", "plugin"])
@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
async def test_pooling_not_supported(
server: RemoteOpenAIServer, model_name: str, task: str
):
@@ -64,7 +64,8 @@ async def test_pooling_not_supported(
},
)
if task != "classify":
assert response.json()["error"]["type"] == "BadRequestError"
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = f"Unsupported task: {task!r}"
assert response.json()["error"]["message"].startswith(err_msg)
assert response.json()["error"]["message"].startswith(err_msg)

View File

@@ -62,14 +62,17 @@ def test_token_ids_prompts(llm: LLM):
assert outputs[0].outputs.data.shape == (11, 384)
@pytest.mark.parametrize("task", ["embed", "classify", "token_classify"])
@pytest.mark.parametrize("task", ["embed", "classify", "token_classify", "plugin"])
def test_unsupported_tasks(llm: LLM, task: PoolingTask, caplog_vllm):
if task == "embed":
with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"):
llm.encode(prompt, pooling_task=task, use_tqdm=False)
assert "deprecated" in caplog_vllm.text
else:
err_msg = "Classification API is not supported by this model.+"
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = "Classification API is not supported by this model.+"
with pytest.raises(ValueError, match=err_msg):
llm.encode(prompt, pooling_task=task, use_tqdm=False)

View File

@@ -73,7 +73,7 @@ async def test_pooling_token_embed(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("task", ["embed", "classify", "token_classify", "plugin"])
@pytest.mark.parametrize("task", ["classify", "token_classify", "plugin"])
async def test_pooling_not_supported(
server: RemoteOpenAIServer, model_name: str, task: str
):
@@ -87,7 +87,8 @@ async def test_pooling_not_supported(
},
)
if task != "embed":
assert response.json()["error"]["type"] == "BadRequestError"
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = f"Unsupported task: {task!r}"
assert response.json()["error"]["message"].startswith(err_msg)
assert response.json()["error"]["message"].startswith(err_msg)

View File

@@ -86,7 +86,6 @@ def _build_serving_tokens(engine: AsyncLLM, **kwargs) -> ServingTokens:
serving_render = OpenAIServingRender(
model_config=engine.model_config,
renderer=engine.renderer,
io_processor=engine.io_processor,
model_registry=models.registry,
request_logger=None,
chat_template=None,
@@ -148,7 +147,6 @@ def _mock_engine() -> MagicMock:
engine.errored = False
engine.model_config = MockModelConfig()
engine.input_processor = MagicMock()
engine.io_processor = MagicMock()
engine.renderer = _build_renderer(engine.model_config)
return engine

View File

@@ -34,7 +34,6 @@ async def _async_serving_models_init() -> OpenAIServingModels:
mock_model_config.max_model_len = 2048
mock_engine_client.model_config = mock_model_config
mock_engine_client.input_processor = MagicMock()
mock_engine_client.io_processor = MagicMock()
mock_engine_client.renderer = MagicMock()
serving_models = OpenAIServingModels(

View File

@@ -514,7 +514,6 @@ async def test_header_dp_rank_argument():
serving_render = OpenAIServingRender(
model_config=engine.model_config,
renderer=engine.renderer,
io_processor=engine.io_processor,
model_registry=models.registry,
request_logger=None,
chat_template=None,