diff --git a/tests/tool_use/test_chat_completions.py b/tests/tool_use/test_chat_completions.py index 07b7933f6..e5bb47587 100644 --- a/tests/tool_use/test_chat_completions.py +++ b/tests/tool_use/test_chat_completions.py @@ -6,6 +6,7 @@ import pytest from .utils import ( MESSAGES_WITHOUT_TOOLS, + SEED, WEATHER_TOOL, ServerConfig, ensure_system_prompt, @@ -27,6 +28,7 @@ async def test_chat_completion_without_tools( max_completion_tokens=150, model=model_name, logprobs=False, + seed=SEED, ) choice = chat_completion.choices[0] stop_reason = chat_completion.choices[0].finish_reason @@ -47,6 +49,7 @@ async def test_chat_completion_without_tools( max_completion_tokens=150, model=model_name, logprobs=False, + seed=SEED, stream=True, ) chunks: list[str] = [] @@ -97,6 +100,7 @@ async def test_chat_completion_with_tools( model=model_name, tools=[WEATHER_TOOL], logprobs=False, + seed=SEED, ) choice = chat_completion.choices[0] stop_reason = chat_completion.choices[0].finish_reason @@ -118,6 +122,7 @@ async def test_chat_completion_with_tools( model=model_name, logprobs=False, tools=[WEATHER_TOOL], + seed=SEED, stream=True, ) diff --git a/tests/tool_use/test_parallel_tool_calls.py b/tests/tool_use/test_parallel_tool_calls.py index 77084ec2d..ed8c80d36 100644 --- a/tests/tool_use/test_parallel_tool_calls.py +++ b/tests/tool_use/test_parallel_tool_calls.py @@ -10,6 +10,7 @@ from .utils import ( MESSAGES_ASKING_FOR_PARALLEL_TOOLS, MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, SEARCH_TOOL, + SEED, WEATHER_TOOL, ServerConfig, ) @@ -39,6 +40,7 @@ async def test_parallel_tool_calls( model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, + seed=SEED, ) choice = chat_completion.choices[0] @@ -76,6 +78,7 @@ async def test_parallel_tool_calls( max_completion_tokens=200, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, + seed=SEED, stream=True, ) @@ -166,6 +169,7 @@ async def test_parallel_tool_calls_with_results( model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, + seed=SEED, ) choice = chat_completion.choices[0] @@ -184,6 +188,7 @@ async def test_parallel_tool_calls_with_results( model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, + seed=SEED, stream=True, ) @@ -229,6 +234,7 @@ async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI): model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, + seed=SEED, parallel_tool_calls=False, ) @@ -247,6 +253,7 @@ async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI): max_completion_tokens=200, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, + seed=SEED, parallel_tool_calls=False, stream=True, ) diff --git a/tests/tool_use/test_tool_calls.py b/tests/tool_use/test_tool_calls.py index 6614b6415..f719a886c 100644 --- a/tests/tool_use/test_tool_calls.py +++ b/tests/tool_use/test_tool_calls.py @@ -10,6 +10,7 @@ from .utils import ( MESSAGES_ASKING_FOR_TOOLS, MESSAGES_WITH_TOOL_RESPONSE, SEARCH_TOOL, + SEED, WEATHER_TOOL, ) @@ -27,6 +28,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, + seed=SEED, ) choice = chat_completion.choices[0] @@ -71,6 +73,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): max_completion_tokens=100, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, + seed=SEED, stream=True, ) @@ -154,6 +157,7 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI): model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, + seed=SEED, ) choice = chat_completion.choices[0] @@ -171,6 +175,7 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI): model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, + seed=SEED, stream=True, ) diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index de7284a30..5a03f53ec 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -42,6 +42,8 @@ def ensure_system_prompt( # universal args for all models go here. also good if you need to test locally # and change type or KV cache quantization or something. +SEED = 42 + ARGS: list[str] = [ "--enable-auto-tool-choice", "--max-model-len",