[Renderer] Move Processor out of LLMEngine (#26165)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-10-03 23:08:22 +08:00
committed by GitHub
parent 73a99cc2a5
commit d78fda7cda
4 changed files with 114 additions and 59 deletions

View File

@@ -37,6 +37,7 @@ from vllm.entrypoints.utils import (_validate_truncation_size,
log_non_default_args)
from vllm.inputs import (DataPrompt, PromptType, SingletonPrompt, TextPrompt,
TokensPrompt)
from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.quantization import QuantizationMethods
@@ -49,10 +50,13 @@ from vllm.sampling_params import (BeamSearchParams, RequestOutputKind,
SamplingParams)
from vllm.tasks import PoolingTask
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
get_cached_tokenizer,
init_tokenizer_from_configs)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, Device, as_iter, is_list_of
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.engine.processor import Processor
from vllm.v1.sample.logits_processor import LogitsProcessor
if TYPE_CHECKING:
@@ -312,6 +316,10 @@ class LLM:
self.io_processor = get_io_processor(self.llm_engine.vllm_config,
io_processor_plugin)
@property
def model_config(self):
return self.llm_engine.model_config
def get_tokenizer(self) -> AnyTokenizer:
return self.llm_engine.get_tokenizer()
@@ -324,6 +332,16 @@ class LLM:
else:
self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
def _get_processor(self) -> Processor:
if not hasattr(self, "_processor"):
vllm_config = self.llm_engine.vllm_config
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = init_tokenizer_from_configs(self.model_config)
self._processor = Processor(vllm_config, tokenizer)
return self._processor
def get_default_sampling_params(self) -> SamplingParams:
if self.default_sampling_params is None:
self.default_sampling_params = (
@@ -1497,8 +1515,6 @@ class LLM:
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
it = tqdm_func(it, desc="Adding requests")
model_config = self.llm_engine.model_config
for i, prompt in enumerate(it):
if isinstance(prompt, dict):
@@ -1506,17 +1522,9 @@ class LLM:
prompt.get("multi_modal_data"),
prompt.get("multi_modal_uuids"))
param = params[i] if isinstance(params, Sequence) else params
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(model_config.max_model_len,
param.truncate_prompt_tokens,
tokenization_kwargs)
self._add_request(
prompt,
params[i] if isinstance(params, Sequence) else params,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
priority=priority[i] if priority else 0,
@@ -1557,23 +1565,59 @@ class LLM:
raise ValueError(f"Multi-modal data for {modality} is None"
f" but UUID is not provided")
def _add_request(
def _process_inputs(
self,
prompt: PromptType,
request_id: str,
engine_prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
priority: int = 0,
) -> None:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(
*,
lora_request: Optional[LoRARequest],
priority: int,
) -> tuple[EngineCoreRequest, dict[str, Any]]:
"""Use the Processor to process inputs for LLMEngine."""
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(self.model_config.max_model_len,
params.truncate_prompt_tokens,
tokenization_kwargs)
processor = self._get_processor()
engine_request = processor.process_inputs(
request_id,
prompt,
engine_prompt,
params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
)
return engine_request, tokenization_kwargs
def _add_request(
self,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest] = None,
priority: int = 0,
) -> None:
prompt_text, _, _ = get_prompt_components(prompt)
request_id = str(next(self.request_counter))
engine_request, tokenization_kwargs = self._process_inputs(
request_id,
prompt,
params,
lora_request=lora_request,
priority=priority,
)
self.llm_engine.add_request(
request_id,
engine_request,
params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
prompt_text=prompt_text,
)
def _run_engine(
self,