add io_process_plugin for sparse embedding (#34214)
Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com> Signed-off-by: Augusto Yao <augusto.yjh@antgroup.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@@ -1390,6 +1390,10 @@ steps:
|
||||
- pip install -e ./plugins/prithvi_io_processor_plugin
|
||||
- pytest -v -s plugins_tests/test_io_processor_plugins.py
|
||||
- pip uninstall prithvi_io_processor_plugin -y
|
||||
# test bge_m3_sparse io_processor plugin
|
||||
- pip install -e ./plugins/bge_m3_sparse_plugin
|
||||
- pytest -v -s plugins_tests/test_bge_m3_sparse_io_processor_plugins.py
|
||||
- pip uninstall bge_m3_sparse_plugin -y
|
||||
# end io_processor plugins test
|
||||
# begin stat_logger plugins test
|
||||
- pip install -e ./plugins/vllm_add_dummy_stat_logger
|
||||
@@ -2967,6 +2971,10 @@ steps:
|
||||
- pip install -e ./plugins/prithvi_io_processor_plugin
|
||||
- pytest -v -s plugins_tests/test_io_processor_plugins.py
|
||||
- pip uninstall prithvi_io_processor_plugin -y
|
||||
# test bge_m3_sparse io_processor plugin
|
||||
- pip install -e ./plugins/bge_m3_sparse_plugin
|
||||
- pytest -v -s plugins_tests/test_bge_m3_sparse_io_processor_plugins.py
|
||||
- pip uninstall bge_m3_sparse_plugin -y
|
||||
# end io_processor plugins test
|
||||
# begin stat_logger plugins test
|
||||
- pip install -e ./plugins/vllm_add_dummy_stat_logger
|
||||
@@ -3248,4 +3256,4 @@ steps:
|
||||
num_gpus: 4
|
||||
working_dir: "/vllm-workspace"
|
||||
commands:
|
||||
- bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040
|
||||
- bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040
|
||||
|
||||
@@ -19,6 +19,10 @@ steps:
|
||||
- pip install -e ./plugins/prithvi_io_processor_plugin
|
||||
- pytest -v -s plugins_tests/test_io_processor_plugins.py
|
||||
- pip uninstall prithvi_io_processor_plugin -y
|
||||
# test bge_m3_sparse io_processor plugin
|
||||
- pip install -e ./plugins/bge_m3_sparse_plugin
|
||||
- pytest -v -s plugins_tests/test_bge_m3_sparse_io_processor_plugins.py
|
||||
- pip uninstall bge_m3_sparse_plugin -y
|
||||
# end io_processor plugins test
|
||||
# begin stat_logger plugins test
|
||||
- pip install -e ./plugins/vllm_add_dummy_stat_logger
|
||||
|
||||
@@ -13,12 +13,13 @@ IOProcessorInput = TypeVar("IOProcessorInput")
|
||||
IOProcessorOutput = TypeVar("IOProcessorOutput")
|
||||
|
||||
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
"""Abstract interface for pre/post-processing of engine I/O."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, renderer: BaseRenderer):
|
||||
super().__init__()
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
@abstractmethod
|
||||
def parse_data(self, data: object) -> IOProcessorInput:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -32,7 +33,7 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
|
||||
self,
|
||||
params: PoolingParams | None = None,
|
||||
) -> PoolingParams:
|
||||
return params or PoolingParams()
|
||||
return params or PoolingParams(task="plugin")
|
||||
|
||||
@abstractmethod
|
||||
def pre_process(
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
def register_bge_m3_sparse_embeddings_processor():
|
||||
return "bge_m3_sparse_processor.sparse_embeddings_processor.BgeM3SparseEmbeddingsProcessor" # noqa: E501
|
||||
@@ -0,0 +1,135 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.openai.engine.protocol import UsageInfo
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.plugins.io_processors.interface import (
|
||||
IOProcessor,
|
||||
)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.tokenizers.detokenizer_utils import convert_ids_list_to_tokens
|
||||
|
||||
from .types import (
|
||||
SparseEmbeddingCompletionRequestMixin,
|
||||
SparseEmbeddingResponse,
|
||||
SparseEmbeddingResponseData,
|
||||
SparseEmbeddingTokenWeight,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BgeM3SparseEmbeddingsProcessor(
|
||||
IOProcessor[SparseEmbeddingCompletionRequestMixin, SparseEmbeddingResponse]
|
||||
):
|
||||
def __init__(self, vllm_config: VllmConfig, renderer: BaseRenderer):
|
||||
super().__init__(vllm_config, renderer)
|
||||
self.offline_requests: list[SparseEmbeddingCompletionRequestMixin] = []
|
||||
self.online_requests: dict[str, SparseEmbeddingCompletionRequestMixin] = {}
|
||||
self.renderer: BaseRenderer = renderer
|
||||
|
||||
def merge_pooling_params(
|
||||
self,
|
||||
params: PoolingParams | None = None,
|
||||
) -> PoolingParams:
|
||||
if params is None:
|
||||
params = PoolingParams()
|
||||
# refer to PoolingCompletionRequest.to_pooling_params
|
||||
params.task = "token_classify"
|
||||
return params
|
||||
|
||||
def parse_request(
|
||||
self, request_data: object
|
||||
) -> SparseEmbeddingCompletionRequestMixin:
|
||||
# for vllm.entrypoints.llm.LLM, offline mode, calls `encode` directly.
|
||||
if isinstance(request_data, dict):
|
||||
return SparseEmbeddingCompletionRequestMixin(**request_data)
|
||||
raise TypeError("request_data should be a dictionary")
|
||||
|
||||
def pre_process(
|
||||
self,
|
||||
prompt: SparseEmbeddingCompletionRequestMixin,
|
||||
request_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> PromptType | Sequence[PromptType]:
|
||||
if request_id is not None:
|
||||
assert request_id not in self.online_requests, "request_id duplicated"
|
||||
self.online_requests[request_id] = prompt
|
||||
else:
|
||||
self.offline_requests.append(prompt)
|
||||
return prompt.input
|
||||
|
||||
def _get_sparse_embedding_request(self, request_id: str | None = None):
|
||||
if request_id:
|
||||
return self.online_requests.pop(request_id, None)
|
||||
return self.offline_requests.pop()
|
||||
|
||||
def _build_sparse_embedding_token_weights(
|
||||
self,
|
||||
sparse_embedding: dict[int, float],
|
||||
return_tokens: bool = False,
|
||||
) -> list[SparseEmbeddingTokenWeight]:
|
||||
token_ids = sparse_embedding.keys()
|
||||
token_weights = sparse_embedding.values()
|
||||
tokens = [None] * len(token_ids)
|
||||
|
||||
if return_tokens and self.renderer is not None:
|
||||
tokens = convert_ids_list_to_tokens(
|
||||
self.renderer.get_tokenizer(), token_ids
|
||||
)
|
||||
sparse_embedding_output: list[SparseEmbeddingTokenWeight] = []
|
||||
for token_id, weight, token in zip(token_ids, token_weights, tokens):
|
||||
sparse_embedding_output.append(
|
||||
SparseEmbeddingTokenWeight(
|
||||
token_id=token_id, weight=weight, token=token
|
||||
)
|
||||
)
|
||||
return sparse_embedding_output
|
||||
|
||||
def post_process(
|
||||
self,
|
||||
model_output: Sequence[PoolingRequestOutput],
|
||||
request_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> SparseEmbeddingResponse:
|
||||
num_prompt_tokens = 0
|
||||
response_data = []
|
||||
return_tokens = self._get_sparse_embedding_request(request_id).return_tokens
|
||||
for idx in range(len(model_output)):
|
||||
mo = model_output[idx]
|
||||
sparse_embedding: dict[int, float] = {}
|
||||
num_prompt_tokens += len(mo.prompt_token_ids)
|
||||
if len(mo.prompt_token_ids) != len(mo.outputs.data):
|
||||
# this is the case that add_special_tokens is True,
|
||||
# which means first token and last token are special tokens
|
||||
mo.prompt_token_ids = mo.prompt_token_ids[1:]
|
||||
for token_id, weight in zip(mo.prompt_token_ids, mo.outputs.data.tolist()):
|
||||
sparse_embedding[token_id] = max(
|
||||
weight, sparse_embedding.get(token_id, 0.0)
|
||||
)
|
||||
response_data.append(
|
||||
SparseEmbeddingResponseData(
|
||||
index=idx,
|
||||
sparse_embedding=self._build_sparse_embedding_token_weights(
|
||||
sparse_embedding,
|
||||
return_tokens,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
resp = SparseEmbeddingResponse(
|
||||
data=response_data,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return resp
|
||||
@@ -0,0 +1,32 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import UsageInfo
|
||||
from vllm.entrypoints.pooling.base.protocol import CompletionRequestMixin
|
||||
|
||||
|
||||
class SparseEmbeddingCompletionRequestMixin(CompletionRequestMixin):
|
||||
return_tokens: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to return dict shows the mapping of token_id to text."
|
||||
"`None` or False means not return.",
|
||||
)
|
||||
|
||||
|
||||
class SparseEmbeddingTokenWeight(BaseModel):
|
||||
token_id: int
|
||||
weight: float
|
||||
token: str | None
|
||||
|
||||
|
||||
class SparseEmbeddingResponseData(BaseModel):
|
||||
index: int
|
||||
object: str = "sparse-embedding"
|
||||
sparse_embedding: list[SparseEmbeddingTokenWeight]
|
||||
|
||||
|
||||
class SparseEmbeddingResponse(BaseModel):
|
||||
data: list[SparseEmbeddingResponseData]
|
||||
usage: UsageInfo
|
||||
15
tests/plugins/bge_m3_sparse_plugin/setup.py
Normal file
15
tests/plugins/bge_m3_sparse_plugin/setup.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from setuptools import setup
|
||||
|
||||
setup(
|
||||
name="bge-m3-sparse-plugin",
|
||||
version="0.1",
|
||||
packages=["bge_m3_sparse_processor"],
|
||||
entry_points={
|
||||
"vllm.io_processor_plugins": [
|
||||
"bge_m3_sparse_plugin = bge_m3_sparse_processor:register_bge_m3_sparse_embeddings_processor", # noqa: E501
|
||||
]
|
||||
},
|
||||
)
|
||||
@@ -22,6 +22,7 @@ from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.plugins.io_processors.interface import IOProcessor
|
||||
from vllm.renderers import BaseRenderer
|
||||
|
||||
from .types import DataModuleConfig, ImagePrompt, ImageRequestOutput
|
||||
|
||||
@@ -218,8 +219,8 @@ def load_image(
|
||||
class PrithviMultimodalDataProcessor(IOProcessor[ImagePrompt, ImageRequestOutput]):
|
||||
indices = [0, 1, 2, 3, 4, 5]
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
super().__init__(vllm_config)
|
||||
def __init__(self, vllm_config: VllmConfig, renderer: BaseRenderer):
|
||||
super().__init__(vllm_config, renderer)
|
||||
|
||||
self.datamodule = Sen1Floods11NonGeoDataModule(
|
||||
data_root=datamodule_config["data_root"],
|
||||
|
||||
212
tests/plugins_tests/test_bge_m3_sparse_io_processor_plugins.py
Normal file
212
tests/plugins_tests/test_bge_m3_sparse_io_processor_plugins.py
Normal file
@@ -0,0 +1,212 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
# Test configuration for BGE-M3 sparse plugin
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.entrypoints.pooling.pooling.protocol import IOProcessorResponse
|
||||
|
||||
model_config = {
|
||||
"model_name": "BAAI/bge-m3",
|
||||
"plugin": "bge_m3_sparse_plugin",
|
||||
"test_input": "What is the capital of France?",
|
||||
"hf_overrides": json.dumps(
|
||||
{"architectures": ["BgeM3EmbeddingModel"], "head_dtype": "float16"}
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _float_close(expected: object, result: object):
|
||||
assert isinstance(expected, float) and isinstance(result, float), (
|
||||
f"{expected=} or {result=} is not float"
|
||||
)
|
||||
return (expected - result) < 1e-3 or abs(expected / result - 1) < 1e-3
|
||||
|
||||
|
||||
def _get_attr_or_val(obj: object | dict, key: str):
|
||||
if isinstance(obj, dict) and key in obj:
|
||||
return obj[key]
|
||||
return getattr(obj, key, None)
|
||||
|
||||
|
||||
def _check_sparse_embedding(data, check_tokens=False):
|
||||
expected_weights = [
|
||||
{"token_id": 32, "weight": 0.0552978515625, "token": "?"},
|
||||
{"token_id": 70, "weight": 0.09808349609375, "token": "the"},
|
||||
{"token_id": 83, "weight": 0.08154296875, "token": "is"},
|
||||
{"token_id": 111, "weight": 0.11810302734375, "token": "of"},
|
||||
{"token_id": 4865, "weight": 0.1171875, "token": "What"},
|
||||
{"token_id": 9942, "weight": 0.292236328125, "token": "France"},
|
||||
{"token_id": 10323, "weight": 0.2802734375, "token": "capital"},
|
||||
]
|
||||
expected_embed = {x["token_id"]: x for x in expected_weights}
|
||||
|
||||
assert len(data) == len(expected_embed)
|
||||
for entry in data:
|
||||
expected_val = expected_embed[_get_attr_or_val(entry, "token_id")]
|
||||
assert _float_close(
|
||||
expected_val["weight"], _get_attr_or_val(entry, "weight")
|
||||
), f"actual embed {entry} not equal to {expected_val}"
|
||||
if check_tokens:
|
||||
assert expected_val["token"] == _get_attr_or_val(entry, "token"), (
|
||||
f"actual embed {entry} not equal to {expected_val}"
|
||||
)
|
||||
else:
|
||||
assert _get_attr_or_val(entry, "token") is None, (
|
||||
f"{entry} should not return token"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def server():
|
||||
args = [
|
||||
"--runner",
|
||||
"pooling",
|
||||
"--enforce-eager",
|
||||
"--max-num-seqs",
|
||||
"32",
|
||||
"--hf_overrides",
|
||||
model_config["hf_overrides"],
|
||||
"--io-processor-plugin",
|
||||
model_config["plugin"],
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(model_config["model_name"], args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"return_tokens",
|
||||
[True, False],
|
||||
)
|
||||
async def test_bge_m3_sparse_plugin_online(
|
||||
server: RemoteOpenAIServer, return_tokens: bool
|
||||
):
|
||||
"""Test BGE-M3 sparse plugin in online mode via API."""
|
||||
request_payload = {
|
||||
"model": model_config["model_name"],
|
||||
"task": "token_classify",
|
||||
"data": {"input": model_config["test_input"], "return_tokens": return_tokens},
|
||||
}
|
||||
|
||||
ret = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json=request_payload,
|
||||
)
|
||||
|
||||
response = ret.json()
|
||||
|
||||
# Verify the request response is in the correct format
|
||||
assert (parsed_response := IOProcessorResponse(**response).data)
|
||||
|
||||
# Verify the output is formatted as expected for this plugin
|
||||
assert _get_attr_or_val(parsed_response, "data")
|
||||
assert len(_get_attr_or_val(parsed_response, "data")) > 0
|
||||
|
||||
data_entry = _get_attr_or_val(parsed_response, "data")[0]
|
||||
assert _get_attr_or_val(data_entry, "object") == "sparse-embedding"
|
||||
assert _get_attr_or_val(data_entry, "sparse_embedding")
|
||||
|
||||
# Verify sparse embedding format
|
||||
sparse_embedding = _get_attr_or_val(data_entry, "sparse_embedding")
|
||||
assert isinstance(sparse_embedding, list)
|
||||
_check_sparse_embedding(sparse_embedding, return_tokens)
|
||||
|
||||
# Verify usage information
|
||||
usage = _get_attr_or_val(parsed_response, "usage")
|
||||
assert usage, f"usage not found for {parsed_response}"
|
||||
assert _get_attr_or_val(usage, "prompt_tokens") > 0
|
||||
assert _get_attr_or_val(usage, "total_tokens") == _get_attr_or_val(
|
||||
usage, "prompt_tokens"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"return_tokens",
|
||||
[True, False],
|
||||
)
|
||||
def test_bge_m3_sparse_plugin_offline(vllm_runner, return_tokens: bool):
|
||||
"""Test BGE-M3 sparse plugin in offline mode."""
|
||||
prompt = {
|
||||
"data": {
|
||||
"input": model_config["test_input"],
|
||||
"return_tokens": return_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
with vllm_runner(
|
||||
model_config["model_name"],
|
||||
runner="pooling",
|
||||
enforce_eager=True,
|
||||
max_num_seqs=32,
|
||||
io_processor_plugin=model_config["plugin"],
|
||||
hf_overrides=json.loads(model_config["hf_overrides"]),
|
||||
default_torch_num_threads=1,
|
||||
) as llm_runner:
|
||||
llm = llm_runner.get_llm()
|
||||
pooler_output = llm.encode(prompt, pooling_task="token_classify")
|
||||
|
||||
outputs = pooler_output[0]
|
||||
|
||||
# Verify output structure
|
||||
assert hasattr(outputs, "outputs")
|
||||
response = outputs.outputs
|
||||
assert hasattr(response, "data")
|
||||
assert len(response.data) == 1
|
||||
# Verify response data
|
||||
for i, output in enumerate(response.data):
|
||||
# Each output should have sparse embeddings
|
||||
sparse_embedding = output.sparse_embedding
|
||||
assert isinstance(sparse_embedding, list)
|
||||
_check_sparse_embedding(sparse_embedding, return_tokens)
|
||||
|
||||
# Verify usage
|
||||
assert response.usage.prompt_tokens > 0
|
||||
assert response.usage.total_tokens == response.usage.prompt_tokens
|
||||
|
||||
|
||||
def test_bge_m3_sparse_plugin_offline_multiple_inputs(vllm_runner):
|
||||
"""Test BGE-M3 sparse plugin with multiple inputs in offline mode."""
|
||||
prompts = {
|
||||
"data": {
|
||||
"input": [
|
||||
"What is the capital of France?",
|
||||
"What is the capital of Germany?",
|
||||
"What is the capital of Spain?",
|
||||
],
|
||||
"return_tokens": True,
|
||||
}
|
||||
}
|
||||
|
||||
with vllm_runner(
|
||||
model_config["model_name"],
|
||||
runner="pooling",
|
||||
enforce_eager=True,
|
||||
max_num_seqs=32,
|
||||
io_processor_plugin=model_config["plugin"],
|
||||
hf_overrides=json.loads(model_config["hf_overrides"]),
|
||||
default_torch_num_threads=1,
|
||||
) as llm_runner:
|
||||
llm = llm_runner.get_llm()
|
||||
pooler_output = llm.encode(prompts, pooling_task="token_classify")
|
||||
|
||||
outputs = pooler_output[0]
|
||||
|
||||
# Verify output structure
|
||||
assert hasattr(outputs, "outputs")
|
||||
response = outputs.outputs
|
||||
assert hasattr(response, "data")
|
||||
assert len(response.data) == 3
|
||||
for i, output in enumerate(response.data):
|
||||
# Each output should have sparse embeddings
|
||||
sparse_embedding = output.sparse_embedding
|
||||
assert isinstance(sparse_embedding, list)
|
||||
|
||||
# Verify usage
|
||||
assert response.usage.prompt_tokens > 0
|
||||
assert response.usage.total_tokens == response.usage.prompt_tokens
|
||||
@@ -39,7 +39,7 @@ def _compute_image_hash(base64_data: str) -> str:
|
||||
def test_loading_missing_plugin():
|
||||
vllm_config = VllmConfig()
|
||||
with pytest.raises(ValueError):
|
||||
get_io_processor(vllm_config, "wrong_plugin")
|
||||
get_io_processor(vllm_config, None, "wrong_plugin")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.plugins import IO_PROCESSOR_PLUGINS_GROUP, load_plugins_by_group
|
||||
from vllm.plugins.io_processors.interface import IOProcessor
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_io_processor(
|
||||
vllm_config: VllmConfig, plugin_from_init: str | None = None
|
||||
vllm_config: VllmConfig,
|
||||
renderer: BaseRenderer,
|
||||
plugin_from_init: str | None = None,
|
||||
) -> IOProcessor | None:
|
||||
# Input.Output processors are loaded as plugins under the
|
||||
# 'vllm.io_processor_plugins' group. Similar to platform
|
||||
@@ -65,4 +69,14 @@ def get_io_processor(
|
||||
|
||||
activated_plugin_cls = loadable_plugins[model_plugin]
|
||||
|
||||
return resolve_obj_by_qualname(activated_plugin_cls)(vllm_config)
|
||||
activated_plugin_typ = resolve_obj_by_qualname(activated_plugin_cls)
|
||||
|
||||
# for backward compatibility, the plugin does not have a renderer argument
|
||||
if "renderer" not in inspect.signature(activated_plugin_typ.__init__).parameters:
|
||||
logger.warning(
|
||||
"The renderer argument will be required in v0.18, "
|
||||
"please update your IOProcessor plugin: %s",
|
||||
activated_plugin_cls,
|
||||
)
|
||||
return activated_plugin_typ(vllm_config)
|
||||
return activated_plugin_typ(vllm_config, renderer)
|
||||
|
||||
@@ -9,6 +9,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
IOProcessorInput = TypeVar("IOProcessorInput")
|
||||
@@ -18,7 +19,7 @@ IOProcessorOutput = TypeVar("IOProcessorOutput")
|
||||
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
|
||||
"""Abstract interface for pre/post-processing of engine I/O."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
def __init__(self, vllm_config: VllmConfig, renderer: BaseRenderer):
|
||||
super().__init__()
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
@@ -135,6 +135,7 @@ class AsyncLLM(EngineClient):
|
||||
self.renderer = renderer = renderer_from_config(self.vllm_config)
|
||||
self.io_processor = get_io_processor(
|
||||
self.vllm_config,
|
||||
self.renderer,
|
||||
self.model_config.io_processor_plugin,
|
||||
)
|
||||
|
||||
|
||||
@@ -92,6 +92,7 @@ class LLMEngine:
|
||||
self.renderer = renderer = renderer_from_config(self.vllm_config)
|
||||
self.io_processor = get_io_processor(
|
||||
self.vllm_config,
|
||||
self.renderer,
|
||||
self.model_config.io_processor_plugin,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user