Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -48,10 +48,9 @@ def create_mock_request_output(
|
||||
)
|
||||
|
||||
|
||||
async def generate_mock_outputs(num_turns,
|
||||
prompt_token_counts,
|
||||
output_token_counts,
|
||||
cached_token_counts=None):
|
||||
async def generate_mock_outputs(
|
||||
num_turns, prompt_token_counts, output_token_counts, cached_token_counts=None
|
||||
):
|
||||
"""Generate a sequence of mock RequestOutput objects to simulate multiple
|
||||
turns."""
|
||||
if cached_token_counts is None:
|
||||
@@ -73,8 +72,9 @@ async def generate_mock_outputs(num_turns,
|
||||
@pytest.fixture
|
||||
def mock_parser():
|
||||
"""Set up a mock parser for tests."""
|
||||
with patch("vllm.entrypoints.context.get_streamable_parser_for_assistant"
|
||||
) as mock_parser_factory:
|
||||
with patch(
|
||||
"vllm.entrypoints.context.get_streamable_parser_for_assistant"
|
||||
) as mock_parser_factory:
|
||||
# Create a mock parser object
|
||||
parser = MagicMock()
|
||||
parser.messages = []
|
||||
@@ -124,9 +124,9 @@ async def test_multi_turn_token_counting():
|
||||
prompt_token_counts = [5, 15, 20]
|
||||
output_token_counts = [3, 4, 5]
|
||||
cached_token_counts = [0, 5, 15]
|
||||
mock_generator = generate_mock_outputs(3, prompt_token_counts,
|
||||
output_token_counts,
|
||||
cached_token_counts)
|
||||
mock_generator = generate_mock_outputs(
|
||||
3, prompt_token_counts, output_token_counts, cached_token_counts
|
||||
)
|
||||
|
||||
# First turn - initial prompt and response
|
||||
mock_output1 = await async_next(mock_generator)
|
||||
@@ -251,7 +251,7 @@ async def test_single_turn_no_tool_output():
|
||||
"""Test that first turn never generates tool output tokens."""
|
||||
context = HarmonyContext(
|
||||
messages=[],
|
||||
available_tools=["browser"] # Tools available
|
||||
available_tools=["browser"], # Tools available
|
||||
)
|
||||
|
||||
# Even with large prompt in first turn, no tool tokens should be counted
|
||||
@@ -333,21 +333,24 @@ async def test_streaming_multi_turn_token_counting(mock_parser):
|
||||
output_token_ids=[101], # Single token
|
||||
num_cached_tokens=0,
|
||||
finished=False, # Not end of message yet
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
# Second token of first turn
|
||||
context.append_output(
|
||||
create_mock_request_output(
|
||||
output_token_ids=[102],
|
||||
finished=False,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
# Last token of first turn (finished=True signals end of message)
|
||||
context.append_output(
|
||||
create_mock_request_output(
|
||||
output_token_ids=[103],
|
||||
finished=True, # End of message
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
# Check token counts after first turn
|
||||
assert context.num_prompt_tokens == 3 # Initial prompt tokens
|
||||
@@ -362,25 +365,36 @@ async def test_streaming_multi_turn_token_counting(mock_parser):
|
||||
# First token of second turn
|
||||
context.append_output(
|
||||
create_mock_request_output(
|
||||
prompt_token_ids=[1, 2, 3, 101, 102, 103, 4,
|
||||
5], # 8 tokens (includes previous)
|
||||
prompt_token_ids=[
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
101,
|
||||
102,
|
||||
103,
|
||||
4,
|
||||
5,
|
||||
], # 8 tokens (includes previous)
|
||||
output_token_ids=[201],
|
||||
num_cached_tokens=3, # Some tokens cached
|
||||
finished=False,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
# More tokens in reasoning channel
|
||||
context.append_output(
|
||||
create_mock_request_output(
|
||||
output_token_ids=[202],
|
||||
finished=False,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
context.append_output(
|
||||
create_mock_request_output(
|
||||
output_token_ids=[203],
|
||||
finished=True, # End of reasoning message
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
# Check counts after second turn (reasoning message)
|
||||
assert context.num_prompt_tokens == 3 + 8 # Initial + second prompt
|
||||
@@ -399,18 +413,32 @@ async def test_streaming_multi_turn_token_counting(mock_parser):
|
||||
context.append_output(
|
||||
create_mock_request_output(
|
||||
prompt_token_ids=[
|
||||
1, 2, 3, 101, 102, 103, 4, 5, 201, 202, 203, 6, 7
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
101,
|
||||
102,
|
||||
103,
|
||||
4,
|
||||
5,
|
||||
201,
|
||||
202,
|
||||
203,
|
||||
6,
|
||||
7,
|
||||
], # 13 tokens
|
||||
output_token_ids=[301],
|
||||
num_cached_tokens=8, # More cached tokens
|
||||
finished=False,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
context.append_output(
|
||||
create_mock_request_output(
|
||||
output_token_ids=[302],
|
||||
finished=True,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
# Final token counts check
|
||||
assert context.num_prompt_tokens == 3 + 8 + 13 # All prompts
|
||||
@@ -421,8 +449,9 @@ async def test_streaming_multi_turn_token_counting(mock_parser):
|
||||
# Additional tool tokens from third turn
|
||||
# Formula: this turn prompt - last turn prompt - last turn output
|
||||
additional_tool_tokens = 13 - 8 - 3 # = 2
|
||||
assert context.num_tool_output_tokens == expected_tool_tokens \
|
||||
+ additional_tool_tokens
|
||||
assert (
|
||||
context.num_tool_output_tokens == expected_tool_tokens + additional_tool_tokens
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -442,8 +471,7 @@ async def test_streaming_message_synchronization(mock_parser):
|
||||
recipient=Role.ASSISTANT,
|
||||
)
|
||||
]
|
||||
context = StreamingHarmonyContext(messages=initial_messages,
|
||||
available_tools=[])
|
||||
context = StreamingHarmonyContext(messages=initial_messages, available_tools=[])
|
||||
|
||||
# Verify initial state
|
||||
assert len(context._messages) == 1
|
||||
@@ -461,9 +489,10 @@ async def test_streaming_message_synchronization(mock_parser):
|
||||
|
||||
# This should trigger the message synchronization logic
|
||||
context.append_output(
|
||||
create_mock_request_output(prompt_token_ids=[1, 2, 3],
|
||||
output_token_ids=[101],
|
||||
finished=False))
|
||||
create_mock_request_output(
|
||||
prompt_token_ids=[1, 2, 3], output_token_ids=[101], finished=False
|
||||
)
|
||||
)
|
||||
|
||||
# Verify that messages were synchronized
|
||||
assert len(context._messages) == 2
|
||||
@@ -485,12 +514,13 @@ async def test_streaming_message_synchronization(mock_parser):
|
||||
author=Author(role=Role.ASSISTANT, name="assistant"),
|
||||
content=[TextContent(text="Response 4")],
|
||||
recipient=Role.USER,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
# Create another output to trigger synchronization again
|
||||
mock_output2 = create_mock_request_output(prompt_token_ids=[1, 2, 3],
|
||||
output_token_ids=[102],
|
||||
finished=True)
|
||||
mock_output2 = create_mock_request_output(
|
||||
prompt_token_ids=[1, 2, 3], output_token_ids=[102], finished=True
|
||||
)
|
||||
|
||||
context.append_output(mock_output2)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user