diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 49b02279d..3feee01da 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -66,7 +66,7 @@ async def test_evil_forward(tmp_socket): with pytest.raises(MQEngineDeadError): async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), - request_id=uuid.uuid4()): + request_id=str(uuid.uuid4())): pass assert client.errored @@ -115,7 +115,7 @@ async def test_failed_health_check(tmp_socket): with pytest.raises(MQEngineDeadError): async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), - request_id=uuid.uuid4()): + request_id=str(uuid.uuid4())): pass client.close() @@ -157,7 +157,7 @@ async def test_failed_abort(tmp_socket): async for _ in client.generate( prompt="Hello my name is", sampling_params=SamplingParams(max_tokens=10), - request_id=uuid.uuid4()): + request_id=str(uuid.uuid4())): pass assert "KeyError" in repr(execinfo.value) assert client.errored @@ -189,7 +189,7 @@ async def test_batch_error(tmp_socket): params = SamplingParams(min_tokens=2048, max_tokens=2048) async for _ in client.generate(prompt="Hello my name is", sampling_params=params, - request_id=uuid.uuid4()): + request_id=str(uuid.uuid4())): pass tasks = [asyncio.create_task(do_generate(client)) for _ in range(10)] @@ -289,7 +289,7 @@ async def test_engine_process_death(tmp_socket): with pytest.raises(MQEngineDeadError): async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), - request_id=uuid.uuid4()): + request_id=str(uuid.uuid4())): pass # And the health check should show the engine is dead diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8fccf9bd2..25fa1c305 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -687,6 +687,10 @@ class LLMEngine: >>> # continue the request processing >>> ... """ + if not isinstance(request_id, str): + raise TypeError( + f"request_id must be a string, got {type(request_id)}") + if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 25fab2713..a2328c37b 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -192,6 +192,11 @@ class LLMEngine: prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> None: + # Validate the request_id type. + if not isinstance(request_id, str): + raise TypeError( + f"request_id must be a string, got {type(request_id)}") + # Process raw inputs into the request. prompt_str, request = self.processor.process_inputs( request_id, prompt, params, arrival_time, lora_request,