[Misc] Reorganize inputs (#35182)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -105,7 +105,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
)
|
||||
|
||||
async def _fake_preprocess_chat(*args, **kwargs):
|
||||
# return conversation, engine_prompts
|
||||
# return conversation, engine_inputs
|
||||
return (
|
||||
[{"role": "user", "content": "Test"}],
|
||||
[{"prompt_token_ids": [1, 2, 3]}],
|
||||
|
||||
@@ -958,14 +958,14 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
orig_render_chat_request = serving_chat.render_chat_request
|
||||
captured_prompts = []
|
||||
captured_inputs = []
|
||||
|
||||
async def render_chat_request(request):
|
||||
result = await orig_render_chat_request(request)
|
||||
|
||||
assert isinstance(result, tuple)
|
||||
conversation, engine_prompts = result
|
||||
captured_prompts.extend(engine_prompts)
|
||||
conversation, engine_inputs = result
|
||||
captured_inputs.extend(engine_inputs)
|
||||
|
||||
return result
|
||||
|
||||
@@ -981,18 +981,18 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
|
||||
with suppress(Exception):
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert len(captured_prompts) == 1
|
||||
assert "cache_salt" not in captured_prompts[0]
|
||||
assert len(captured_inputs) == 1
|
||||
assert "cache_salt" not in captured_inputs[0]
|
||||
|
||||
captured_prompts.clear()
|
||||
captured_inputs.clear()
|
||||
|
||||
# Test with certain cache_salt
|
||||
req.cache_salt = "test_salt"
|
||||
with suppress(Exception):
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert len(captured_prompts) == 1
|
||||
assert captured_prompts[0]["cache_salt"] == "test_salt"
|
||||
assert len(captured_inputs) == 1
|
||||
assert captured_inputs[0]["cache_salt"] == "test_salt"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -37,7 +37,7 @@ from vllm.entrypoints.openai.responses.serving import (
|
||||
from vllm.entrypoints.openai.responses.streaming_events import (
|
||||
StreamingState,
|
||||
)
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.inputs import tokens_input
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
@@ -258,20 +258,20 @@ class TestValidateGeneratorInput:
|
||||
"""Test _validate_generator_input with valid prompt length"""
|
||||
# Create an engine prompt with valid length (less than max_model_len)
|
||||
valid_prompt_token_ids = list(range(5)) # 5 tokens < 100 max_model_len
|
||||
engine_prompt = TokensPrompt(prompt_token_ids=valid_prompt_token_ids)
|
||||
engine_input = tokens_input(valid_prompt_token_ids)
|
||||
|
||||
# Call the method
|
||||
result = serving_responses_instance._validate_generator_input(engine_prompt)
|
||||
result = serving_responses_instance._validate_generator_input(engine_input)
|
||||
|
||||
# Should return None for valid input
|
||||
assert result is None
|
||||
|
||||
# create an invalid engine prompt
|
||||
invalid_prompt_token_ids = list(range(200)) # 100 tokens >= 100 max_model_len
|
||||
engine_prompt = TokensPrompt(prompt_token_ids=invalid_prompt_token_ids)
|
||||
engine_input = tokens_input(invalid_prompt_token_ids)
|
||||
|
||||
# Call the method
|
||||
result = serving_responses_instance._validate_generator_input(engine_prompt)
|
||||
result = serving_responses_instance._validate_generator_input(engine_input)
|
||||
|
||||
# Should return an ErrorResponse
|
||||
assert result is not None
|
||||
|
||||
Reference in New Issue
Block a user