[Core] Registry for processing model inputs (#5214)
Co-authored-by: ywang96 <ywang@roblox.com>
This commit is contained in:
@@ -20,7 +20,7 @@ from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.engine.output_processor.util import create_output_by_sequence_group
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import LLMInputs, PromptInputs
|
||||
from vllm.inputs import INPUT_REGISTRY, LLMInputs, PromptInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
|
||||
@@ -227,6 +227,9 @@ class LLMEngine:
|
||||
self.generation_config_fields = _load_generation_config_dict(
|
||||
model_config)
|
||||
|
||||
self.input_processor = INPUT_REGISTRY.create_input_processor(
|
||||
self.model_config)
|
||||
|
||||
self.model_executor = executor_class(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
@@ -511,9 +514,11 @@ class LLMEngine:
|
||||
else:
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
|
||||
return LLMInputs(prompt_token_ids=prompt_token_ids,
|
||||
prompt=inputs.get("prompt"),
|
||||
multi_modal_data=inputs.get("multi_modal_data"))
|
||||
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
|
||||
prompt=inputs.get("prompt"),
|
||||
multi_modal_data=inputs.get("multi_modal_data"))
|
||||
|
||||
return self.input_processor(llm_inputs)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user