diff --git a/tests/renderers/test_completions.py b/tests/renderers/test_completions.py index 7e33a1f9f..17274ccd7 100644 --- a/tests/renderers/test_completions.py +++ b/tests/renderers/test_completions.py @@ -10,7 +10,6 @@ import pybase64 import pytest import torch -from vllm.inputs.data import is_embeds_prompt from vllm.renderers import TokenizeParams from vllm.renderers.hf import HfRenderer from vllm.tokenizers.registry import tokenizer_args_from_config @@ -320,7 +319,6 @@ class TestRenderEmbedPrompt: ) assert len(results) == 1 - assert is_embeds_prompt(results[0]) assert torch.allclose(results[0]["prompt_embeds"], test_tensor) @pytest.mark.asyncio @@ -342,7 +340,6 @@ class TestRenderEmbedPrompt: assert len(results) == 2 for i, result in enumerate(results): - assert is_embeds_prompt(result) assert torch.allclose(result["prompt_embeds"], test_tensors[i]) @pytest.mark.asyncio @@ -420,7 +417,7 @@ class TestRenderEmbedPrompt: assert len(results) == 2 # First should be embed prompt - assert is_embeds_prompt(results[0]) + assert torch.allclose(results[0]["prompt_embeds"], test_tensor) # Second should be tokens prompt assert "prompt_token_ids" in results[1] assert results[1]["prompt_token_ids"] == [101, 102, 103] diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 0df62f512..e618b11ad 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -68,7 +68,6 @@ from vllm.entrypoints.openai.parser.harmony_utils import ( from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.inputs.data import EmbedsPrompt, TokensPrompt -from vllm.inputs.parse import get_prompt_components from vllm.logger import init_logger from vllm.logprobs import Logprob from vllm.outputs import CompletionOutput, RequestOutput @@ -359,7 +358,7 @@ class OpenAIServingChat(OpenAIServing): generators: list[AsyncGenerator[RequestOutput, None]] = [] try: for i, engine_prompt in enumerate(engine_prompts): - prompt_text, _, _ = get_prompt_components(engine_prompt) + prompt_text = engine_prompt.get("prompt") # If we are creating sub requests for multiple prompts, ensure that they # have unique request ids. diff --git a/vllm/entrypoints/openai/completion/serving.py b/vllm/entrypoints/openai/completion/serving.py index dc59d5248..8981b8662 100644 --- a/vllm/entrypoints/openai/completion/serving.py +++ b/vllm/entrypoints/openai/completion/serving.py @@ -34,8 +34,7 @@ from vllm.entrypoints.openai.engine.serving import ( from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.exceptions import VLLMValidationError -from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt -from vllm.inputs.parse import get_prompt_components +from vllm.inputs.data import EmbedsPrompt, TokensPrompt from vllm.logger import init_logger from vllm.logprobs import Logprob from vllm.outputs import RequestOutput @@ -161,7 +160,7 @@ class OpenAIServingCompletion(OpenAIServing): generators: list[AsyncGenerator[RequestOutput, None]] = [] try: for i, engine_prompt in enumerate(engine_prompts): - prompt_text, _, _ = get_prompt_components(engine_prompt) + prompt_text = engine_prompt.get("prompt") max_tokens = get_max_tokens( max_model_len=self.max_model_len, @@ -278,11 +277,7 @@ class OpenAIServingCompletion(OpenAIServing): # with the inputs token IDs if final_res.prompt is None: engine_prompt = engine_prompts[i] - final_res.prompt = ( - None - if is_embeds_prompt(engine_prompt) - else engine_prompt.get("prompt") - ) + final_res.prompt = engine_prompt.get("prompt") final_res_batch_checked = cast(list[RequestOutput], final_res_batch) @@ -352,11 +347,7 @@ class OpenAIServingCompletion(OpenAIServing): prompt_text = res.prompt if prompt_text is None: engine_prompt = engine_prompts[prompt_idx] - prompt_text = ( - None - if is_embeds_prompt(engine_prompt) - else engine_prompt.get("prompt") - ) + prompt_text = engine_prompt.get("prompt") # Prompt details are excluded from later streamed outputs if prompt_token_ids is not None: diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index df8e7b19f..7f9300a1a 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -1116,7 +1116,7 @@ class OpenAIServing: priority: int = 0, trace_headers: Mapping[str, str] | None = None, ): - prompt_text, _, _ = get_prompt_components(engine_prompt) + prompt_text = engine_prompt.get("prompt") orig_priority = priority sub_request = 0 @@ -1186,7 +1186,7 @@ class OpenAIServing: context.chat_template_content_format, ) engine_prompt = engine_prompts[0] - prompt_text, _, _ = get_prompt_components(engine_prompt) + prompt_text = engine_prompt.get("prompt") sampling_params.max_tokens = get_max_tokens( self.max_model_len, diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 92970fc39..315ffddde 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, cast import torch -from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar +from typing_extensions import NotRequired, TypedDict, TypeVar from vllm.sampling_params import SamplingParams @@ -77,6 +77,9 @@ class EmbedsPrompt(_CommonKeys): prompt_embeds: torch.Tensor """The embeddings of the prompt.""" + prompt: NotRequired[str] + """The prompt text corresponding to the token embeddings, if available.""" + class DataPrompt(_CommonKeys): """Represents generic inputs handled by IO processor plugins.""" @@ -113,22 +116,6 @@ more than one prompt, i.e. """ -def is_tokens_prompt(prompt: SingletonPrompt) -> TypeIs[TokensPrompt]: - return ( - isinstance(prompt, dict) - and "prompt_token_ids" in prompt - and "prompt_embeds" not in prompt - ) - - -def is_embeds_prompt(prompt: SingletonPrompt) -> TypeIs[EmbedsPrompt]: - return ( - isinstance(prompt, dict) - and "prompt_token_ids" not in prompt - and "prompt_embeds" in prompt - ) - - _T1_co = TypeVar( "_T1_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True ) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 4872992b6..0f1b7f46d 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -4,7 +4,7 @@ import time from collections.abc import Callable, Mapping from copy import copy -from typing import Any, cast +from typing import Any import torch.nn as nn from typing_extensions import TypeVar @@ -32,6 +32,7 @@ from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.input_processor import InputProcessor from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.parallel_sampling import ParentRequest +from vllm.v1.engine.utils import get_prompt_text from vllm.v1.executor import Executor from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager from vllm.v1.metrics.reader import Metric, get_metrics_snapshot @@ -245,10 +246,7 @@ class LLMEngine: trace_headers, priority, ) - if isinstance(prompt, str): - prompt_text = prompt - elif isinstance(prompt, Mapping): - prompt_text = cast(str | None, prompt.get("prompt")) + prompt_text = get_prompt_text(prompt) self.input_processor.assign_request_id(request) diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index cae613920..44465f4d0 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -4,12 +4,12 @@ import contextlib import os import weakref -from collections.abc import Callable, Iterator, Mapping +from collections.abc import Callable, Iterator from dataclasses import dataclass from enum import Enum, auto from multiprocessing import Process, connection from multiprocessing.process import BaseProcess -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING from unittest.mock import patch import msgspec @@ -17,6 +17,8 @@ import zmq from vllm import envs from vllm.config import CacheConfig, ParallelConfig, VllmConfig +from vllm.inputs import PromptType +from vllm.inputs.parse import get_prompt_components from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.ray.ray_env import get_env_vars_to_copy @@ -224,12 +226,8 @@ def get_device_indices( return value -def get_prompt_text(prompt: Any) -> str | None: - if isinstance(prompt, str): - return prompt - if isinstance(prompt, Mapping): - return cast(str | None, prompt.get("prompt")) - return None +def get_prompt_text(prompt: PromptType) -> str | None: + return get_prompt_components(prompt)[0] class CoreEngineActorManager: