[Frontend][4/n] Improve pooling entrypoints | pooling. (#39153)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
@@ -87,7 +87,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
serving_render = OpenAIServingRender(
|
||||
model_config=engine.model_config,
|
||||
renderer=engine.renderer,
|
||||
io_processor=engine.io_processor,
|
||||
model_registry=models.registry,
|
||||
request_logger=None,
|
||||
chat_template=None,
|
||||
@@ -123,7 +122,6 @@ async def test_chat_error_non_stream():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
@@ -173,7 +171,6 @@ async def test_chat_error_stream():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
@@ -567,7 +567,6 @@ def _build_serving_render(
|
||||
return OpenAIServingRender(
|
||||
model_config=engine.model_config,
|
||||
renderer=engine.renderer,
|
||||
io_processor=engine.io_processor,
|
||||
model_registry=model_registry,
|
||||
request_logger=None,
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
@@ -599,7 +598,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
class MockEngine:
|
||||
model_config: MockModelConfig = field(default_factory=MockModelConfig)
|
||||
input_processor: MagicMock = field(default_factory=MagicMock)
|
||||
io_processor: MagicMock = field(default_factory=MagicMock)
|
||||
renderer: MagicMock = field(default_factory=MagicMock)
|
||||
|
||||
|
||||
@@ -632,7 +630,6 @@ async def test_serving_chat_returns_correct_model_name():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
@@ -662,7 +659,6 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
@@ -693,7 +689,6 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = mock_model_config
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
# Initialize the serving chat
|
||||
@@ -737,7 +732,6 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = mock_model_config
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
@@ -779,7 +773,6 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = mock_model_config
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
# Initialize the serving chat
|
||||
@@ -823,7 +816,6 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig(skip_tokenizer_init=True)
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
mock_tokenizer = MagicMock(spec=MistralTokenizer)
|
||||
mock_renderer = MistralRenderer(
|
||||
@@ -863,7 +855,6 @@ async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig(skip_tokenizer_init=True)
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
mock_tokenizer = MagicMock(spec=MistralTokenizer)
|
||||
mock_renderer = MistralRenderer(
|
||||
@@ -906,7 +897,6 @@ async def test_serving_chat_could_load_correct_generation_config():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = mock_model_config
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
# Initialize the serving chat
|
||||
@@ -952,7 +942,6 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = mock_model_config
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
@@ -1003,7 +992,6 @@ async def test_serving_chat_data_parallel_rank_extraction():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
# Mock the generate method to return an async generator
|
||||
@@ -1095,7 +1083,6 @@ class TestServingChatWithHarmony:
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
return mock_engine
|
||||
|
||||
@@ -1732,7 +1719,6 @@ async def test_tool_choice_validation_without_parser():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
models = OpenAIServingModels(
|
||||
@@ -1802,7 +1788,6 @@ async def test_streaming_n_gt1_independent_tool_parsers():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
models = OpenAIServingModels(
|
||||
|
||||
@@ -79,7 +79,6 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
|
||||
serving_render = OpenAIServingRender(
|
||||
model_config=engine.model_config,
|
||||
renderer=engine.renderer,
|
||||
io_processor=engine.io_processor,
|
||||
model_registry=models.registry,
|
||||
request_logger=None,
|
||||
chat_template=None,
|
||||
@@ -107,7 +106,6 @@ async def test_completion_error_non_stream():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_completion = _build_serving_completion(mock_engine)
|
||||
@@ -157,7 +155,6 @@ async def test_completion_error_stream():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_completion = _build_serving_completion(mock_engine)
|
||||
|
||||
@@ -137,7 +137,6 @@ def mock_serving_setup():
|
||||
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
models = OpenAIServingModels(
|
||||
@@ -148,7 +147,6 @@ def mock_serving_setup():
|
||||
serving_render = OpenAIServingRender(
|
||||
model_config=mock_engine.model_config,
|
||||
renderer=mock_engine.renderer,
|
||||
io_processor=mock_engine.io_processor,
|
||||
model_registry=models.registry,
|
||||
request_logger=None,
|
||||
chat_template=None,
|
||||
|
||||
@@ -77,7 +77,6 @@ def _create_mock_engine():
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
# renderer is accessed by OpenAIServing.__init__ and serving.py
|
||||
mock_renderer = MagicMock()
|
||||
|
||||
@@ -218,7 +218,6 @@ class TestInitializeToolSessions:
|
||||
engine_client.model_config = model_config
|
||||
|
||||
engine_client.input_processor = MagicMock()
|
||||
engine_client.io_processor = MagicMock()
|
||||
engine_client.renderer = MagicMock()
|
||||
|
||||
models = MagicMock()
|
||||
@@ -307,7 +306,6 @@ class TestValidateGeneratorInput:
|
||||
engine_client.model_config = model_config
|
||||
|
||||
engine_client.input_processor = MagicMock()
|
||||
engine_client.io_processor = MagicMock()
|
||||
engine_client.renderer = MagicMock()
|
||||
|
||||
models = MagicMock()
|
||||
@@ -369,7 +367,6 @@ async def test_reasoning_tokens_counted_for_text_reasoning_model(monkeypatch):
|
||||
model_config.get_diff_sampling_param.return_value = {}
|
||||
engine_client.model_config = model_config
|
||||
engine_client.input_processor = MagicMock()
|
||||
engine_client.io_processor = MagicMock()
|
||||
engine_client.renderer = MagicMock()
|
||||
|
||||
tokenizer = FakeTokenizer()
|
||||
@@ -672,7 +669,6 @@ def _make_serving_instance_with_reasoning():
|
||||
model_config.get_diff_sampling_param.return_value = {}
|
||||
engine_client.model_config = model_config
|
||||
engine_client.input_processor = MagicMock()
|
||||
engine_client.io_processor = MagicMock()
|
||||
engine_client.renderer = MagicMock()
|
||||
|
||||
models = MagicMock()
|
||||
|
||||
@@ -110,8 +110,11 @@ def test_score_api(llm: LLM):
|
||||
llm.score("ping", "pong", use_tqdm=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("task", ["embed", "token_embed"])
|
||||
@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
|
||||
def test_unsupported_tasks(llm: LLM, task: PoolingTask):
|
||||
err_msg = "Embedding API is not supported by this model.+"
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
else:
|
||||
err_msg = "Embedding API is not supported by this model.+"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
llm.encode(prompt, pooling_task=task, use_tqdm=False)
|
||||
|
||||
@@ -469,4 +469,8 @@ async def test_pooling_not_supported(
|
||||
},
|
||||
)
|
||||
assert response.json()["error"]["type"] == "BadRequestError"
|
||||
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
else:
|
||||
err_msg = f"Unsupported task: {task!r}"
|
||||
assert response.json()["error"]["message"].startswith(err_msg)
|
||||
|
||||
@@ -107,8 +107,11 @@ def test_pooling_params(llm: LLM):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("task", ["token_classify", "classify"])
|
||||
@pytest.mark.parametrize("task", ["token_classify", "classify", "plugin"])
|
||||
def test_unsupported_tasks(llm: LLM, task: PoolingTask):
|
||||
err_msg = "Classification API is not supported by this model.+"
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
else:
|
||||
err_msg = "Classification API is not supported by this model.+"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
llm.encode(prompt, pooling_task=task, use_tqdm=False)
|
||||
|
||||
@@ -767,4 +767,8 @@ async def test_pooling_not_supported(
|
||||
},
|
||||
)
|
||||
assert response.json()["error"]["type"] == "BadRequestError"
|
||||
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
else:
|
||||
err_msg = f"Unsupported task: {task!r}"
|
||||
assert response.json()["error"]["message"].startswith(err_msg)
|
||||
|
||||
@@ -411,4 +411,8 @@ async def test_pooling_not_supported(server: RemoteOpenAIServer, task: str):
|
||||
},
|
||||
)
|
||||
assert response.json()["error"]["type"] == "BadRequestError"
|
||||
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
else:
|
||||
err_msg = f"Unsupported task: {task!r}"
|
||||
assert response.json()["error"]["message"].startswith(err_msg)
|
||||
|
||||
@@ -484,4 +484,8 @@ async def test_pooling_not_supported(server: RemoteOpenAIServer, task: str):
|
||||
},
|
||||
)
|
||||
assert response.json()["error"]["type"] == "BadRequestError"
|
||||
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
else:
|
||||
err_msg = f"Unsupported task: {task!r}"
|
||||
assert response.json()["error"]["message"].startswith(err_msg)
|
||||
|
||||
@@ -65,14 +65,17 @@ def test_score_api(llm: LLM):
|
||||
llm.score("ping", "pong", use_tqdm=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("task", ["classify", "embed", "token_embed"])
|
||||
@pytest.mark.parametrize("task", ["classify", "embed", "token_embed", "plugin"])
|
||||
def test_unsupported_tasks(llm: LLM, task: PoolingTask, caplog_vllm):
|
||||
if task == "classify":
|
||||
with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"):
|
||||
llm.encode(prompt, pooling_task=task, use_tqdm=False)
|
||||
assert "deprecated" in caplog_vllm.text
|
||||
else:
|
||||
err_msg = "Embedding API is not supported by this model.+"
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
else:
|
||||
err_msg = "Embedding API is not supported by this model.+"
|
||||
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
llm.encode(prompt, pooling_task=task, use_tqdm=False)
|
||||
|
||||
@@ -50,7 +50,7 @@ async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: st
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("task", ["classify", "embed", "token_embed", "plugin"])
|
||||
@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
|
||||
async def test_pooling_not_supported(
|
||||
server: RemoteOpenAIServer, model_name: str, task: str
|
||||
):
|
||||
@@ -64,7 +64,8 @@ async def test_pooling_not_supported(
|
||||
},
|
||||
)
|
||||
|
||||
if task != "classify":
|
||||
assert response.json()["error"]["type"] == "BadRequestError"
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
else:
|
||||
err_msg = f"Unsupported task: {task!r}"
|
||||
assert response.json()["error"]["message"].startswith(err_msg)
|
||||
assert response.json()["error"]["message"].startswith(err_msg)
|
||||
|
||||
@@ -62,14 +62,17 @@ def test_token_ids_prompts(llm: LLM):
|
||||
assert outputs[0].outputs.data.shape == (11, 384)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("task", ["embed", "classify", "token_classify"])
|
||||
@pytest.mark.parametrize("task", ["embed", "classify", "token_classify", "plugin"])
|
||||
def test_unsupported_tasks(llm: LLM, task: PoolingTask, caplog_vllm):
|
||||
if task == "embed":
|
||||
with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"):
|
||||
llm.encode(prompt, pooling_task=task, use_tqdm=False)
|
||||
assert "deprecated" in caplog_vllm.text
|
||||
else:
|
||||
err_msg = "Classification API is not supported by this model.+"
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
else:
|
||||
err_msg = "Classification API is not supported by this model.+"
|
||||
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
llm.encode(prompt, pooling_task=task, use_tqdm=False)
|
||||
|
||||
@@ -73,7 +73,7 @@ async def test_pooling_token_embed(server: RemoteOpenAIServer, model_name: str):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("task", ["embed", "classify", "token_classify", "plugin"])
|
||||
@pytest.mark.parametrize("task", ["classify", "token_classify", "plugin"])
|
||||
async def test_pooling_not_supported(
|
||||
server: RemoteOpenAIServer, model_name: str, task: str
|
||||
):
|
||||
@@ -87,7 +87,8 @@ async def test_pooling_not_supported(
|
||||
},
|
||||
)
|
||||
|
||||
if task != "embed":
|
||||
assert response.json()["error"]["type"] == "BadRequestError"
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
else:
|
||||
err_msg = f"Unsupported task: {task!r}"
|
||||
assert response.json()["error"]["message"].startswith(err_msg)
|
||||
assert response.json()["error"]["message"].startswith(err_msg)
|
||||
|
||||
@@ -86,7 +86,6 @@ def _build_serving_tokens(engine: AsyncLLM, **kwargs) -> ServingTokens:
|
||||
serving_render = OpenAIServingRender(
|
||||
model_config=engine.model_config,
|
||||
renderer=engine.renderer,
|
||||
io_processor=engine.io_processor,
|
||||
model_registry=models.registry,
|
||||
request_logger=None,
|
||||
chat_template=None,
|
||||
@@ -148,7 +147,6 @@ def _mock_engine() -> MagicMock:
|
||||
engine.errored = False
|
||||
engine.model_config = MockModelConfig()
|
||||
engine.input_processor = MagicMock()
|
||||
engine.io_processor = MagicMock()
|
||||
engine.renderer = _build_renderer(engine.model_config)
|
||||
return engine
|
||||
|
||||
|
||||
@@ -34,7 +34,6 @@ async def _async_serving_models_init() -> OpenAIServingModels:
|
||||
mock_model_config.max_model_len = 2048
|
||||
mock_engine_client.model_config = mock_model_config
|
||||
mock_engine_client.input_processor = MagicMock()
|
||||
mock_engine_client.io_processor = MagicMock()
|
||||
mock_engine_client.renderer = MagicMock()
|
||||
|
||||
serving_models = OpenAIServingModels(
|
||||
|
||||
@@ -514,7 +514,6 @@ async def test_header_dp_rank_argument():
|
||||
serving_render = OpenAIServingRender(
|
||||
model_config=engine.model_config,
|
||||
renderer=engine.renderer,
|
||||
io_processor=engine.io_processor,
|
||||
model_registry=models.registry,
|
||||
request_logger=None,
|
||||
chat_template=None,
|
||||
|
||||
@@ -14,7 +14,6 @@ from vllm.distributed.weight_transfer.base import (
|
||||
from vllm.inputs import EngineInput, PromptType
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors import IOProcessor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@@ -44,7 +43,6 @@ class EngineClient(ABC):
|
||||
vllm_config: VllmConfig
|
||||
model_config: ModelConfig
|
||||
renderer: BaseRenderer
|
||||
io_processor: IOProcessor | None
|
||||
input_processor: InputProcessor
|
||||
|
||||
@property
|
||||
|
||||
@@ -49,9 +49,7 @@ from vllm.entrypoints.chat_utils import (
|
||||
load_chat_template,
|
||||
)
|
||||
from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors
|
||||
from vllm.entrypoints.pooling.scoring.io_processor import (
|
||||
ScoringIOProcessor,
|
||||
)
|
||||
from vllm.entrypoints.pooling.scoring.io_processor import ScoringIOProcessor
|
||||
from vllm.entrypoints.pooling.scoring.typing import ScoreInput
|
||||
from vllm.entrypoints.pooling.typing import OfflineInputsContext, OfflineOutputsContext
|
||||
from vllm.entrypoints.utils import log_non_default_args
|
||||
@@ -398,12 +396,11 @@ class LLM:
|
||||
self.runner_type = self.model_config.runner_type
|
||||
self.renderer = self.llm_engine.renderer
|
||||
self.chat_template = load_chat_template(chat_template)
|
||||
self.io_processor = self.llm_engine.io_processor
|
||||
self.input_processor = self.llm_engine.input_processor
|
||||
self.chat_template_config = ChatTemplateConfig(chat_template=self.chat_template)
|
||||
self.pooling_io_processors = init_pooling_io_processors(
|
||||
supported_tasks=supported_tasks,
|
||||
model_config=self.model_config,
|
||||
vllm_config=self.llm_engine.vllm_config,
|
||||
renderer=self.renderer,
|
||||
chat_template_config=self.chat_template_config,
|
||||
)
|
||||
@@ -1081,118 +1078,55 @@ class LLM:
|
||||
pooled hidden states in the same order as the input prompts.
|
||||
"""
|
||||
|
||||
self._verify_pooling_task(pooling_task)
|
||||
|
||||
if isinstance(prompts, dict) and "data" in prompts:
|
||||
if self.io_processor is None:
|
||||
raise ValueError(
|
||||
"No IOProcessor plugin installed. Please refer "
|
||||
"to the documentation and to the "
|
||||
"'prithvi_geospatial_mae_io_processor' "
|
||||
"offline inference example for more details."
|
||||
)
|
||||
|
||||
# Validate the request data is valid for the loaded plugin
|
||||
prompt_data = prompts.get("data")
|
||||
if prompt_data is None:
|
||||
raise ValueError(
|
||||
"The 'data' field of the prompt is expected to contain "
|
||||
"the prompt data and it cannot be None. "
|
||||
"Refer to the documentation of the IOProcessor "
|
||||
"in use for more details."
|
||||
)
|
||||
validated_prompt = self.io_processor.parse_data(prompt_data)
|
||||
|
||||
# obtain the actual model prompts from the pre-processor
|
||||
prompts = self.io_processor.pre_process(prompt=validated_prompt)
|
||||
prompts_seq = prompt_to_seq(prompts)
|
||||
|
||||
params_seq: Sequence[PoolingParams] = [
|
||||
self.io_processor.merge_pooling_params(param)
|
||||
for param in self._params_to_seq(
|
||||
pooling_params,
|
||||
len(prompts_seq),
|
||||
)
|
||||
]
|
||||
for p in params_seq:
|
||||
if p.task is None:
|
||||
p.task = "plugin"
|
||||
|
||||
outputs = self._run_completion(
|
||||
prompts=prompts_seq,
|
||||
params=params_seq,
|
||||
output_type=PoolingRequestOutput,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
if isinstance(prompts, dict) and "data" in prompts and pooling_task != "plugin":
|
||||
raise ValueError(
|
||||
"The 'data' field is only supported for the 'plugin' pooling task."
|
||||
)
|
||||
self._verify_pooling_task(pooling_task)
|
||||
assert pooling_task is not None and pooling_task in self.pooling_io_processors
|
||||
|
||||
# get the post-processed model outputs
|
||||
assert self.io_processor is not None
|
||||
processed_outputs = self.io_processor.post_process(outputs)
|
||||
io_processor = self.pooling_io_processors[pooling_task]
|
||||
|
||||
return [
|
||||
PoolingRequestOutput[Any](
|
||||
request_id="",
|
||||
outputs=processed_outputs,
|
||||
num_cached_tokens=getattr(
|
||||
processed_outputs, "num_cached_tokens", 0
|
||||
),
|
||||
prompt_token_ids=[],
|
||||
finished=True,
|
||||
)
|
||||
]
|
||||
else:
|
||||
if pooling_params is None:
|
||||
# Use default pooling params.
|
||||
pooling_params = PoolingParams()
|
||||
if pooling_params is None:
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
prompts_seq = prompt_to_seq(prompts)
|
||||
params_seq = self._params_to_seq(pooling_params, len(prompts_seq))
|
||||
ctx = OfflineInputsContext(
|
||||
prompts=prompts,
|
||||
pooling_params=pooling_params,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
for param in params_seq:
|
||||
if param.task is None:
|
||||
param.task = pooling_task
|
||||
elif param.task != pooling_task:
|
||||
msg = (
|
||||
f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
engine_inputs = io_processor.pre_process_offline(ctx)
|
||||
n_inputs = len(engine_inputs)
|
||||
assert ctx.pooling_params is not None
|
||||
|
||||
if pooling_task in self.pooling_io_processors:
|
||||
io_processor = self.pooling_io_processors[pooling_task]
|
||||
processor_inputs = io_processor.pre_process_offline(
|
||||
ctx=OfflineInputsContext(
|
||||
prompts=prompts_seq, tokenization_kwargs=tokenization_kwargs
|
||||
)
|
||||
)
|
||||
seq_lora_requests = self._lora_request_to_seq(
|
||||
lora_request, len(prompts_seq)
|
||||
)
|
||||
seq_priority = self._priority_to_seq(None, len(prompts))
|
||||
params_seq = self._params_to_seq(ctx.pooling_params, n_inputs)
|
||||
|
||||
self._render_and_add_requests(
|
||||
prompts=processor_inputs,
|
||||
params=params_seq,
|
||||
lora_requests=seq_lora_requests,
|
||||
priorities=seq_priority,
|
||||
)
|
||||
for param in params_seq:
|
||||
if param.task is None:
|
||||
param.task = pooling_task
|
||||
elif pooling_task == "plugin":
|
||||
# `plugin` task uses io_processor.parse_request to verify inputs.
|
||||
# We actually allow plugin to overwrite pooling_task.
|
||||
pass
|
||||
elif param.task != pooling_task:
|
||||
msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
|
||||
raise ValueError(msg)
|
||||
|
||||
outputs = self._run_engine(
|
||||
use_tqdm=use_tqdm, output_type=PoolingRequestOutput
|
||||
)
|
||||
outputs = io_processor.post_process_offline(
|
||||
ctx=OfflineOutputsContext(outputs=outputs)
|
||||
)
|
||||
else:
|
||||
outputs = self._run_completion(
|
||||
prompts=prompts_seq,
|
||||
params=params_seq,
|
||||
output_type=PoolingRequestOutput,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
seq_lora_requests = self._lora_request_to_seq(lora_request, n_inputs)
|
||||
seq_priority = self._priority_to_seq(None, n_inputs)
|
||||
|
||||
self._render_and_add_requests(
|
||||
prompts=engine_inputs,
|
||||
params=params_seq,
|
||||
lora_requests=seq_lora_requests,
|
||||
priorities=seq_priority,
|
||||
)
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm, output_type=PoolingRequestOutput)
|
||||
outputs = io_processor.post_process_offline(
|
||||
ctx=OfflineOutputsContext(outputs=outputs)
|
||||
)
|
||||
return outputs
|
||||
|
||||
def _verify_pooling_task(self, pooling_task: PoolingTask | None):
|
||||
@@ -1254,6 +1188,14 @@ class LLM:
|
||||
pooling_task,
|
||||
)
|
||||
|
||||
if pooling_task == "plugin" and "plugin" not in self.pooling_io_processors:
|
||||
raise ValueError(
|
||||
"No IOProcessor plugin installed. Please refer "
|
||||
"to the documentation and to the "
|
||||
"'prithvi_geospatial_mae_io_processor' "
|
||||
"offline inference example for more details."
|
||||
)
|
||||
|
||||
def embed(
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
@@ -1458,6 +1400,9 @@ class LLM:
|
||||
scoring_data = io_processor.valid_inputs(data_1, data_2)
|
||||
n_queries = len(scoring_data.data_1)
|
||||
|
||||
if pooling_params is None:
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
ctx = OfflineInputsContext(
|
||||
prompts=scoring_data,
|
||||
pooling_params=pooling_params,
|
||||
@@ -1466,15 +1411,11 @@ class LLM:
|
||||
n_queries=n_queries,
|
||||
)
|
||||
|
||||
processor_inputs = io_processor.pre_process_offline(ctx)
|
||||
engine_inputs = io_processor.pre_process_offline(ctx)
|
||||
n_inputs = len(engine_inputs)
|
||||
|
||||
seq_lora_requests = self._lora_request_to_seq(
|
||||
lora_request, len(processor_inputs)
|
||||
)
|
||||
|
||||
if ctx.pooling_params is None:
|
||||
ctx.pooling_params = PoolingParams()
|
||||
params_seq = self._params_to_seq(ctx.pooling_params, len(processor_inputs))
|
||||
seq_lora_requests = self._lora_request_to_seq(lora_request, n_inputs)
|
||||
params_seq = self._params_to_seq(ctx.pooling_params, n_inputs)
|
||||
|
||||
for param in params_seq:
|
||||
if param.task is None:
|
||||
@@ -1483,10 +1424,10 @@ class LLM:
|
||||
msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
|
||||
raise ValueError(msg)
|
||||
|
||||
seq_priority = self._priority_to_seq(None, len(processor_inputs))
|
||||
seq_priority = self._priority_to_seq(None, n_inputs)
|
||||
|
||||
self._render_and_add_requests(
|
||||
prompts=processor_inputs,
|
||||
prompts=engine_inputs,
|
||||
params=params_seq,
|
||||
lora_requests=seq_lora_requests,
|
||||
priorities=seq_priority,
|
||||
@@ -1579,7 +1520,7 @@ class LLM:
|
||||
if isinstance(params, Sequence):
|
||||
if len(params) != num_requests:
|
||||
raise ValueError(
|
||||
f"The lengths of prompts ({params}) "
|
||||
f"The lengths of prompts ({num_requests}) "
|
||||
f"and params ({len(params)}) must be the same."
|
||||
)
|
||||
|
||||
|
||||
@@ -370,7 +370,6 @@ async def init_app_state(
|
||||
state.openai_serving_render = OpenAIServingRender(
|
||||
model_config=engine_client.model_config,
|
||||
renderer=engine_client.renderer,
|
||||
io_processor=engine_client.io_processor,
|
||||
model_registry=state.openai_serving_models.registry,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
@@ -441,13 +440,12 @@ async def init_render_app_state(
|
||||
|
||||
Unlike :func:`init_app_state` this function does not require an
|
||||
:class:`~vllm.engine.protocol.EngineClient`; it bootstraps the
|
||||
preprocessing pipeline (renderer, io_processor, input_processor)
|
||||
preprocessing pipeline (renderer, input_processor)
|
||||
directly from the :class:`~vllm.config.VllmConfig`.
|
||||
"""
|
||||
from vllm.entrypoints.chat_utils import load_chat_template
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIModelRegistry
|
||||
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.renderers import renderer_from_config
|
||||
|
||||
served_model_names = args.served_model_name or [args.model]
|
||||
@@ -465,15 +463,11 @@ async def init_render_app_state(
|
||||
request_logger = None
|
||||
|
||||
renderer = renderer_from_config(vllm_config)
|
||||
io_processor = get_io_processor(
|
||||
vllm_config, renderer, vllm_config.model_config.io_processor_plugin
|
||||
)
|
||||
resolved_chat_template = load_chat_template(args.chat_template)
|
||||
|
||||
state.openai_serving_render = OpenAIServingRender(
|
||||
model_config=vllm_config.model_config,
|
||||
renderer=renderer,
|
||||
io_processor=io_processor,
|
||||
model_registry=model_registry,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
|
||||
@@ -44,12 +44,6 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
|
||||
TranscriptionResponse,
|
||||
TranslationRequest,
|
||||
)
|
||||
from vllm.entrypoints.pooling.pooling.protocol import (
|
||||
IOProcessorRequest,
|
||||
PoolingChatRequest,
|
||||
PoolingCompletionRequest,
|
||||
PoolingResponse,
|
||||
)
|
||||
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
|
||||
from vllm.entrypoints.serve.tokenize.protocol import (
|
||||
DetokenizeRequest,
|
||||
@@ -62,8 +56,7 @@ from vllm.inputs import EngineInput, PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob, PromptLogprobs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.renderers import ChatParams, TokenizeParams
|
||||
from vllm.renderers.inputs.preprocess import (
|
||||
extract_prompt_components,
|
||||
@@ -78,10 +71,7 @@ from vllm.tracing import (
|
||||
log_tracing_disabled_warning,
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.async_utils import (
|
||||
collect_from_async_generator,
|
||||
merge_async_iterators,
|
||||
)
|
||||
from vllm.utils.async_utils import collect_from_async_generator
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -101,17 +91,11 @@ class RendererChatRequest(RendererRequest, Protocol):
|
||||
|
||||
|
||||
CompletionLikeRequest: TypeAlias = (
|
||||
CompletionRequest
|
||||
| TokenizeCompletionRequest
|
||||
| DetokenizeRequest
|
||||
| PoolingCompletionRequest
|
||||
CompletionRequest | TokenizeCompletionRequest | DetokenizeRequest
|
||||
)
|
||||
|
||||
ChatLikeRequest: TypeAlias = (
|
||||
ChatCompletionRequest
|
||||
| BatchChatCompletionRequest
|
||||
| TokenizeChatRequest
|
||||
| PoolingChatRequest
|
||||
ChatCompletionRequest | BatchChatCompletionRequest | TokenizeChatRequest
|
||||
)
|
||||
|
||||
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
|
||||
@@ -121,7 +105,6 @@ AnyRequest: TypeAlias = (
|
||||
| ChatLikeRequest
|
||||
| SpeechToTextRequest
|
||||
| ResponsesRequest
|
||||
| IOProcessorRequest
|
||||
| GenerateRequest
|
||||
)
|
||||
|
||||
@@ -130,7 +113,6 @@ AnyResponse: TypeAlias = (
|
||||
| ChatCompletionResponse
|
||||
| TranscriptionResponse
|
||||
| TokenizeResponse
|
||||
| PoolingResponse
|
||||
| GenerateResponse
|
||||
)
|
||||
|
||||
@@ -146,12 +128,6 @@ class ServeContext(Generic[RequestT]):
|
||||
created_time: int = field(default_factory=lambda: int(time.time()))
|
||||
lora_request: LoRARequest | None = None
|
||||
engine_inputs: list[EngineInput] | None = None
|
||||
|
||||
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
|
||||
None
|
||||
)
|
||||
final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
@@ -171,7 +147,6 @@ class OpenAIServing:
|
||||
super().__init__()
|
||||
|
||||
self.engine_client = engine_client
|
||||
|
||||
self.models = models
|
||||
|
||||
self.request_logger = request_logger
|
||||
@@ -179,7 +154,6 @@ class OpenAIServing:
|
||||
|
||||
self.model_config = engine_client.model_config
|
||||
self.renderer = engine_client.renderer
|
||||
self.io_processor = engine_client.io_processor
|
||||
self.input_processor = engine_client.input_processor
|
||||
|
||||
async def beam_search(
|
||||
@@ -381,155 +355,6 @@ class OpenAIServing:
|
||||
prompt_logprobs=None,
|
||||
)
|
||||
|
||||
async def _preprocess(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> ErrorResponse | None:
|
||||
"""
|
||||
Default preprocessing hook. Subclasses may override to prepare `ctx`.
|
||||
"""
|
||||
return None
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> AnyResponse | ErrorResponse:
|
||||
"""
|
||||
Default response builder. Subclass may override this method
|
||||
to return the appropriate response object.
|
||||
"""
|
||||
return self.create_error_response("unimplemented endpoint")
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> AnyResponse | ErrorResponse:
|
||||
async for response in self._pipeline(ctx):
|
||||
return response
|
||||
|
||||
return self.create_error_response("No response yielded from pipeline")
|
||||
|
||||
async def _pipeline(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> AsyncGenerator[AnyResponse | ErrorResponse, None]:
|
||||
"""Execute the request processing pipeline yielding responses."""
|
||||
if error := await self._check_model(ctx.request):
|
||||
yield error
|
||||
if error := self._validate_request(ctx):
|
||||
yield error
|
||||
|
||||
preprocess_ret = await self._preprocess(ctx)
|
||||
if isinstance(preprocess_ret, ErrorResponse):
|
||||
yield preprocess_ret
|
||||
|
||||
generators_ret = await self._prepare_generators(ctx)
|
||||
if isinstance(generators_ret, ErrorResponse):
|
||||
yield generators_ret
|
||||
|
||||
collect_ret = await self._collect_batch(ctx)
|
||||
if isinstance(collect_ret, ErrorResponse):
|
||||
yield collect_ret
|
||||
|
||||
yield self._build_response(ctx)
|
||||
|
||||
def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None:
|
||||
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
|
||||
|
||||
if (
|
||||
truncate_prompt_tokens is not None
|
||||
and truncate_prompt_tokens > self.model_config.max_model_len
|
||||
):
|
||||
return self.create_error_response(
|
||||
"truncate_prompt_tokens value is "
|
||||
"greater than max_model_len."
|
||||
" Please request a smaller truncation size."
|
||||
)
|
||||
return None
|
||||
|
||||
def _create_pooling_params(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> PoolingParams | ErrorResponse:
|
||||
if not hasattr(ctx.request, "to_pooling_params"):
|
||||
return self.create_error_response(
|
||||
"Request type does not support pooling parameters"
|
||||
)
|
||||
|
||||
return ctx.request.to_pooling_params()
|
||||
|
||||
async def _prepare_generators(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> ErrorResponse | None:
|
||||
"""Schedule the request and get the result generator."""
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
trace_headers = (
|
||||
None
|
||||
if ctx.raw_request is None
|
||||
else await self._get_trace_headers(ctx.raw_request.headers)
|
||||
)
|
||||
|
||||
pooling_params = self._create_pooling_params(ctx)
|
||||
if isinstance(pooling_params, ErrorResponse):
|
||||
return pooling_params
|
||||
|
||||
if ctx.engine_inputs is None:
|
||||
return self.create_error_response("Engine prompts not available")
|
||||
|
||||
for i, engine_input in enumerate(ctx.engine_inputs):
|
||||
request_id_item = f"{ctx.request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_input,
|
||||
params=pooling_params,
|
||||
lora_request=ctx.lora_request,
|
||||
)
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_input,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=ctx.lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=getattr(ctx.request, "priority", 0),
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
||||
ctx.result_generator = merge_async_iterators(*generators)
|
||||
|
||||
return None
|
||||
|
||||
async def _collect_batch(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> ErrorResponse | None:
|
||||
"""Collect batch results from the result generator."""
|
||||
if ctx.engine_inputs is None:
|
||||
return self.create_error_response("Engine prompts not available")
|
||||
|
||||
num_prompts = len(ctx.engine_inputs)
|
||||
final_res_batch: list[PoolingRequestOutput | None]
|
||||
final_res_batch = [None] * num_prompts
|
||||
|
||||
if ctx.result_generator is None:
|
||||
return self.create_error_response("Result generator not available")
|
||||
|
||||
async for i, res in ctx.result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
if None in final_res_batch:
|
||||
return self.create_error_response(
|
||||
"Failed to generate results for all prompts"
|
||||
)
|
||||
|
||||
ctx.final_res_batch = [res for res in final_res_batch if res is not None]
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def create_error_response(
|
||||
message: str | Exception,
|
||||
@@ -719,7 +544,7 @@ class OpenAIServing:
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: PromptType | EngineInput,
|
||||
params: SamplingParams | PoolingParams | BeamSearchParams | None,
|
||||
params: SamplingParams | BeamSearchParams | None,
|
||||
lora_request: LoRARequest | None,
|
||||
) -> None:
|
||||
if self.request_logger is None:
|
||||
|
||||
@@ -112,7 +112,6 @@ class OpenAIServingModels:
|
||||
|
||||
self.model_config = self.engine_client.model_config
|
||||
self.renderer = self.engine_client.renderer
|
||||
self.io_processor = self.engine_client.io_processor
|
||||
self.input_processor = self.engine_client.input_processor
|
||||
|
||||
async def init_static_loras(self):
|
||||
|
||||
@@ -67,20 +67,18 @@ def init_pooling_state(
|
||||
from vllm.entrypoints.chat_utils import load_chat_template
|
||||
from vllm.entrypoints.pooling.classify.serving import ServingClassification
|
||||
from vllm.entrypoints.pooling.embed.serving import ServingEmbedding
|
||||
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
|
||||
from vllm.entrypoints.pooling.pooling.serving import ServingPooling
|
||||
from vllm.entrypoints.pooling.scoring.serving import ServingScores
|
||||
from vllm.tasks import POOLING_TASKS
|
||||
|
||||
model_config = engine_client.model_config
|
||||
|
||||
resolved_chat_template = load_chat_template(args.chat_template)
|
||||
|
||||
state.serving_pooling = (
|
||||
(
|
||||
OpenAIServingPooling(
|
||||
ServingPooling(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
state.openai_serving_render,
|
||||
supported_tasks=supported_tasks,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Final
|
||||
|
||||
from vllm import PoolingRequestOutput, PromptType
|
||||
from vllm.config import ModelConfig
|
||||
from vllm import PoolingParams, PoolingRequestOutput, PromptType
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatTemplateConfig,
|
||||
@@ -33,11 +33,12 @@ class PoolingIOProcessor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
vllm_config: VllmConfig,
|
||||
renderer: BaseRenderer,
|
||||
chat_template_config: ChatTemplateConfig,
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.renderer = renderer
|
||||
|
||||
self.chat_template = chat_template_config.chat_template
|
||||
@@ -48,12 +49,12 @@ class PoolingIOProcessor:
|
||||
chat_template_config.trust_request_chat_template
|
||||
)
|
||||
|
||||
def create_pooling_params(self, request):
|
||||
return request.to_pooling_params()
|
||||
|
||||
#######################################
|
||||
# online APIs
|
||||
|
||||
def create_pooling_params(self, request):
|
||||
return request.to_pooling_params()
|
||||
|
||||
def pre_process_online(self, ctx: PoolingServeContext):
|
||||
request = ctx.request
|
||||
|
||||
@@ -100,12 +101,16 @@ class PoolingIOProcessor:
|
||||
# offline APIs
|
||||
|
||||
def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]:
|
||||
assert not isinstance(ctx.prompts, ScoringData)
|
||||
assert not isinstance(ctx.prompts, ScoringData) and not (
|
||||
isinstance(ctx.prompts, dict) and "data" in ctx.prompts
|
||||
)
|
||||
|
||||
prompts_seq = prompt_to_seq(ctx.prompts)
|
||||
tok_params = self.renderer.default_cmpl_tok_params.with_kwargs(
|
||||
**(ctx.tokenization_kwargs or {})
|
||||
)
|
||||
return self._preprocess_completion_offline(
|
||||
prompts=ctx.prompts, tok_params=tok_params
|
||||
prompts=prompts_seq, tok_params=tok_params
|
||||
)
|
||||
|
||||
async def pre_process_offline_async(self, ctx: OfflineInputsContext):
|
||||
@@ -243,3 +248,19 @@ class PoolingIOProcessor:
|
||||
"Refused request with untrusted chat template."
|
||||
)
|
||||
return None
|
||||
|
||||
def _params_to_seq(
|
||||
self,
|
||||
params: PoolingParams | Sequence[PoolingParams],
|
||||
num_requests: int,
|
||||
) -> Sequence[PoolingParams]:
|
||||
if isinstance(params, Sequence):
|
||||
if len(params) != num_requests:
|
||||
raise ValueError(
|
||||
f"The lengths of prompts ({num_requests}) "
|
||||
f"and params ({len(params)}) must be the same."
|
||||
)
|
||||
|
||||
return params
|
||||
|
||||
return [params] * num_requests
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from http import HTTPStatus
|
||||
from typing import ClassVar
|
||||
@@ -9,7 +11,7 @@ from fastapi.responses import Response
|
||||
from starlette.datastructures import Headers
|
||||
|
||||
from vllm import PoolingParams, PoolingRequestOutput, envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatTemplateConfig,
|
||||
@@ -35,7 +37,7 @@ from vllm.utils.async_utils import merge_async_iterators
|
||||
from .io_processor import PoolingIOProcessor
|
||||
|
||||
|
||||
class PoolingServing:
|
||||
class PoolingServingBase(ABC):
|
||||
request_id_prefix: ClassVar[str]
|
||||
|
||||
def __init__(
|
||||
@@ -50,10 +52,11 @@ class PoolingServing:
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.engine_client = engine_client
|
||||
self.models = models
|
||||
self.model_config = models.model_config
|
||||
self.renderer = models.renderer
|
||||
self.vllm_config = engine_client.vllm_config
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
self.request_logger = request_logger
|
||||
self.return_tokens_as_token_ids = return_tokens_as_token_ids
|
||||
@@ -63,31 +66,14 @@ class PoolingServing:
|
||||
chat_template_content_format=chat_template_content_format,
|
||||
trust_request_chat_template=trust_request_chat_template,
|
||||
)
|
||||
self.io_processor = self.init_io_processor(
|
||||
model_config=models.model_config,
|
||||
renderer=models.renderer,
|
||||
chat_template_config=self.chat_template_config,
|
||||
)
|
||||
|
||||
def init_io_processor(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
renderer: BaseRenderer,
|
||||
chat_template_config: ChatTemplateConfig,
|
||||
) -> PoolingIOProcessor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(
|
||||
self,
|
||||
request: AnyPoolingRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> Response:
|
||||
ctx = await self._init_ctx(request, raw_request)
|
||||
await self.io_processor.pre_process_online_async(ctx)
|
||||
await self._prepare_generators(ctx)
|
||||
await self._collect_batch(ctx)
|
||||
await self.io_processor.post_process_online_async(ctx)
|
||||
return await self._build_response(ctx)
|
||||
raise NotImplementedError
|
||||
|
||||
async def _init_ctx(
|
||||
self,
|
||||
@@ -124,10 +110,8 @@ class PoolingServing:
|
||||
else await self._get_trace_headers(ctx.raw_request.headers)
|
||||
)
|
||||
|
||||
if ctx.pooling_params is None:
|
||||
pooling_params = self.io_processor.create_pooling_params(ctx.request)
|
||||
else:
|
||||
pooling_params = ctx.pooling_params
|
||||
assert ctx.pooling_params is not None
|
||||
pooling_params = ctx.pooling_params
|
||||
|
||||
if isinstance(pooling_params, list):
|
||||
for params in pooling_params:
|
||||
@@ -190,6 +174,7 @@ class PoolingServing:
|
||||
|
||||
ctx.final_res_batch = [res for res in final_res_batch if res is not None]
|
||||
|
||||
@abstractmethod
|
||||
async def _build_response(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
@@ -355,3 +340,39 @@ class PoolingServing:
|
||||
params=params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
|
||||
class PoolingServing(PoolingServingBase, ABC):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.io_processor = self.init_io_processor(
|
||||
vllm_config=self.vllm_config,
|
||||
renderer=self.renderer,
|
||||
chat_template_config=self.chat_template_config,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def init_io_processor(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
renderer: BaseRenderer,
|
||||
chat_template_config: ChatTemplateConfig,
|
||||
) -> PoolingIOProcessor:
|
||||
raise NotImplementedError
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
request: AnyPoolingRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> Response:
|
||||
ctx = await self._init_ctx(request, raw_request)
|
||||
await self.io_processor.pre_process_online_async(ctx)
|
||||
|
||||
if ctx.pooling_params is None:
|
||||
ctx.pooling_params = self.io_processor.create_pooling_params(request)
|
||||
|
||||
await self._prepare_generators(ctx)
|
||||
await self._collect_batch(ctx)
|
||||
await self.io_processor.post_process_online_async(ctx)
|
||||
return await self._build_response(ctx)
|
||||
|
||||
@@ -5,4 +5,8 @@ from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
|
||||
|
||||
|
||||
class ClassifyIOProcessor(PoolingIOProcessor):
|
||||
name = "classification"
|
||||
name = "classify"
|
||||
|
||||
|
||||
class TokenClassifyIOProcessor(PoolingIOProcessor):
|
||||
name = "token_classify"
|
||||
|
||||
@@ -6,14 +6,11 @@ from typing import TypeAlias
|
||||
import numpy as np
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateConfig
|
||||
from vllm.entrypoints.openai.engine.protocol import UsageInfo
|
||||
from vllm.entrypoints.pooling.base.serving import PoolingServing
|
||||
from vllm.entrypoints.pooling.typing import PoolingServeContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import ClassificationOutput
|
||||
from vllm.renderers import BaseRenderer
|
||||
|
||||
from .io_processor import ClassifyIOProcessor
|
||||
from .protocol import (
|
||||
@@ -31,17 +28,8 @@ ClassificationServeContext: TypeAlias = PoolingServeContext[ClassificationReques
|
||||
class ServingClassification(PoolingServing):
|
||||
request_id_prefix = "classify"
|
||||
|
||||
def init_io_processor(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
renderer: BaseRenderer,
|
||||
chat_template_config: ChatTemplateConfig,
|
||||
) -> ClassifyIOProcessor:
|
||||
return ClassifyIOProcessor(
|
||||
model_config=model_config,
|
||||
renderer=renderer,
|
||||
chat_template_config=chat_template_config,
|
||||
)
|
||||
def init_io_processor(self, *args, **kwargs) -> ClassifyIOProcessor:
|
||||
return ClassifyIOProcessor(*args, **kwargs)
|
||||
|
||||
async def _build_response(
|
||||
self,
|
||||
|
||||
@@ -37,7 +37,7 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class EmbedIOProcessor(PoolingIOProcessor):
|
||||
name = "embedding"
|
||||
name = "embed"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -549,3 +549,7 @@ class EmbedIOProcessor(PoolingIOProcessor):
|
||||
request = ctx.request
|
||||
if request.truncate == "NONE" and request.max_tokens is not None:
|
||||
self._check_cohere_max_tokens(ctx.final_res_batch, request.max_tokens)
|
||||
|
||||
|
||||
class TokenEmbedIOProcessor(PoolingIOProcessor):
|
||||
name = "token_embed"
|
||||
|
||||
@@ -8,8 +8,6 @@ from typing import Literal, TypeAlias, cast
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateConfig
|
||||
from vllm.entrypoints.openai.engine.protocol import UsageInfo
|
||||
from vllm.entrypoints.pooling.base.serving import PoolingServing
|
||||
from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor
|
||||
@@ -33,12 +31,10 @@ from vllm.entrypoints.pooling.utils import (
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.utils.serial_utils import EmbedDType, Endianness
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
JSONResponseCLS = get_json_response_cls()
|
||||
|
||||
EmbeddingServeContext: TypeAlias = PoolingServeContext[EmbeddingRequest]
|
||||
|
||||
@@ -49,17 +45,13 @@ class ServingEmbedding(PoolingServing):
|
||||
request_id_prefix = "embd"
|
||||
io_processor: EmbedIOProcessor
|
||||
|
||||
def init_io_processor(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
renderer: BaseRenderer,
|
||||
chat_template_config: ChatTemplateConfig,
|
||||
) -> EmbedIOProcessor:
|
||||
return EmbedIOProcessor(
|
||||
model_config=model_config,
|
||||
renderer=renderer,
|
||||
chat_template_config=chat_template_config,
|
||||
)
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.json_response_cls = get_json_response_cls()
|
||||
|
||||
def init_io_processor(self, *args, **kwargs) -> EmbedIOProcessor:
|
||||
return EmbedIOProcessor(*args, **kwargs)
|
||||
|
||||
async def _build_response(
|
||||
self,
|
||||
@@ -149,7 +141,7 @@ class ServingEmbedding(PoolingServing):
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
return JSONResponseCLS(content=response.model_dump())
|
||||
return self.json_response_cls(content=response.model_dump())
|
||||
|
||||
def _openai_bytes_response(
|
||||
self,
|
||||
@@ -190,8 +182,8 @@ class ServingEmbedding(PoolingServing):
|
||||
media_type=response.media_type,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_cohere_response_from_ctx(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
) -> JSONResponse:
|
||||
request = ctx.request
|
||||
@@ -218,4 +210,4 @@ class ServingEmbedding(PoolingServing):
|
||||
),
|
||||
),
|
||||
)
|
||||
return JSONResponse(content=response.model_dump(exclude_none=True))
|
||||
return self.json_response_cls(content=response.model_dump(exclude_none=True))
|
||||
|
||||
@@ -1,42 +1,69 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateConfig
|
||||
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
|
||||
from vllm.entrypoints.pooling.scoring.io_processor import ScoringIOProcessors
|
||||
from vllm.entrypoints.pooling.utils import enable_scoring_api
|
||||
from vllm.plugins.io_processors import has_io_processor
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.tasks import SupportedTask
|
||||
|
||||
from .base.io_processor import PoolingIOProcessor
|
||||
from .utils import enable_scoring_api
|
||||
|
||||
|
||||
def init_pooling_io_processors(
|
||||
supported_tasks: tuple[SupportedTask, ...],
|
||||
model_config: ModelConfig,
|
||||
vllm_config: VllmConfig,
|
||||
renderer: BaseRenderer,
|
||||
chat_template_config: ChatTemplateConfig,
|
||||
) -> dict[str, PoolingIOProcessor]:
|
||||
processors: list[tuple[str, type[PoolingIOProcessor]]] = []
|
||||
model_config = vllm_config.model_config
|
||||
processors: dict[str, type[PoolingIOProcessor]] = {}
|
||||
|
||||
if "classify" in supported_tasks:
|
||||
from vllm.entrypoints.pooling.classify.io_processor import ClassifyIOProcessor
|
||||
from .classify.io_processor import ClassifyIOProcessor
|
||||
|
||||
processors["classify"] = ClassifyIOProcessor
|
||||
|
||||
if "token_classify" in supported_tasks:
|
||||
from .classify.io_processor import TokenClassifyIOProcessor
|
||||
|
||||
processors["token_classify"] = TokenClassifyIOProcessor
|
||||
|
||||
processors.append(("classify", ClassifyIOProcessor))
|
||||
if "embed" in supported_tasks:
|
||||
from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor
|
||||
from .embed.io_processor import EmbedIOProcessor
|
||||
|
||||
processors.append(("embed", EmbedIOProcessor))
|
||||
processors["embed"] = EmbedIOProcessor
|
||||
|
||||
if "token_embed" in supported_tasks:
|
||||
from .embed.io_processor import TokenEmbedIOProcessor
|
||||
|
||||
processors["token_embed"] = TokenEmbedIOProcessor
|
||||
|
||||
if has_io_processor(
|
||||
vllm_config,
|
||||
model_config.io_processor_plugin,
|
||||
):
|
||||
from .pooling.io_processor import PluginWithIOProcessorPlugins
|
||||
|
||||
processors["plugin"] = PluginWithIOProcessorPlugins
|
||||
elif "plugin" in supported_tasks:
|
||||
from .pooling.io_processor import PluginWithoutIOProcessorPlugins
|
||||
|
||||
processors["plugin"] = PluginWithoutIOProcessorPlugins
|
||||
|
||||
if enable_scoring_api(supported_tasks, model_config):
|
||||
score_type = model_config.score_type
|
||||
from .scoring.io_processor import ScoringIOProcessors
|
||||
|
||||
if score_type is not None and score_type in ScoringIOProcessors:
|
||||
processors.append((score_type, ScoringIOProcessors[score_type]))
|
||||
processors[score_type] = ScoringIOProcessors[score_type]
|
||||
|
||||
return {
|
||||
task: processor_cls(
|
||||
model_config=model_config,
|
||||
vllm_config=vllm_config,
|
||||
renderer=renderer,
|
||||
chat_template_config=chat_template_config,
|
||||
)
|
||||
for task, processor_cls in processors
|
||||
for task, processor_cls in processors.items()
|
||||
}
|
||||
|
||||
@@ -3,24 +3,17 @@
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.pooling.pooling.protocol import (
|
||||
IOProcessorResponse,
|
||||
PoolingBytesResponse,
|
||||
PoolingRequest,
|
||||
PoolingResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
|
||||
from vllm.entrypoints.pooling.pooling.protocol import PoolingRequest
|
||||
from vllm.entrypoints.pooling.pooling.serving import ServingPooling
|
||||
from vllm.entrypoints.utils import load_aware_call, with_cancellation
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def pooling(request: Request) -> OpenAIServingPooling | None:
|
||||
def pooling(request: Request) -> ServingPooling | None:
|
||||
return request.app.state.serving_pooling
|
||||
|
||||
|
||||
@@ -39,19 +32,4 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
|
||||
if handler is None:
|
||||
raise NotImplementedError("The model does not support Pooling API")
|
||||
|
||||
generator = await handler.create_pooling(request, raw_request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, (PoolingResponse, IOProcessorResponse)):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
elif isinstance(generator, PoolingBytesResponse):
|
||||
return StreamingResponse(
|
||||
content=generator.content,
|
||||
headers=generator.headers,
|
||||
media_type=generator.media_type,
|
||||
)
|
||||
|
||||
assert_never(generator)
|
||||
return await handler(request, raw_request)
|
||||
|
||||
156
vllm/entrypoints/pooling/pooling/io_processor.py
Normal file
156
vllm/entrypoints/pooling/pooling/io_processor.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from vllm import PoolingParams, PoolingRequestOutput
|
||||
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
|
||||
from vllm.inputs import EngineInput
|
||||
from vllm.logger import init_logger
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
|
||||
|
||||
from ..typing import OfflineInputsContext, OfflineOutputsContext, PoolingServeContext
|
||||
from .protocol import IOProcessorRequest, IOProcessorResponse
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class PluginWithoutIOProcessorPlugins(PoolingIOProcessor):
|
||||
name = "plugin"
|
||||
|
||||
|
||||
class PluginWithIOProcessorPlugins(PoolingIOProcessor):
|
||||
"""IO Processor plugins are a feature that allows pre- and post-processing
|
||||
of the model input and output for pooling models."""
|
||||
|
||||
name = "plugin"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
io_processor = get_io_processor(
|
||||
self.vllm_config,
|
||||
self.renderer,
|
||||
self.model_config.io_processor_plugin,
|
||||
)
|
||||
|
||||
assert io_processor is not None
|
||||
self.io_processor = io_processor
|
||||
|
||||
#######################################
|
||||
# online APIs
|
||||
|
||||
def pre_process_online(self, ctx: PoolingServeContext):
|
||||
assert isinstance(ctx.request, IOProcessorRequest)
|
||||
|
||||
validated_prompt = self.io_processor.parse_data(ctx.request.data)
|
||||
|
||||
raw_prompts = self.io_processor.pre_process(
|
||||
prompt=validated_prompt, request_id=ctx.request_id
|
||||
)
|
||||
|
||||
parsed_prompts = [
|
||||
(
|
||||
prompt
|
||||
if isinstance(prompt, bytes)
|
||||
else parse_model_prompt(self.model_config, prompt)
|
||||
)
|
||||
for prompt in prompt_to_seq(raw_prompts)
|
||||
]
|
||||
|
||||
tok_params = ctx.request.build_tok_params(self.model_config)
|
||||
|
||||
ctx.engine_inputs = self.renderer.render_cmpl(
|
||||
parsed_prompts,
|
||||
tok_params,
|
||||
prompt_extras={
|
||||
k: v
|
||||
for k in ("mm_processor_kwargs", "cache_salt")
|
||||
if (v := getattr(ctx.request, k, None)) is not None
|
||||
},
|
||||
)
|
||||
|
||||
pooling_params = self.io_processor.merge_pooling_params()
|
||||
if pooling_params.task is None:
|
||||
pooling_params.task = "plugin"
|
||||
ctx.pooling_params = pooling_params
|
||||
|
||||
def post_process_online(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
):
|
||||
output = self.io_processor.post_process(
|
||||
ctx.final_res_batch,
|
||||
request_id=ctx.request_id,
|
||||
)
|
||||
|
||||
if callable(
|
||||
output_to_response := getattr(self.io_processor, "output_to_response", None)
|
||||
):
|
||||
logger.warning_once(
|
||||
"`IOProcessor.output_to_response` is deprecated. To ensure "
|
||||
"consistency between offline and online APIs, "
|
||||
"`IOProcessorResponse` will become a transparent wrapper "
|
||||
"around output data from v0.19 onwards.",
|
||||
)
|
||||
|
||||
if hasattr(output, "request_id") and output.request_id is None:
|
||||
output.request_id = ctx.request_id # type: ignore
|
||||
|
||||
ctx.response = output_to_response(output) # type: ignore
|
||||
else:
|
||||
ctx.response = IOProcessorResponse(request_id=ctx.request_id, data=output)
|
||||
|
||||
#######################################
|
||||
# offline APIs
|
||||
|
||||
def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]:
|
||||
assert isinstance(ctx.prompts, dict) and "data" in ctx.prompts
|
||||
assert ctx.pooling_params is not None
|
||||
|
||||
# Validate the request data is valid for the loaded plugin
|
||||
prompt_data = ctx.prompts.get("data")
|
||||
if prompt_data is None:
|
||||
raise ValueError(
|
||||
"The 'data' field of the prompt is expected to contain "
|
||||
"the prompt data and it cannot be None. "
|
||||
"Refer to the documentation of the IOProcessor "
|
||||
"in use for more details."
|
||||
)
|
||||
validated_prompt = self.io_processor.parse_data(prompt_data)
|
||||
|
||||
# obtain the actual model prompts from the pre-processor
|
||||
prompts = self.io_processor.pre_process(prompt=validated_prompt)
|
||||
prompts_seq = prompt_to_seq(prompts)
|
||||
|
||||
params_seq: list[PoolingParams] = [
|
||||
self.io_processor.merge_pooling_params(param)
|
||||
for param in self._params_to_seq(
|
||||
ctx.pooling_params,
|
||||
len(prompts_seq),
|
||||
)
|
||||
]
|
||||
for p in params_seq:
|
||||
if p.task is None:
|
||||
p.task = "plugin"
|
||||
|
||||
ctx.pooling_params = params_seq
|
||||
ctx.prompts = prompts_seq
|
||||
return super().pre_process_offline(ctx)
|
||||
|
||||
def post_process_offline(
|
||||
self,
|
||||
ctx: OfflineOutputsContext,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
processed_outputs = self.io_processor.post_process(ctx.outputs)
|
||||
|
||||
return [
|
||||
PoolingRequestOutput[Any](
|
||||
request_id="",
|
||||
outputs=processed_outputs,
|
||||
num_cached_tokens=getattr(processed_outputs, "num_cached_tokens", 0),
|
||||
prompt_token_ids=[],
|
||||
finished=True,
|
||||
)
|
||||
]
|
||||
@@ -1,252 +1,157 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Callable, Sequence
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import Final, Literal, cast
|
||||
from typing import Literal, cast
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.engine.protocol import UsageInfo
|
||||
from vllm.entrypoints.pooling.base.serving import PoolingServingBase
|
||||
from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors
|
||||
from vllm.entrypoints.pooling.pooling.protocol import (
|
||||
IOProcessorRequest,
|
||||
IOProcessorResponse,
|
||||
PoolingBytesResponse,
|
||||
PoolingChatRequest,
|
||||
PoolingCompletionRequest,
|
||||
PoolingRequest,
|
||||
PoolingResponse,
|
||||
PoolingResponseData,
|
||||
)
|
||||
from vllm.entrypoints.pooling.typing import AnyPoolingRequest, PoolingServeContext
|
||||
from vllm.entrypoints.pooling.utils import (
|
||||
encode_pooling_bytes,
|
||||
encode_pooling_output_base64,
|
||||
encode_pooling_output_float,
|
||||
get_json_response_cls,
|
||||
)
|
||||
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
|
||||
from vllm.inputs import EngineInput
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.renderers.inputs.preprocess import prompt_to_seq
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
|
||||
from vllm.utils.serial_utils import EmbedDType, Endianness
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingPooling(OpenAIServing):
|
||||
class ServingPooling(PoolingServingBase):
|
||||
request_id_prefix = "pooling"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
openai_serving_render: OpenAIServingRender,
|
||||
*args,
|
||||
supported_tasks: tuple[SupportedTask, ...],
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
chat_template: str | None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
trust_request_chat_template: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.supported_tasks = supported_tasks
|
||||
self.pooling_task = self.model_config.get_pooling_task(supported_tasks)
|
||||
self.openai_serving_render = openai_serving_render
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
self.trust_request_chat_template = trust_request_chat_template
|
||||
self.io_processors = init_pooling_io_processors(
|
||||
supported_tasks=supported_tasks,
|
||||
vllm_config=self.vllm_config,
|
||||
renderer=self.renderer,
|
||||
chat_template_config=self.chat_template_config,
|
||||
)
|
||||
self.json_response_cls = get_json_response_cls()
|
||||
|
||||
async def create_pooling(
|
||||
async def __call__(
|
||||
self,
|
||||
request: PoolingRequest,
|
||||
request: AnyPoolingRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> PoolingResponse | IOProcessorResponse | PoolingBytesResponse | ErrorResponse:
|
||||
"""
|
||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
for the API specification. This API mimics the OpenAI Embedding API.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
) -> Response:
|
||||
assert isinstance(request, PoolingRequest)
|
||||
pooling_task = self._verify_pooling_task(request)
|
||||
|
||||
model_name = self.models.model_name()
|
||||
io_processor = self.io_processors[pooling_task]
|
||||
ctx = await self._init_ctx(request, raw_request)
|
||||
|
||||
request_id = f"pool-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
await io_processor.pre_process_online_async(ctx)
|
||||
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
if ctx.pooling_params is None:
|
||||
ctx.pooling_params = io_processor.create_pooling_params(request)
|
||||
|
||||
await self._prepare_generators(ctx)
|
||||
await self._collect_batch(ctx)
|
||||
|
||||
await io_processor.post_process_online_async(ctx)
|
||||
return await self._build_response(ctx)
|
||||
|
||||
def _verify_pooling_task(self, request: PoolingRequest) -> str:
|
||||
if getattr(request, "dimensions", None) is not None:
|
||||
raise ValueError("dimensions is currently not supported")
|
||||
|
||||
if request.task is None:
|
||||
request.task = self.pooling_task
|
||||
|
||||
if getattr(request, "dimensions", None) is not None:
|
||||
return self.create_error_response("dimensions is currently not supported")
|
||||
if isinstance(request, IOProcessorRequest):
|
||||
request.task = "plugin"
|
||||
|
||||
assert request.task is not None
|
||||
pooling_task = request.task
|
||||
|
||||
# plugin task uses io_processor.parse_request to verify inputs
|
||||
if request.task != "plugin" and request.task != self.pooling_task:
|
||||
if request.task not in self.supported_tasks:
|
||||
if pooling_task != "plugin" and pooling_task != self.pooling_task:
|
||||
if pooling_task not in self.io_processors:
|
||||
raise ValueError(
|
||||
f"Unsupported task: {request.task!r} "
|
||||
f"Unsupported task: {pooling_task!r} "
|
||||
f"Supported tasks: {self.supported_tasks}"
|
||||
)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Pooling multitask support is deprecated and will be removed "
|
||||
"in v0.20. When the default pooling task is not what you want, you "
|
||||
'need to manually specify it via --pooler-config.task "%s". ',
|
||||
request.task,
|
||||
"need to manually specify it via --pooler-config.task %s. ",
|
||||
pooling_task,
|
||||
)
|
||||
|
||||
engine_inputs: Sequence[EngineInput]
|
||||
if use_io_processor := isinstance(request, IOProcessorRequest):
|
||||
if self.io_processor is None:
|
||||
raise ValueError(
|
||||
"No IOProcessor plugin installed. Please refer "
|
||||
"to the documentation and to the "
|
||||
"'prithvi_geospatial_mae_io_processor' "
|
||||
"offline inference example for more details."
|
||||
)
|
||||
|
||||
validated_prompt = self.io_processor.parse_data(request.data)
|
||||
|
||||
raw_prompts = await self.io_processor.pre_process_async(
|
||||
prompt=validated_prompt, request_id=request_id
|
||||
)
|
||||
engine_inputs = await self.openai_serving_render.preprocess_cmpl(
|
||||
request,
|
||||
prompt_to_seq(raw_prompts),
|
||||
)
|
||||
elif isinstance(request, PoolingChatRequest):
|
||||
error_check_ret = self.openai_serving_render.validate_chat_template(
|
||||
request_chat_template=request.chat_template,
|
||||
chat_template_kwargs=request.chat_template_kwargs,
|
||||
trust_request_chat_template=self.trust_request_chat_template,
|
||||
)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
_, engine_inputs = await self.openai_serving_render.preprocess_chat(
|
||||
request,
|
||||
request.messages,
|
||||
default_template=self.chat_template,
|
||||
default_template_content_format=self.chat_template_content_format,
|
||||
default_template_kwargs=None,
|
||||
)
|
||||
elif isinstance(request, PoolingCompletionRequest):
|
||||
engine_inputs = await self.openai_serving_render.preprocess_completion(
|
||||
request,
|
||||
prompt_input=request.input,
|
||||
prompt_embeds=None,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported request of type {type(request)}")
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
if use_io_processor:
|
||||
assert self.io_processor is not None
|
||||
|
||||
pooling_params = self.io_processor.merge_pooling_params()
|
||||
if pooling_params.task is None:
|
||||
pooling_params.task = "plugin"
|
||||
else:
|
||||
pooling_params = request.to_pooling_params() # type: ignore
|
||||
|
||||
for i, engine_input in enumerate(engine_inputs):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_input,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
if pooling_task == "plugin" and "plugin" not in self.io_processors:
|
||||
raise ValueError(
|
||||
"No IOProcessor plugin installed. Please refer "
|
||||
"to the documentation and to the "
|
||||
"'prithvi_geospatial_mae_io_processor' "
|
||||
"offline inference example for more details."
|
||||
)
|
||||
|
||||
trace_headers = (
|
||||
None
|
||||
if raw_request is None
|
||||
else await self._get_trace_headers(raw_request.headers)
|
||||
return pooling_task
|
||||
|
||||
async def _build_response(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
) -> Response:
|
||||
if ctx.response is not None:
|
||||
# for IOProcessorResponse
|
||||
return self.json_response_cls(content=ctx.response.model_dump())
|
||||
|
||||
encoding_format = ctx.request.encoding_format
|
||||
embed_dtype = ctx.request.embed_dtype
|
||||
endianness = ctx.request.endianness
|
||||
|
||||
if encoding_format == "float" or encoding_format == "base64":
|
||||
return self.request_output_to_pooling_json_response(
|
||||
ctx.final_res_batch,
|
||||
ctx.request_id,
|
||||
ctx.created_time,
|
||||
ctx.model_name,
|
||||
encoding_format,
|
||||
embed_dtype,
|
||||
endianness,
|
||||
)
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_input,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
if encoding_format == "bytes" or encoding_format == "bytes_only":
|
||||
return self.request_output_to_pooling_bytes_response(
|
||||
ctx.final_res_batch,
|
||||
ctx.request_id,
|
||||
ctx.created_time,
|
||||
ctx.model_name,
|
||||
encoding_format,
|
||||
embed_dtype,
|
||||
endianness,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
if use_io_processor:
|
||||
assert self.io_processor is not None
|
||||
output = await self.io_processor.post_process_async(
|
||||
result_generator,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
if callable(
|
||||
output_to_response := getattr(
|
||||
self.io_processor, "output_to_response", None
|
||||
)
|
||||
):
|
||||
logger.warning_once(
|
||||
"`IOProcessor.output_to_response` is deprecated. To ensure "
|
||||
"consistency between offline and online APIs, "
|
||||
"`IOProcessorResponse` will become a transparent wrapper "
|
||||
"around output data from v0.19 onwards.",
|
||||
)
|
||||
|
||||
if hasattr(output, "request_id") and output.request_id is None:
|
||||
output.request_id = request_id # type: ignore
|
||||
|
||||
return output_to_response(output) # type: ignore
|
||||
|
||||
return IOProcessorResponse(request_id=request_id, data=output)
|
||||
|
||||
assert isinstance(request, (PoolingCompletionRequest, PoolingChatRequest))
|
||||
num_prompts = len(engine_inputs)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[PoolingRequestOutput | None]
|
||||
final_res_batch = [None] * num_prompts
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
assert all(final_res is not None for final_res in final_res_batch)
|
||||
|
||||
final_res_batch_checked = cast(list[PoolingRequestOutput], final_res_batch)
|
||||
|
||||
response = self.request_output_to_pooling_response(
|
||||
final_res_batch_checked,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
request.encoding_format,
|
||||
request.embed_dtype,
|
||||
request.endianness,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
|
||||
return response
|
||||
assert_never(encoding_format)
|
||||
|
||||
def request_output_to_pooling_json_response(
|
||||
self,
|
||||
@@ -257,7 +162,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
encoding_format: Literal["float", "base64"],
|
||||
embed_dtype: EmbedDType,
|
||||
endianness: Endianness,
|
||||
) -> PoolingResponse:
|
||||
) -> JSONResponse:
|
||||
encode_fn = cast(
|
||||
Callable[[PoolingRequestOutput], list[float] | str],
|
||||
(
|
||||
@@ -289,13 +194,14 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return PoolingResponse(
|
||||
response = PoolingResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
return self.json_response_cls(content=response.model_dump())
|
||||
|
||||
def request_output_to_pooling_bytes_response(
|
||||
self,
|
||||
@@ -306,7 +212,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
encoding_format: Literal["bytes", "bytes_only"],
|
||||
embed_dtype: EmbedDType,
|
||||
endianness: Endianness,
|
||||
) -> PoolingBytesResponse:
|
||||
) -> StreamingResponse:
|
||||
content, items, usage = encode_pooling_bytes(
|
||||
pooling_outputs=final_res_batch,
|
||||
embed_dtype=embed_dtype,
|
||||
@@ -329,38 +235,10 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
}
|
||||
)
|
||||
|
||||
return PoolingBytesResponse(content=content, headers=headers)
|
||||
response = PoolingBytesResponse(content=content, headers=headers)
|
||||
|
||||
def request_output_to_pooling_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
encoding_format: EncodingFormat,
|
||||
embed_dtype: EmbedDType,
|
||||
endianness: Endianness,
|
||||
) -> PoolingResponse | PoolingBytesResponse:
|
||||
if encoding_format == "float" or encoding_format == "base64":
|
||||
return self.request_output_to_pooling_json_response(
|
||||
final_res_batch,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
encoding_format,
|
||||
embed_dtype,
|
||||
endianness,
|
||||
)
|
||||
|
||||
if encoding_format == "bytes" or encoding_format == "bytes_only":
|
||||
return self.request_output_to_pooling_bytes_response(
|
||||
final_res_batch,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
encoding_format,
|
||||
embed_dtype,
|
||||
endianness,
|
||||
)
|
||||
|
||||
assert_never(encoding_format)
|
||||
return StreamingResponse(
|
||||
content=response.content,
|
||||
headers=response.headers,
|
||||
media_type=response.media_type,
|
||||
)
|
||||
|
||||
@@ -278,7 +278,7 @@ class CrossEncoderIOProcessor(ScoringIOProcessor):
|
||||
|
||||
def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]:
|
||||
assert isinstance(ctx.prompts, ScoringData)
|
||||
assert not isinstance(ctx.pooling_params, list)
|
||||
assert not isinstance(ctx.pooling_params, Sequence)
|
||||
|
||||
tok_params = self.renderer.default_cmpl_tok_params.with_kwargs(
|
||||
**(ctx.tokenization_kwargs or {})
|
||||
|
||||
@@ -4,15 +4,13 @@
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateConfig
|
||||
from vllm.entrypoints.openai.engine.protocol import UsageInfo
|
||||
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
|
||||
from vllm.entrypoints.pooling.base.serving import PoolingServing
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.v1.pool.late_interaction import (
|
||||
build_late_interaction_doc_params,
|
||||
build_late_interaction_query_params,
|
||||
@@ -52,22 +50,17 @@ class ServingScores(PoolingServing):
|
||||
super().__init__(engine_client, *args, **kwargs)
|
||||
|
||||
def init_io_processor(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
renderer: BaseRenderer,
|
||||
chat_template_config: ChatTemplateConfig,
|
||||
self, vllm_config: VllmConfig, *args, **kwargs
|
||||
) -> PoolingIOProcessor:
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
score_type: str = model_config.score_type
|
||||
if self.enable_flash_late_interaction:
|
||||
score_type = "flash-late-interaction"
|
||||
|
||||
assert score_type in ScoringIOProcessors
|
||||
processor_cls = ScoringIOProcessors[score_type]
|
||||
return processor_cls(
|
||||
model_config=model_config,
|
||||
renderer=renderer,
|
||||
chat_template_config=chat_template_config,
|
||||
)
|
||||
return processor_cls(vllm_config, *args, **kwargs)
|
||||
|
||||
async def __call__(self, *args, **kwargs) -> Response:
|
||||
if not self.enable_flash_late_interaction:
|
||||
|
||||
@@ -30,7 +30,7 @@ from vllm.entrypoints.pooling.pooling.protocol import (
|
||||
)
|
||||
from vllm.entrypoints.pooling.scoring.protocol import ScoringRequest, ScoringResponse
|
||||
from vllm.entrypoints.pooling.scoring.typing import ScoringData
|
||||
from vllm.inputs import EngineInput
|
||||
from vllm.inputs import DataPrompt, EngineInput
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
PoolingCompletionLikeRequest: TypeAlias = (
|
||||
@@ -86,11 +86,14 @@ class PoolingServeContext(Generic[PoolingRequestT]):
|
||||
## for bi-encoder & late-interaction
|
||||
n_queries: int | None = None
|
||||
|
||||
## for IOProcessorResponse
|
||||
response: Any | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class OfflineInputsContext:
|
||||
prompts: PromptType | Sequence[PromptType] | ScoringData
|
||||
pooling_params: PoolingParams | list[PoolingParams] | None = None
|
||||
prompts: PromptType | Sequence[PromptType] | DataPrompt | ScoringData
|
||||
pooling_params: PoolingParams | Sequence[PoolingParams]
|
||||
tokenization_kwargs: dict[str, Any] | None = None
|
||||
chat_template: str | None = None
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.pooling.base.serving import PoolingServing
|
||||
from vllm.entrypoints.pooling.base.serving import PoolingServingBase
|
||||
from vllm.entrypoints.pooling.utils import enable_scoring_api
|
||||
from vllm.entrypoints.serve.instrumentator.basic import base
|
||||
from vllm.entrypoints.serve.instrumentator.health import health
|
||||
@@ -23,7 +23,7 @@ from vllm.tasks import POOLING_TASKS, SupportedTask
|
||||
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
|
||||
# (requires typing_extensions >= 4.13)
|
||||
RequestType = Any
|
||||
GetHandlerFn = Callable[[Request], OpenAIServing | PoolingServing | None]
|
||||
GetHandlerFn = Callable[[Request], OpenAIServing | PoolingServingBase | None]
|
||||
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]
|
||||
|
||||
|
||||
|
||||
@@ -65,7 +65,6 @@ class OpenAIServingRender:
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
renderer: BaseRenderer,
|
||||
io_processor: Any,
|
||||
model_registry: OpenAIModelRegistry,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
@@ -81,7 +80,6 @@ class OpenAIServingRender:
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.renderer = renderer
|
||||
self.io_processor = io_processor
|
||||
self.model_registry = model_registry
|
||||
self.request_logger = request_logger
|
||||
self.chat_template = chat_template
|
||||
|
||||
@@ -12,6 +12,23 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def has_io_processor(
|
||||
vllm_config: VllmConfig,
|
||||
plugin_from_init: str | None = None,
|
||||
):
|
||||
if plugin_from_init:
|
||||
model_plugin = plugin_from_init
|
||||
else:
|
||||
# A plugin can be specified via the model config
|
||||
# Retrieve the model specific plugin if available
|
||||
# This is using a custom field in the hf_config for the model
|
||||
hf_config = vllm_config.model_config.hf_config.to_dict()
|
||||
config_plugin = hf_config.get("io_processor_plugin")
|
||||
model_plugin = config_plugin
|
||||
|
||||
return model_plugin is not None
|
||||
|
||||
|
||||
def get_io_processor(
|
||||
vllm_config: VllmConfig,
|
||||
renderer: BaseRenderer,
|
||||
|
||||
@@ -26,7 +26,6 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import renderer_from_config
|
||||
from vllm.renderers.inputs.preprocess import extract_prompt_components
|
||||
@@ -133,11 +132,6 @@ class AsyncLLM(EngineClient):
|
||||
)
|
||||
|
||||
self.renderer = renderer = renderer_from_config(self.vllm_config)
|
||||
self.io_processor = get_io_processor(
|
||||
self.vllm_config,
|
||||
self.renderer,
|
||||
self.model_config.io_processor_plugin,
|
||||
)
|
||||
|
||||
# Convert EngineInput --> EngineCoreRequest.
|
||||
self.input_processor = InputProcessor(self.vllm_config, renderer)
|
||||
|
||||
@@ -19,7 +19,6 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import renderer_from_config
|
||||
from vllm.renderers.inputs.preprocess import extract_prompt_components
|
||||
@@ -90,11 +89,6 @@ class LLMEngine:
|
||||
self.should_execute_dummy_batch = False
|
||||
|
||||
self.renderer = renderer = renderer_from_config(self.vllm_config)
|
||||
self.io_processor = get_io_processor(
|
||||
self.vllm_config,
|
||||
self.renderer,
|
||||
self.model_config.io_processor_plugin,
|
||||
)
|
||||
|
||||
# Convert EngineInput --> EngineCoreRequest.
|
||||
self.input_processor = InputProcessor(self.vllm_config, renderer)
|
||||
|
||||
Reference in New Issue
Block a user