[Misc] IO Processor plugins for pooling models (#22820)
Signed-off-by: Christian Pinto <christian.pinto@ibm.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
@@ -37,13 +37,15 @@ from vllm.entrypoints.score_utils import (ScoreContentPartParam,
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.utils import (_validate_truncation_size,
|
||||
log_non_default_args)
|
||||
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.inputs import (DataPrompt, PromptType, SingletonPrompt, TextPrompt,
|
||||
TokensPrompt)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
|
||||
PoolingRequestOutput, RequestOutput,
|
||||
ScoringRequestOutput)
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import (BeamSearchParams, RequestOutputKind,
|
||||
SamplingParams)
|
||||
@@ -284,6 +286,11 @@ class LLM:
|
||||
|
||||
self.supported_tasks = supported_tasks
|
||||
|
||||
# Load the Input/Output processor plugin if any
|
||||
io_processor_plugin = self.llm_engine.model_config.io_processor_plugin
|
||||
self.io_processor = get_io_processor(self.llm_engine.vllm_config,
|
||||
io_processor_plugin)
|
||||
|
||||
def get_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
@@ -833,7 +840,7 @@ class LLM:
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
prompts: Union[PromptType, Sequence[PromptType], DataPrompt],
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
*,
|
||||
@@ -915,6 +922,22 @@ class LLM:
|
||||
if truncate_prompt_tokens is not None:
|
||||
param.truncate_prompt_tokens = truncate_prompt_tokens
|
||||
|
||||
io_processor_prompt = False
|
||||
if isinstance(prompts, dict) and "data" in prompts:
|
||||
io_processor_prompt = True
|
||||
if self.io_processor is None:
|
||||
raise ValueError(
|
||||
"No IOProcessor plugin installed. Please refer "
|
||||
"to the documentation and to the "
|
||||
"'prithvi_geospatial_mae_io_processor' "
|
||||
"offline inference example for more details.")
|
||||
|
||||
# Validate the request data is valid for the loaded plugin
|
||||
validated_prompt = self.io_processor.parse_request(prompts)
|
||||
|
||||
# obtain the actual model prompts from the pre-processor
|
||||
prompts = self.io_processor.pre_process(prompt=validated_prompt)
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=prompts,
|
||||
params=pooling_params,
|
||||
@@ -923,8 +946,24 @@ class LLM:
|
||||
)
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return self.engine_class.validate_outputs(outputs,
|
||||
PoolingRequestOutput)
|
||||
|
||||
model_outputs = self.engine_class.validate_outputs(
|
||||
outputs, PoolingRequestOutput)
|
||||
|
||||
if io_processor_prompt:
|
||||
# get the post-processed model outputs
|
||||
assert self.io_processor is not None
|
||||
processed_outputs = self.io_processor.post_process(
|
||||
model_output=model_outputs)
|
||||
|
||||
return [
|
||||
PoolingRequestOutput[Any](request_id="",
|
||||
outputs=processed_outputs,
|
||||
prompt_token_ids=[],
|
||||
finished=True)
|
||||
]
|
||||
else:
|
||||
return model_outputs
|
||||
|
||||
def embed(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user