diff --git a/tests/v1/e2e/test_streaming_input.py b/tests/v1/e2e/test_streaming_input.py index a1eaa065a..4c9b43099 100644 --- a/tests/v1/e2e/test_streaming_input.py +++ b/tests/v1/e2e/test_streaming_input.py @@ -19,7 +19,7 @@ import pytest import pytest_asyncio from vllm import SamplingParams -from vllm.inputs.data import StreamingInput +from vllm.inputs import StreamingInput from vllm.outputs import RequestOutput from vllm.platforms import current_platform from vllm.sampling_params import RequestOutputKind diff --git a/tests/v1/streaming_input/test_async_llm_streaming.py b/tests/v1/streaming_input/test_async_llm_streaming.py index 992634387..b5ba757d0 100644 --- a/tests/v1/streaming_input/test_async_llm_streaming.py +++ b/tests/v1/streaming_input/test_async_llm_streaming.py @@ -7,7 +7,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest -from vllm.inputs.data import StreamingInput +from vllm.inputs import StreamingInput from vllm.outputs import RequestOutput from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.v1.engine.async_llm import AsyncLLM diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index d9aed70c9..0fdb3ab5e 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -12,6 +12,7 @@ from .data import ( PromptType, SingletonInputs, SingletonPrompt, + StreamingInput, TextPrompt, TokenInputs, TokensPrompt, @@ -41,4 +42,5 @@ __all__ = [ "build_explicit_enc_dec_prompt", "to_enc_dec_tuple_list", "zip_enc_dec_prompts", + "StreamingInput", ] diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index f1a3e341f..2beb9c4f8 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -16,8 +16,7 @@ from vllm import TokensPrompt from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.protocol import EngineClient -from vllm.inputs import PromptType -from vllm.inputs.data import StreamingInput +from vllm.inputs import PromptType, StreamingInput from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry @@ -461,6 +460,7 @@ class AsyncLLM(EngineClient): self._validate_streaming_input_sampling_params(sp) else: sp = sampling_params + # TODO(nick): Avoid re-validating reused sampling parameters req = self.input_processor.process_inputs( request_id=internal_req_id, prompt=input_chunk.prompt,