[Frontend][4/n] Improve pooling entrypoints | pooling. (#39153)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
@@ -87,7 +87,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
serving_render = OpenAIServingRender(
|
||||
model_config=engine.model_config,
|
||||
renderer=engine.renderer,
|
||||
io_processor=engine.io_processor,
|
||||
model_registry=models.registry,
|
||||
request_logger=None,
|
||||
chat_template=None,
|
||||
@@ -123,7 +122,6 @@ async def test_chat_error_non_stream():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
@@ -173,7 +171,6 @@ async def test_chat_error_stream():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
@@ -567,7 +567,6 @@ def _build_serving_render(
|
||||
return OpenAIServingRender(
|
||||
model_config=engine.model_config,
|
||||
renderer=engine.renderer,
|
||||
io_processor=engine.io_processor,
|
||||
model_registry=model_registry,
|
||||
request_logger=None,
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
@@ -599,7 +598,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
class MockEngine:
|
||||
model_config: MockModelConfig = field(default_factory=MockModelConfig)
|
||||
input_processor: MagicMock = field(default_factory=MagicMock)
|
||||
io_processor: MagicMock = field(default_factory=MagicMock)
|
||||
renderer: MagicMock = field(default_factory=MagicMock)
|
||||
|
||||
|
||||
@@ -632,7 +630,6 @@ async def test_serving_chat_returns_correct_model_name():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
@@ -662,7 +659,6 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
@@ -693,7 +689,6 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = mock_model_config
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
# Initialize the serving chat
|
||||
@@ -737,7 +732,6 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = mock_model_config
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
@@ -779,7 +773,6 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = mock_model_config
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
# Initialize the serving chat
|
||||
@@ -823,7 +816,6 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig(skip_tokenizer_init=True)
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
mock_tokenizer = MagicMock(spec=MistralTokenizer)
|
||||
mock_renderer = MistralRenderer(
|
||||
@@ -863,7 +855,6 @@ async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig(skip_tokenizer_init=True)
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
mock_tokenizer = MagicMock(spec=MistralTokenizer)
|
||||
mock_renderer = MistralRenderer(
|
||||
@@ -906,7 +897,6 @@ async def test_serving_chat_could_load_correct_generation_config():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = mock_model_config
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
# Initialize the serving chat
|
||||
@@ -952,7 +942,6 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = mock_model_config
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
@@ -1003,7 +992,6 @@ async def test_serving_chat_data_parallel_rank_extraction():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
# Mock the generate method to return an async generator
|
||||
@@ -1095,7 +1083,6 @@ class TestServingChatWithHarmony:
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
return mock_engine
|
||||
|
||||
@@ -1732,7 +1719,6 @@ async def test_tool_choice_validation_without_parser():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
models = OpenAIServingModels(
|
||||
@@ -1802,7 +1788,6 @@ async def test_streaming_n_gt1_independent_tool_parsers():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
models = OpenAIServingModels(
|
||||
|
||||
@@ -79,7 +79,6 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
|
||||
serving_render = OpenAIServingRender(
|
||||
model_config=engine.model_config,
|
||||
renderer=engine.renderer,
|
||||
io_processor=engine.io_processor,
|
||||
model_registry=models.registry,
|
||||
request_logger=None,
|
||||
chat_template=None,
|
||||
@@ -107,7 +106,6 @@ async def test_completion_error_non_stream():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_completion = _build_serving_completion(mock_engine)
|
||||
@@ -157,7 +155,6 @@ async def test_completion_error_stream():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_completion = _build_serving_completion(mock_engine)
|
||||
|
||||
@@ -137,7 +137,6 @@ def mock_serving_setup():
|
||||
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
models = OpenAIServingModels(
|
||||
@@ -148,7 +147,6 @@ def mock_serving_setup():
|
||||
serving_render = OpenAIServingRender(
|
||||
model_config=mock_engine.model_config,
|
||||
renderer=mock_engine.renderer,
|
||||
io_processor=mock_engine.io_processor,
|
||||
model_registry=models.registry,
|
||||
request_logger=None,
|
||||
chat_template=None,
|
||||
|
||||
@@ -77,7 +77,6 @@ def _create_mock_engine():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
# renderer is accessed by OpenAIServing.__init__ and serving.py
|
||||
mock_renderer = MagicMock()
|
||||
|
||||
@@ -218,7 +218,6 @@ class TestInitializeToolSessions:
|
||||
engine_client.model_config = model_config
|
||||
|
||||
engine_client.input_processor = MagicMock()
|
||||
engine_client.io_processor = MagicMock()
|
||||
engine_client.renderer = MagicMock()
|
||||
|
||||
models = MagicMock()
|
||||
@@ -307,7 +306,6 @@ class TestValidateGeneratorInput:
|
||||
engine_client.model_config = model_config
|
||||
|
||||
engine_client.input_processor = MagicMock()
|
||||
engine_client.io_processor = MagicMock()
|
||||
engine_client.renderer = MagicMock()
|
||||
|
||||
models = MagicMock()
|
||||
@@ -369,7 +367,6 @@ async def test_reasoning_tokens_counted_for_text_reasoning_model(monkeypatch):
|
||||
model_config.get_diff_sampling_param.return_value = {}
|
||||
engine_client.model_config = model_config
|
||||
engine_client.input_processor = MagicMock()
|
||||
engine_client.io_processor = MagicMock()
|
||||
engine_client.renderer = MagicMock()
|
||||
|
||||
tokenizer = FakeTokenizer()
|
||||
@@ -672,7 +669,6 @@ def _make_serving_instance_with_reasoning():
|
||||
model_config.get_diff_sampling_param.return_value = {}
|
||||
engine_client.model_config = model_config
|
||||
engine_client.input_processor = MagicMock()
|
||||
engine_client.io_processor = MagicMock()
|
||||
engine_client.renderer = MagicMock()
|
||||
|
||||
models = MagicMock()
|
||||
|
||||
@@ -110,8 +110,11 @@ def test_score_api(llm: LLM):
|
||||
llm.score("ping", "pong", use_tqdm=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("task", ["embed", "token_embed"])
|
||||
@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
|
||||
def test_unsupported_tasks(llm: LLM, task: PoolingTask):
|
||||
err_msg = "Embedding API is not supported by this model.+"
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
else:
|
||||
err_msg = "Embedding API is not supported by this model.+"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
llm.encode(prompt, pooling_task=task, use_tqdm=False)
|
||||
|
||||
@@ -469,4 +469,8 @@ async def test_pooling_not_supported(
|
||||
},
|
||||
)
|
||||
assert response.json()["error"]["type"] == "BadRequestError"
|
||||
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
else:
|
||||
err_msg = f"Unsupported task: {task!r}"
|
||||
assert response.json()["error"]["message"].startswith(err_msg)
|
||||
|
||||
@@ -107,8 +107,11 @@ def test_pooling_params(llm: LLM):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("task", ["token_classify", "classify"])
|
||||
@pytest.mark.parametrize("task", ["token_classify", "classify", "plugin"])
|
||||
def test_unsupported_tasks(llm: LLM, task: PoolingTask):
|
||||
err_msg = "Classification API is not supported by this model.+"
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
else:
|
||||
err_msg = "Classification API is not supported by this model.+"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
llm.encode(prompt, pooling_task=task, use_tqdm=False)
|
||||
|
||||
@@ -767,4 +767,8 @@ async def test_pooling_not_supported(
|
||||
},
|
||||
)
|
||||
assert response.json()["error"]["type"] == "BadRequestError"
|
||||
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
else:
|
||||
err_msg = f"Unsupported task: {task!r}"
|
||||
assert response.json()["error"]["message"].startswith(err_msg)
|
||||
|
||||
@@ -411,4 +411,8 @@ async def test_pooling_not_supported(server: RemoteOpenAIServer, task: str):
|
||||
},
|
||||
)
|
||||
assert response.json()["error"]["type"] == "BadRequestError"
|
||||
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
else:
|
||||
err_msg = f"Unsupported task: {task!r}"
|
||||
assert response.json()["error"]["message"].startswith(err_msg)
|
||||
|
||||
@@ -484,4 +484,8 @@ async def test_pooling_not_supported(server: RemoteOpenAIServer, task: str):
|
||||
},
|
||||
)
|
||||
assert response.json()["error"]["type"] == "BadRequestError"
|
||||
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
else:
|
||||
err_msg = f"Unsupported task: {task!r}"
|
||||
assert response.json()["error"]["message"].startswith(err_msg)
|
||||
|
||||
@@ -65,14 +65,17 @@ def test_score_api(llm: LLM):
|
||||
llm.score("ping", "pong", use_tqdm=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("task", ["classify", "embed", "token_embed"])
|
||||
@pytest.mark.parametrize("task", ["classify", "embed", "token_embed", "plugin"])
|
||||
def test_unsupported_tasks(llm: LLM, task: PoolingTask, caplog_vllm):
|
||||
if task == "classify":
|
||||
with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"):
|
||||
llm.encode(prompt, pooling_task=task, use_tqdm=False)
|
||||
assert "deprecated" in caplog_vllm.text
|
||||
else:
|
||||
err_msg = "Embedding API is not supported by this model.+"
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
else:
|
||||
err_msg = "Embedding API is not supported by this model.+"
|
||||
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
llm.encode(prompt, pooling_task=task, use_tqdm=False)
|
||||
|
||||
@@ -50,7 +50,7 @@ async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: st
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("task", ["classify", "embed", "token_embed", "plugin"])
|
||||
@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
|
||||
async def test_pooling_not_supported(
|
||||
server: RemoteOpenAIServer, model_name: str, task: str
|
||||
):
|
||||
@@ -64,7 +64,8 @@ async def test_pooling_not_supported(
|
||||
},
|
||||
)
|
||||
|
||||
if task != "classify":
|
||||
assert response.json()["error"]["type"] == "BadRequestError"
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
else:
|
||||
err_msg = f"Unsupported task: {task!r}"
|
||||
assert response.json()["error"]["message"].startswith(err_msg)
|
||||
assert response.json()["error"]["message"].startswith(err_msg)
|
||||
|
||||
@@ -62,14 +62,17 @@ def test_token_ids_prompts(llm: LLM):
|
||||
assert outputs[0].outputs.data.shape == (11, 384)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("task", ["embed", "classify", "token_classify"])
|
||||
@pytest.mark.parametrize("task", ["embed", "classify", "token_classify", "plugin"])
|
||||
def test_unsupported_tasks(llm: LLM, task: PoolingTask, caplog_vllm):
|
||||
if task == "embed":
|
||||
with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"):
|
||||
llm.encode(prompt, pooling_task=task, use_tqdm=False)
|
||||
assert "deprecated" in caplog_vllm.text
|
||||
else:
|
||||
err_msg = "Classification API is not supported by this model.+"
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
else:
|
||||
err_msg = "Classification API is not supported by this model.+"
|
||||
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
llm.encode(prompt, pooling_task=task, use_tqdm=False)
|
||||
|
||||
@@ -73,7 +73,7 @@ async def test_pooling_token_embed(server: RemoteOpenAIServer, model_name: str):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("task", ["embed", "classify", "token_classify", "plugin"])
|
||||
@pytest.mark.parametrize("task", ["classify", "token_classify", "plugin"])
|
||||
async def test_pooling_not_supported(
|
||||
server: RemoteOpenAIServer, model_name: str, task: str
|
||||
):
|
||||
@@ -87,7 +87,8 @@ async def test_pooling_not_supported(
|
||||
},
|
||||
)
|
||||
|
||||
if task != "embed":
|
||||
assert response.json()["error"]["type"] == "BadRequestError"
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
else:
|
||||
err_msg = f"Unsupported task: {task!r}"
|
||||
assert response.json()["error"]["message"].startswith(err_msg)
|
||||
assert response.json()["error"]["message"].startswith(err_msg)
|
||||
|
||||
@@ -86,7 +86,6 @@ def _build_serving_tokens(engine: AsyncLLM, **kwargs) -> ServingTokens:
|
||||
serving_render = OpenAIServingRender(
|
||||
model_config=engine.model_config,
|
||||
renderer=engine.renderer,
|
||||
io_processor=engine.io_processor,
|
||||
model_registry=models.registry,
|
||||
request_logger=None,
|
||||
chat_template=None,
|
||||
@@ -148,7 +147,6 @@ def _mock_engine() -> MagicMock:
|
||||
engine.errored = False
|
||||
engine.model_config = MockModelConfig()
|
||||
engine.input_processor = MagicMock()
|
||||
engine.io_processor = MagicMock()
|
||||
engine.renderer = _build_renderer(engine.model_config)
|
||||
return engine
|
||||
|
||||
|
||||
@@ -34,7 +34,6 @@ async def _async_serving_models_init() -> OpenAIServingModels:
|
||||
mock_model_config.max_model_len = 2048
|
||||
mock_engine_client.model_config = mock_model_config
|
||||
mock_engine_client.input_processor = MagicMock()
|
||||
mock_engine_client.io_processor = MagicMock()
|
||||
mock_engine_client.renderer = MagicMock()
|
||||
|
||||
serving_models = OpenAIServingModels(
|
||||
|
||||
@@ -514,7 +514,6 @@ async def test_header_dp_rank_argument():
|
||||
serving_render = OpenAIServingRender(
|
||||
model_config=engine.model_config,
|
||||
renderer=engine.renderer,
|
||||
io_processor=engine.io_processor,
|
||||
model_registry=models.registry,
|
||||
request_logger=None,
|
||||
chat_template=None,
|
||||
|
||||
Reference in New Issue
Block a user