diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 8a7892cf6..10879f0be 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -282,9 +282,11 @@ async def test_serving_chat_could_load_correct_generation_config(): assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 +@pytest.mark.parametrize("model_type", ["gpt_oss", "any"]) @pytest.mark.asyncio -async def test_serving_chat_did_set_correct_cache_salt(): +async def test_serving_chat_did_set_correct_cache_salt(model_type): mock_model_config = MockModelConfig() + mock_model_config.hf_config.model_type = model_type mock_engine = MagicMock(spec=MQLLMEngineClient) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 1789521af..d57868847 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1483,4 +1483,9 @@ class OpenAIServingChat(OpenAIServing): # Render prompt token ids. prompt_token_ids = render_for_completion(messages) engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) + + # Add cache_salt if provided in the request + if request.cache_salt is not None: + engine_prompt["cache_salt"] = request.cache_salt + return messages, [prompt_token_ids], [engine_prompt] diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 86c16df40..1b30fa01e 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -408,6 +408,11 @@ class OpenAIServingResponses(OpenAIServing): request, prev_response) prompt_token_ids = render_for_completion(messages) engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) + + # Add cache_salt if provided in the request + if request.cache_salt is not None: + engine_prompt["cache_salt"] = request.cache_salt + return messages, [prompt_token_ids], [engine_prompt] async def responses_full_generator(