[Refactor] Pass Renderer to Input Processor (#34329)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -125,6 +125,7 @@ class TestInitializeToolSessions:
|
||||
engine_client = MagicMock()
|
||||
|
||||
model_config = MagicMock()
|
||||
model_config.max_model_len = 100
|
||||
model_config.hf_config.model_type = "test"
|
||||
model_config.get_diff_sampling_param.return_value = {}
|
||||
engine_client.model_config = model_config
|
||||
@@ -212,6 +213,7 @@ class TestValidateGeneratorInput:
|
||||
engine_client = MagicMock()
|
||||
|
||||
model_config = MagicMock()
|
||||
model_config.max_model_len = 100
|
||||
model_config.hf_config.model_type = "test"
|
||||
model_config.get_diff_sampling_param.return_value = {}
|
||||
engine_client.model_config = model_config
|
||||
@@ -231,9 +233,6 @@ class TestValidateGeneratorInput:
|
||||
chat_template_content_format="auto",
|
||||
)
|
||||
|
||||
# Set max_model_len for testing
|
||||
instance.max_model_len = 100
|
||||
|
||||
return instance
|
||||
|
||||
def test_validate_generator_input(self, serving_responses_instance):
|
||||
|
||||
@@ -507,7 +507,8 @@ def test_apc_single_prompt_block_align_alignment(
|
||||
vllm_runner_kwargs["enable_prefix_caching"] = True
|
||||
with vllm_runner(**vllm_runner_kwargs) as vllm_model:
|
||||
# Retrieve the default mamba state block size
|
||||
mamba_block_size = vllm_model.llm.llm_engine.cache_config.mamba_block_size
|
||||
vllm_config = vllm_model.llm.llm_engine.vllm_config
|
||||
mamba_block_size = vllm_config.cache_config.mamba_block_size
|
||||
|
||||
# In case the hybrid model does not have the
|
||||
# "mamba_block_size" assume a fixed constant
|
||||
@@ -660,7 +661,8 @@ def test_apc_multiple_prompts_block_align_alignment(
|
||||
vllm_runner_kwargs["enable_prefix_caching"] = True
|
||||
with vllm_runner(**vllm_runner_kwargs) as vllm_model:
|
||||
# Retrieve the default mamba state block size
|
||||
mamba_block_size = vllm_model.llm.llm_engine.cache_config.mamba_block_size
|
||||
vllm_config = vllm_model.llm.llm_engine.vllm_config
|
||||
mamba_block_size = vllm_config.cache_config.mamba_block_size
|
||||
|
||||
# In case the hybrid model does not have the
|
||||
# "mamba_block_size" assume a fixed constant
|
||||
|
||||
@@ -25,7 +25,8 @@ def test_classify_models(
|
||||
with vllm_runner(
|
||||
model, max_model_len=512, dtype=dtype, enable_prefix_caching=True
|
||||
) as vllm_model:
|
||||
cache_config = vllm_model.llm.llm_engine.cache_config
|
||||
vllm_config = vllm_model.llm.llm_engine.vllm_config
|
||||
cache_config = vllm_config.cache_config
|
||||
assert cache_config.enable_prefix_caching
|
||||
|
||||
# First Run
|
||||
@@ -74,7 +75,8 @@ def test_embed_models(
|
||||
max_model_len=None,
|
||||
enable_prefix_caching=True,
|
||||
) as vllm_model:
|
||||
cache_config = vllm_model.llm.llm_engine.cache_config
|
||||
vllm_config = vllm_model.llm.llm_engine.vllm_config
|
||||
cache_config = vllm_config.cache_config
|
||||
assert cache_config.enable_prefix_caching
|
||||
|
||||
# First Run
|
||||
@@ -106,5 +108,6 @@ def test_non_causal_models(
|
||||
hf_runner, vllm_runner, example_prompts, model: str, dtype: str
|
||||
) -> None:
|
||||
with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model:
|
||||
cache_config = vllm_model.llm.llm_engine.cache_config
|
||||
vllm_config = vllm_model.llm.llm_engine.vllm_config
|
||||
cache_config = vllm_config.cache_config
|
||||
assert not cache_config.enable_prefix_caching
|
||||
|
||||
@@ -161,7 +161,8 @@ def test_pooling_prefix_cache(vllm_runner, monkeypatch):
|
||||
assert chunks[0] <= prompt1_len
|
||||
assert chunks[0] < prompt2_len
|
||||
|
||||
cache_config = llm.get_llm().llm_engine.cache_config
|
||||
vllm_config = llm.get_llm().llm_engine.vllm_config
|
||||
cache_config = vllm_config.cache_config
|
||||
print(f"{cache_config=}")
|
||||
# Prefixes are cached in blocks
|
||||
assert (prompt2_len - chunks[0]) % cache_config.block_size == 0
|
||||
|
||||
@@ -311,7 +311,8 @@ def test_get_logprobs_and_prompt_logprobs(
|
||||
temperature: "temperature" sampling parameter
|
||||
example_prompts: example prompt fixture
|
||||
"""
|
||||
do_apc = vllm_model.llm.llm_engine.cache_config.enable_prefix_caching
|
||||
vllm_config = vllm_model.llm.llm_engine.vllm_config
|
||||
do_apc = vllm_config.cache_config.enable_prefix_caching
|
||||
if do_apc and (temperature < 2.0 or batch_logprobs_composition != SAMPLE_PROMPT):
|
||||
# Skip some test-cases to save time.
|
||||
pytest.skip()
|
||||
|
||||
@@ -54,7 +54,7 @@ class PoolerConfig:
|
||||
Reduce the dimensions of embeddings if model
|
||||
support matryoshka representation. Defaults to None.
|
||||
"""
|
||||
enable_chunked_processing: bool | None = None
|
||||
enable_chunked_processing: bool = False
|
||||
"""
|
||||
Whether to enable chunked processing for long inputs that exceed the model's
|
||||
maximum position embeddings. When enabled, long inputs will be split into
|
||||
|
||||
@@ -31,12 +31,9 @@ class EngineClient(ABC):
|
||||
|
||||
vllm_config: VllmConfig
|
||||
model_config: ModelConfig
|
||||
input_processor: InputProcessor
|
||||
renderer: BaseRenderer
|
||||
io_processor: IOProcessor | None
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def renderer(self) -> BaseRenderer: ...
|
||||
input_processor: InputProcessor
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
|
||||
@@ -356,8 +356,9 @@ class LLM:
|
||||
self.supported_tasks = supported_tasks
|
||||
|
||||
self.model_config = self.llm_engine.model_config
|
||||
self.input_processor = self.llm_engine.input_processor
|
||||
self.renderer = self.llm_engine.renderer
|
||||
self.io_processor = self.llm_engine.io_processor
|
||||
self.input_processor = self.llm_engine.input_processor
|
||||
|
||||
# Cache for __repr__ to avoid repeated collective_rpc calls
|
||||
self._cached_repr: str | None = None
|
||||
@@ -816,7 +817,7 @@ class LLM:
|
||||
A list of `TokensPrompts` objects containing the tokenized prompt
|
||||
after chat template interpolation, and the raw multi-modal inputs.
|
||||
"""
|
||||
renderer = self.llm_engine.renderer
|
||||
renderer = self.renderer
|
||||
model_config = self.model_config
|
||||
|
||||
parsed_prompts = [
|
||||
@@ -858,7 +859,7 @@ class LLM:
|
||||
A list of `TokensPrompts` objects containing the tokenized prompt
|
||||
after chat template interpolation, and the raw multi-modal inputs.
|
||||
"""
|
||||
renderer = self.llm_engine.renderer
|
||||
renderer = self.renderer
|
||||
|
||||
chat_params = ChatParams(
|
||||
chat_template=chat_template,
|
||||
|
||||
@@ -239,8 +239,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
try:
|
||||
renderer = self.engine_client.renderer
|
||||
tokenizer = renderer.tokenizer
|
||||
tokenizer = self.renderer.tokenizer
|
||||
|
||||
tool_parser = self.tool_parser
|
||||
|
||||
@@ -375,6 +374,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
data_parallel_rank = self._get_data_parallel_rank(raw_request)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
max_model_len = self.model_config.max_model_len
|
||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
@@ -387,7 +387,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
)
|
||||
|
||||
max_tokens = get_max_tokens(
|
||||
self.max_model_len,
|
||||
max_model_len,
|
||||
request.max_completion_tokens
|
||||
if request.max_completion_tokens is not None
|
||||
else request.max_tokens,
|
||||
|
||||
@@ -157,13 +157,14 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
data_parallel_rank = self._get_data_parallel_rank(raw_request)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
max_model_len = self.model_config.max_model_len
|
||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
prompt_text = self._extract_prompt_text(engine_prompt)
|
||||
|
||||
max_tokens = get_max_tokens(
|
||||
self.max_model_len,
|
||||
max_model_len,
|
||||
request.max_tokens,
|
||||
self._extract_prompt_len(engine_prompt),
|
||||
self.default_sampling_params,
|
||||
|
||||
@@ -242,11 +242,10 @@ class OpenAIServing:
|
||||
|
||||
self.log_error_stack = log_error_stack
|
||||
|
||||
self.input_processor = self.models.input_processor
|
||||
self.io_processor = self.models.io_processor
|
||||
self.renderer = self.models.renderer
|
||||
self.model_config = self.models.model_config
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
self.model_config = engine_client.model_config
|
||||
self.renderer = engine_client.renderer
|
||||
self.io_processor = engine_client.io_processor
|
||||
self.input_processor = engine_client.input_processor
|
||||
|
||||
async def beam_search(
|
||||
self,
|
||||
@@ -537,7 +536,7 @@ class OpenAIServing:
|
||||
|
||||
if (
|
||||
truncate_prompt_tokens is not None
|
||||
and truncate_prompt_tokens > self.max_model_len
|
||||
and truncate_prompt_tokens > self.model_config.max_model_len
|
||||
):
|
||||
return self.create_error_response(
|
||||
"truncate_prompt_tokens value is "
|
||||
@@ -844,6 +843,7 @@ class OpenAIServing:
|
||||
input_text: str,
|
||||
) -> TokensPrompt:
|
||||
token_num = len(input_ids)
|
||||
max_model_len = self.model_config.max_model_len
|
||||
|
||||
# Note: EmbeddingRequest, ClassificationRequest,
|
||||
# and ScoreRequest doesn't have max_tokens
|
||||
@@ -862,7 +862,7 @@ class OpenAIServing:
|
||||
):
|
||||
# Note: input length can be up to the entire model context length
|
||||
# since these requests don't generate tokens.
|
||||
if token_num > self.max_model_len:
|
||||
if token_num > max_model_len:
|
||||
operations: dict[type[AnyRequest], str] = {
|
||||
ScoreDataRequest: "score",
|
||||
ScoreTextRequest: "score",
|
||||
@@ -873,7 +873,7 @@ class OpenAIServing:
|
||||
operation = operations.get(type(request), "embedding generation")
|
||||
raise VLLMValidationError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{max_model_len} tokens. However, you requested "
|
||||
f"{token_num} tokens in the input for {operation}. "
|
||||
f"Please reduce the length of the input.",
|
||||
parameter="input_tokens",
|
||||
@@ -898,22 +898,22 @@ class OpenAIServing:
|
||||
|
||||
# Note: input length can be up to model context length - 1 for
|
||||
# completion-like requests.
|
||||
if token_num >= self.max_model_len:
|
||||
if token_num >= max_model_len:
|
||||
raise VLLMValidationError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, your request has "
|
||||
f"{max_model_len} tokens. However, your request has "
|
||||
f"{token_num} input tokens. Please reduce the length of "
|
||||
"the input messages.",
|
||||
parameter="input_tokens",
|
||||
value=token_num,
|
||||
)
|
||||
|
||||
if max_tokens is not None and token_num + max_tokens > self.max_model_len:
|
||||
if max_tokens is not None and token_num + max_tokens > max_model_len:
|
||||
raise VLLMValidationError(
|
||||
"'max_tokens' or 'max_completion_tokens' is too large: "
|
||||
f"{max_tokens}. This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens and your request has "
|
||||
f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
|
||||
f"{max_model_len} tokens and your request has "
|
||||
f"{token_num} input tokens ({max_tokens} > {max_model_len}"
|
||||
f" - {token_num}).",
|
||||
parameter="max_tokens",
|
||||
value=max_tokens,
|
||||
@@ -1089,6 +1089,7 @@ class OpenAIServing:
|
||||
priority: int = 0,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
):
|
||||
max_model_len = self.model_config.max_model_len
|
||||
prompt_text = self._extract_prompt_text(engine_prompt)
|
||||
|
||||
orig_priority = priority
|
||||
@@ -1148,7 +1149,7 @@ class OpenAIServing:
|
||||
token_ids = context.render_for_completion()
|
||||
engine_prompt = TokensPrompt(prompt_token_ids=token_ids)
|
||||
|
||||
sampling_params.max_tokens = self.max_model_len - len(token_ids)
|
||||
sampling_params.max_tokens = max_model_len - len(token_ids)
|
||||
elif isinstance(context, ParsableContext):
|
||||
engine_prompts = await self._render_next_turn(
|
||||
context.request,
|
||||
@@ -1162,7 +1163,7 @@ class OpenAIServing:
|
||||
prompt_text = self._extract_prompt_text(engine_prompt)
|
||||
|
||||
sampling_params.max_tokens = get_max_tokens(
|
||||
self.max_model_len,
|
||||
max_model_len,
|
||||
context.request.max_output_tokens,
|
||||
self._extract_prompt_len(engine_prompt),
|
||||
self.default_sampling_params, # type: ignore
|
||||
|
||||
@@ -59,11 +59,10 @@ class OpenAIServingModels:
|
||||
)
|
||||
self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock)
|
||||
|
||||
self.input_processor = self.engine_client.input_processor
|
||||
self.io_processor = self.engine_client.io_processor
|
||||
self.renderer = self.engine_client.renderer
|
||||
self.model_config = self.engine_client.model_config
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
self.renderer = self.engine_client.renderer
|
||||
self.io_processor = self.engine_client.io_processor
|
||||
self.input_processor = self.engine_client.input_processor
|
||||
|
||||
async def init_static_loras(self):
|
||||
"""Loads all static LoRA modules.
|
||||
@@ -96,12 +95,13 @@ class OpenAIServingModels:
|
||||
return self.base_model_paths[0].name
|
||||
|
||||
async def show_available_models(self) -> ModelList:
|
||||
"""Show available models. This includes the base model and all
|
||||
adapters"""
|
||||
"""Show available models. This includes the base model and all adapters."""
|
||||
max_model_len = self.model_config.max_model_len
|
||||
|
||||
model_cards = [
|
||||
ModelCard(
|
||||
id=base_model.name,
|
||||
max_model_len=self.max_model_len,
|
||||
max_model_len=max_model_len,
|
||||
root=base_model.model_path,
|
||||
permission=[ModelPermission()],
|
||||
)
|
||||
|
||||
@@ -296,10 +296,12 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
) -> ErrorResponse | None:
|
||||
"""Add validations to the input to the generator here."""
|
||||
prompt_len = self._extract_prompt_len(engine_prompt)
|
||||
if self.max_model_len <= prompt_len:
|
||||
max_model_len = self.model_config.max_model_len
|
||||
|
||||
if prompt_len >= max_model_len:
|
||||
error_message = (
|
||||
f"The engine prompt length {prompt_len} "
|
||||
f"exceeds the max_model_len {self.max_model_len}. "
|
||||
f"exceeds the max_model_len {max_model_len}. "
|
||||
"Please reduce prompt."
|
||||
)
|
||||
return self.create_error_response(
|
||||
@@ -414,6 +416,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
max_model_len = self.model_config.max_model_len
|
||||
generators: list[AsyncGenerator[ConversationContext, None]] = []
|
||||
|
||||
builtin_tool_list: list[str] = []
|
||||
@@ -431,8 +434,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
assert len(builtin_tool_list) == 0
|
||||
available_tools = []
|
||||
try:
|
||||
renderer = self.engine_client.renderer
|
||||
tokenizer = renderer.get_tokenizer()
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
|
||||
for engine_prompt in engine_prompts:
|
||||
maybe_error = self._validate_generator_input(engine_prompt)
|
||||
@@ -440,7 +442,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
return maybe_error
|
||||
|
||||
default_max_tokens = get_max_tokens(
|
||||
self.max_model_len,
|
||||
max_model_len,
|
||||
request.max_output_tokens,
|
||||
self._extract_prompt_len(engine_prompt),
|
||||
self.default_sampling_params,
|
||||
|
||||
@@ -69,16 +69,8 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
self.trust_request_chat_template = trust_request_chat_template
|
||||
|
||||
pooler_config = self.model_config.pooler_config
|
||||
|
||||
# Avoid repeated attribute lookups
|
||||
self.supports_chunked_processing = bool(
|
||||
pooler_config and pooler_config.enable_chunked_processing
|
||||
)
|
||||
self.max_embed_len = (
|
||||
pooler_config.max_embed_len
|
||||
if pooler_config and pooler_config.max_embed_len
|
||||
else None
|
||||
)
|
||||
assert pooler_config is not None
|
||||
self.pooler_config = pooler_config
|
||||
|
||||
async def _preprocess(
|
||||
self,
|
||||
@@ -240,7 +232,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
"""Check if chunked processing should be used for this request."""
|
||||
return (
|
||||
isinstance(request, (EmbeddingCompletionRequest, EmbeddingChatRequest))
|
||||
and self.supports_chunked_processing
|
||||
and self.pooler_config.enable_chunked_processing
|
||||
)
|
||||
|
||||
async def _process_chunked_request(
|
||||
@@ -310,14 +302,14 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
max_pos_embeddings = self._get_max_position_embeddings()
|
||||
|
||||
# Determine the effective max length for validation
|
||||
if self.max_embed_len is not None:
|
||||
if self.pooler_config.max_embed_len:
|
||||
# Use max_embed_len for validation instead of max_model_len
|
||||
length_type = "maximum embedding input length"
|
||||
max_length_value = self.max_embed_len
|
||||
max_length_value = self.pooler_config.max_embed_len
|
||||
else:
|
||||
# Fall back to max_model_len validation (original behavior)
|
||||
length_type = "maximum context length"
|
||||
max_length_value = self.max_model_len
|
||||
max_length_value = self.model_config.max_model_len
|
||||
|
||||
validation_error_msg = (
|
||||
"This model's {length_type} is {max_length_value} tokens. "
|
||||
|
||||
@@ -117,7 +117,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
tokens=input_ids,
|
||||
token_strs=token_strs,
|
||||
count=len(input_ids),
|
||||
max_model_len=self.max_model_len,
|
||||
max_model_len=self.model_config.max_model_len,
|
||||
)
|
||||
|
||||
async def create_detokenize(
|
||||
|
||||
@@ -16,7 +16,7 @@ from vllm.multimodal.inputs import (
|
||||
MultiModalUUIDDict,
|
||||
)
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||
from vllm.renderers import renderer_from_config
|
||||
from vllm.renderers import BaseRenderer, renderer_from_config
|
||||
from vllm.renderers.inputs import (
|
||||
DecoderDictPrompt,
|
||||
DecoderOnlyDictPrompt,
|
||||
@@ -56,6 +56,7 @@ class InputPreprocessor:
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
observability_config: ObservabilityConfig | None = None,
|
||||
renderer: BaseRenderer | None = None,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
mm_processor_cache: BaseMultiModalProcessorCache | None = None,
|
||||
) -> None:
|
||||
@@ -63,7 +64,7 @@ class InputPreprocessor:
|
||||
|
||||
self.model_config = model_config
|
||||
self.observability_config = observability_config
|
||||
self.renderer = renderer_from_config(model_config)
|
||||
self.renderer = renderer or renderer_from_config(model_config)
|
||||
self.mm_registry = mm_registry
|
||||
self.mm_processor_cache = mm_processor_cache
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import BaseRenderer, merge_kwargs
|
||||
from vllm.renderers import merge_kwargs, renderer_from_config
|
||||
from vllm.renderers.inputs import DictPrompt, TokPrompt
|
||||
from vllm.renderers.inputs.preprocess import extract_prompt_components
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
@@ -110,9 +110,10 @@ class AsyncLLM(EngineClient):
|
||||
# Ensure we can serialize custom transformer configs
|
||||
maybe_register_config_serialize_by_value()
|
||||
|
||||
self.model_config = vllm_config.model_config
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
tracing_endpoint = self.observability_config.otlp_traces_endpoint
|
||||
if tracing_endpoint is not None:
|
||||
init_tracer("vllm.llm_engine", tracing_endpoint)
|
||||
@@ -131,20 +132,22 @@ class AsyncLLM(EngineClient):
|
||||
"enabling logging without default stat loggers."
|
||||
)
|
||||
|
||||
self.input_processor = InputProcessor(self.vllm_config)
|
||||
self.renderer = renderer = renderer_from_config(self.model_config)
|
||||
self.io_processor = get_io_processor(
|
||||
self.vllm_config,
|
||||
self.model_config.io_processor_plugin,
|
||||
)
|
||||
|
||||
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
|
||||
# Convert TokPrompt --> EngineCoreRequest.
|
||||
self.input_processor = InputProcessor(self.vllm_config, renderer)
|
||||
|
||||
# Converts EngineCoreOutputs --> RequestOutput.
|
||||
self.output_processor = OutputProcessor(
|
||||
self.tokenizer,
|
||||
renderer.tokenizer,
|
||||
log_stats=self.log_stats,
|
||||
stream_interval=self.vllm_config.scheduler_config.stream_interval,
|
||||
tracing_enabled=tracing_endpoint is not None,
|
||||
)
|
||||
if tracing_endpoint is not None:
|
||||
self.output_processor.tracing_enabled = True
|
||||
|
||||
# EngineCore (starts the engine in background process).
|
||||
self.engine_core = EngineCoreClient.make_async_mp_client(
|
||||
@@ -891,17 +894,13 @@ class AsyncLLM(EngineClient):
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> TokenizerLike | None:
|
||||
return self.input_processor.tokenizer
|
||||
return self.renderer.tokenizer
|
||||
|
||||
def get_tokenizer(self) -> TokenizerLike:
|
||||
return self.input_processor.get_tokenizer()
|
||||
|
||||
@property
|
||||
def renderer(self) -> BaseRenderer:
|
||||
return self.input_processor.renderer
|
||||
return self.renderer.get_tokenizer()
|
||||
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
return self.observability_config.otlp_traces_endpoint is not None # type: ignore
|
||||
return self.observability_config.otlp_traces_endpoint is not None
|
||||
|
||||
async def do_log_stats(self) -> None:
|
||||
if self.logger_manager:
|
||||
|
||||
@@ -27,7 +27,7 @@ from vllm.multimodal.parse import ModalityDataItems, MultiModalDataItems
|
||||
from vllm.multimodal.processing.context import set_request_id
|
||||
from vllm.multimodal.utils import argsort_mm_positions
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.renderers import BaseRenderer, renderer_from_config
|
||||
from vllm.renderers.inputs import DictPrompt, TokPrompt
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tasks import POOLING_TASKS, SupportedTask
|
||||
@@ -44,6 +44,8 @@ class InputProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
renderer: BaseRenderer | None = None,
|
||||
*,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
@@ -57,6 +59,7 @@ class InputProcessor:
|
||||
|
||||
self.generation_config_fields = model_config.try_get_generation_config()
|
||||
|
||||
self.renderer = renderer or renderer_from_config(model_config)
|
||||
self.mm_registry = mm_registry
|
||||
self.mm_processor_cache = mm_registry.processor_cache_from_config(vllm_config)
|
||||
|
||||
@@ -74,20 +77,17 @@ class InputProcessor:
|
||||
self.input_preprocessor = InputPreprocessor(
|
||||
model_config,
|
||||
self.observability_config,
|
||||
mm_registry,
|
||||
renderer=renderer,
|
||||
mm_registry=mm_registry,
|
||||
mm_processor_cache=self.mm_processor_cache,
|
||||
)
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> TokenizerLike | None:
|
||||
return self.input_preprocessor.tokenizer
|
||||
return self.renderer.tokenizer
|
||||
|
||||
def get_tokenizer(self) -> TokenizerLike:
|
||||
return self.input_preprocessor.get_tokenizer()
|
||||
|
||||
@property
|
||||
def renderer(self) -> BaseRenderer:
|
||||
return self.input_preprocessor.renderer
|
||||
return self.renderer.get_tokenizer()
|
||||
|
||||
def _validate_params(
|
||||
self,
|
||||
|
||||
@@ -21,7 +21,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.renderers import renderer_from_config
|
||||
from vllm.renderers.inputs import DictPrompt, TokPrompt
|
||||
from vllm.renderers.inputs.preprocess import extract_prompt_components
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@@ -62,9 +62,12 @@ class LLMEngine:
|
||||
multiprocess_mode: bool = False,
|
||||
) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
tracing_endpoint = self.observability_config.otlp_traces_endpoint
|
||||
if tracing_endpoint is not None:
|
||||
init_tracer("vllm.llm_engine", tracing_endpoint)
|
||||
|
||||
self.log_stats = log_stats
|
||||
|
||||
@@ -87,22 +90,22 @@ class LLMEngine:
|
||||
self.dp_group = None
|
||||
self.should_execute_dummy_batch = False
|
||||
|
||||
self.input_processor = InputProcessor(self.vllm_config)
|
||||
self.renderer = renderer = renderer_from_config(self.model_config)
|
||||
self.io_processor = get_io_processor(
|
||||
self.vllm_config,
|
||||
self.model_config.io_processor_plugin,
|
||||
)
|
||||
|
||||
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
|
||||
# Convert TokPrompt --> EngineCoreRequest.
|
||||
self.input_processor = InputProcessor(self.vllm_config, renderer)
|
||||
|
||||
# Converts EngineCoreOutputs --> RequestOutput.
|
||||
self.output_processor = OutputProcessor(
|
||||
self.tokenizer,
|
||||
renderer.tokenizer,
|
||||
log_stats=self.log_stats,
|
||||
stream_interval=self.vllm_config.scheduler_config.stream_interval,
|
||||
tracing_enabled=tracing_endpoint is not None,
|
||||
)
|
||||
endpoint = self.observability_config.otlp_traces_endpoint
|
||||
if endpoint is not None:
|
||||
init_tracer("vllm.llm_engine", endpoint)
|
||||
self.output_processor.tracing_enabled = True
|
||||
|
||||
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
|
||||
self.engine_core = EngineCoreClient.make_client(
|
||||
@@ -365,14 +368,10 @@ class LLMEngine:
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> TokenizerLike | None:
|
||||
return self.input_processor.tokenizer
|
||||
return self.renderer.tokenizer
|
||||
|
||||
def get_tokenizer(self) -> TokenizerLike:
|
||||
return self.input_processor.get_tokenizer()
|
||||
|
||||
@property
|
||||
def renderer(self) -> BaseRenderer:
|
||||
return self.input_processor.renderer
|
||||
return self.renderer.get_tokenizer()
|
||||
|
||||
def do_log_stats(self) -> None:
|
||||
"""Log stats if logging is enabled."""
|
||||
|
||||
@@ -417,8 +417,10 @@ class OutputProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: TokenizerLike | None,
|
||||
*,
|
||||
log_stats: bool,
|
||||
stream_interval: int = 1,
|
||||
tracing_enabled: bool = False,
|
||||
):
|
||||
self.log_stats = log_stats
|
||||
self.tokenizer = tokenizer
|
||||
@@ -427,7 +429,7 @@ class OutputProcessor:
|
||||
self.parent_requests: dict[str, ParentRequest] = {}
|
||||
self.external_req_ids: defaultdict[str, list[str]] = defaultdict(list)
|
||||
self.lora_states = LoRARequestStates(log_stats)
|
||||
self.tracing_enabled: bool = False
|
||||
self.tracing_enabled = tracing_enabled
|
||||
self._requests_drained = asyncio.Event()
|
||||
self._requests_drained.set()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user