Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -46,29 +46,36 @@ class MinTokensTestCase:
self.expected_exact_len = expected_exact_len
def __str__(self):
return (f"{self.name}: min={self.min_tokens}, "
f"max={self.max_tokens}, stop={self.stop}")
return (
f"{self.name}: min={self.min_tokens}, "
f"max={self.max_tokens}, stop={self.stop}"
)
# Test scenarios covering all critical cases
MIN_TOKENS_TEST_CASES = [
# === BASIC FUNCTIONALITY (should work) ===
MinTokensTestCase(name="basic_min_tokens_no_stop",
min_tokens=8,
max_tokens=20,
stop=None,
expected_min_len=8),
MinTokensTestCase(name="min_tokens_zero",
min_tokens=0,
max_tokens=10,
stop=None,
expected_min_len=0),
MinTokensTestCase(name="min_equals_max_no_stop",
min_tokens=15,
max_tokens=15,
stop=None,
expected_exact_len=15),
MinTokensTestCase(
name="basic_min_tokens_no_stop",
min_tokens=8,
max_tokens=20,
stop=None,
expected_min_len=8,
),
MinTokensTestCase(
name="min_tokens_zero",
min_tokens=0,
max_tokens=10,
stop=None,
expected_min_len=0,
),
MinTokensTestCase(
name="min_equals_max_no_stop",
min_tokens=15,
max_tokens=15,
stop=None,
expected_exact_len=15,
),
# === STOP STRINGS WITH MIN_TOKENS ===
# These tests expose the detokenizer bug where stop strings
# bypass min_tokens
@@ -94,9 +101,11 @@ MIN_TOKENS_TEST_CASES = [
expected_min_len=5,
),
marks=pytest.mark.xfail(
reason=("Known bug #21987: stop strings bypass min_tokens "
"(fixed by PR #22014)"),
strict=False),
reason=(
"Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)"
),
strict=False,
),
id="min_tokens_with_comprehensive_stops",
),
pytest.param(
@@ -108,12 +117,13 @@ MIN_TOKENS_TEST_CASES = [
expected_min_len=3,
),
marks=pytest.mark.xfail(
reason=("Known bug #21987: stop strings bypass min_tokens "
"(fixed by PR #22014)"),
strict=False),
reason=(
"Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)"
),
strict=False,
),
id="min_tokens_with_simple_char_stop",
),
# === EOS TOKEN WITH MIN_TOKENS (potential LogitsProcessor bug) ===
# These test the MinTokensLogitsProcessor handling of EOS tokens
pytest.param(
@@ -125,26 +135,26 @@ MIN_TOKENS_TEST_CASES = [
expected_exact_len=20,
),
marks=pytest.mark.xfail(
reason=
("Potential logits-processor bug: EOS tokens may bypass min_tokens"
),
reason=("Potential logits-processor bug: EOS tokens may bypass min_tokens"),
strict=False,
),
id="min_equals_max_eos_only",
),
# === EDGE CASES ===
MinTokensTestCase(name="large_min_tokens",
min_tokens=50,
max_tokens=60,
stop=None,
expected_min_len=50),
MinTokensTestCase(
name="large_min_tokens",
min_tokens=50,
max_tokens=60,
stop=None,
expected_min_len=50,
),
MinTokensTestCase(
name="min_tokens_with_empty_stop_list",
min_tokens=5,
max_tokens=15,
stop=[], # Empty stop list
expected_min_len=5),
expected_min_len=5,
),
]
@@ -170,25 +180,27 @@ def get_token_count(output: RequestOutput) -> int:
return len(output.outputs[0].token_ids)
def assert_min_tokens_satisfied(output: RequestOutput,
test_case: MinTokensTestCase) -> None:
def assert_min_tokens_satisfied(
output: RequestOutput, test_case: MinTokensTestCase
) -> None:
"""Assert that min_tokens requirement is satisfied"""
token_count = get_token_count(output)
stop_reason = (output.outputs[0].stop_reason
if output.outputs else "no output")
stop_reason = output.outputs[0].stop_reason if output.outputs else "no output"
if test_case.expected_exact_len is not None:
# Exact length requirement
assert token_count == test_case.expected_exact_len, (
f"Expected exactly {test_case.expected_exact_len} tokens, "
f"got {token_count} tokens. "
f"Stop reason: {stop_reason}")
f"Stop reason: {stop_reason}"
)
else:
# Minimum length requirement
assert token_count >= (test_case.expected_min_len or 0), (
f"Expected at least {test_case.expected_min_len} tokens, "
f"got {token_count} tokens. "
f"Stop reason: {stop_reason}")
f"Stop reason: {stop_reason}"
)
@pytest.mark.parametrize(
@@ -199,13 +211,13 @@ def assert_min_tokens_satisfied(output: RequestOutput,
def test_min_tokens_comprehensive(llm_v1: LLM, test_case: MinTokensTestCase):
"""
Comprehensive test for min_tokens functionality in V1 engine.
This test covers all critical scenarios for min_tokens:
- Basic functionality (should work)
- Stop strings with min_tokens (known bug)
- EOS tokens with min_tokens (potential bug)
- Edge cases
Args:
llm_v1: V1 LLM instance
test_case: Test scenario parameters
@@ -218,7 +230,7 @@ def test_min_tokens_comprehensive(llm_v1: LLM, test_case: MinTokensTestCase):
max_tokens=test_case.max_tokens,
stop=test_case.stop,
temperature=GREEDY,
include_stop_str_in_output=True # Include stop strings for debugging
include_stop_str_in_output=True, # Include stop strings for debugging
)
# Use simple prompt. Comprehensive stop lists should catch any generation
@@ -250,13 +262,11 @@ def test_min_tokens_comprehensive(llm_v1: LLM, test_case: MinTokensTestCase):
def test_min_tokens_basic_functionality(llm_v1: LLM):
"""
Test basic min_tokens functionality without stop conditions.
This is a baseline test that should always pass and validates
that min_tokens works correctly in the simple case.
"""
sampling_params = SamplingParams(min_tokens=10,
max_tokens=20,
temperature=GREEDY)
sampling_params = SamplingParams(min_tokens=10, max_tokens=20, temperature=GREEDY)
prompt = "Once upon a time"
outputs = llm_v1.generate([prompt], sampling_params)
@@ -269,17 +279,16 @@ def test_min_tokens_basic_functionality(llm_v1: LLM):
@pytest.mark.xfail(
reason=("Known bug #21987: stop strings bypass min_tokens "
"(fixed by PR #22014)"),
reason=("Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)"),
strict=False,
)
def test_min_tokens_stop_strings_bug(llm_v1: LLM):
"""
Test the specific bug where stop strings bypass min_tokens.
This test specifically reproduces the bug Calvin is fixing in PR #22014.
It should fail until that fix is merged.
Strategy: Use guaranteed stop characters that will appear
in any generated text.
"""
@@ -291,7 +300,8 @@ def test_min_tokens_stop_strings_bug(llm_v1: LLM):
# Common letter; likely appears early
stop=["e"],
temperature=GREEDY,
include_stop_str_in_output=True)
include_stop_str_in_output=True,
)
# Simple prompt that will generate text containing "e"
prompt = "The quick brown fox"
@@ -308,23 +318,25 @@ def test_min_tokens_stop_strings_bug(llm_v1: LLM):
# This assertion should fail due to the bug - if stop string is found early,
# the model should still continue generating until min_tokens is reached
stop_reason = (outputs[0].outputs[0].stop_reason
if outputs[0].outputs else "no output")
assert token_count >= 15, ("Bug confirmed: "
f"{token_count} tokens < min_tokens=15. "
f"Reason: {stop_reason}. "
f"Text: {repr(generated_text)}")
stop_reason = (
outputs[0].outputs[0].stop_reason if outputs[0].outputs else "no output"
)
assert token_count >= 15, (
"Bug confirmed: "
f"{token_count} tokens < min_tokens=15. "
f"Reason: {stop_reason}. "
f"Text: {repr(generated_text)}"
)
@pytest.mark.xfail(
reason=("Known bug #21987: stop strings bypass min_tokens "
"(fixed by PR #22014)"),
reason=("Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)"),
strict=False,
)
def test_min_tokens_stop_strings_guaranteed_early_trigger(llm_v1: LLM):
"""
Guaranteed test for stop strings bypassing min_tokens bug.
Strategy: Use very low temperature and multiple common stop strings
to virtually guarantee early detection, combined with long min_tokens
to ensure the bug is exposed regardless of model behavior.
@@ -337,7 +349,8 @@ def test_min_tokens_stop_strings_guaranteed_early_trigger(llm_v1: LLM):
# Use multiple very common patterns - at least one will appear
stop=["e", "a", "i", "o", "u", " ", "t", "n", "s", "r"],
temperature=GREEDY,
include_stop_str_in_output=True)
include_stop_str_in_output=True,
)
# Simple prompt that will generate some text
prompt = "The cat"
@@ -346,8 +359,7 @@ def test_min_tokens_stop_strings_guaranteed_early_trigger(llm_v1: LLM):
assert len(outputs) == 1
token_count = get_token_count(outputs[0])
generated_text = outputs[0].outputs[0].text if outputs[0].outputs else ""
stop_reason = (outputs[0].outputs[0].stop_reason
if outputs[0].outputs else "unknown")
stop_reason = outputs[0].outputs[0].stop_reason if outputs[0].outputs else "unknown"
print(f"Generated text: {repr(generated_text)}")
print(f"Token count: {token_count}")
@@ -357,21 +369,23 @@ def test_min_tokens_stop_strings_guaranteed_early_trigger(llm_v1: LLM):
# will trigger early termination before min_tokens=50 is reached
# It's virtually impossible to generate 50 tokens without hitting
# at least one of: e, a, i, o, u, space, t, n, s, r
finish_reason = (outputs[0].outputs[0].finish_reason
if outputs[0].outputs else "unknown")
finish_reason = (
outputs[0].outputs[0].finish_reason if outputs[0].outputs else "unknown"
)
print(f"Finish reason: {finish_reason}")
if finish_reason == "stop":
assert token_count >= 50, ("Bug confirmed: "
f"{token_count} tokens < min_tokens=50. "
f"Reason: {finish_reason}. "
f"Text: {repr(generated_text)}")
assert token_count >= 50, (
"Bug confirmed: "
f"{token_count} tokens < min_tokens=50. "
f"Reason: {finish_reason}. "
f"Text: {repr(generated_text)}"
)
@pytest.mark.xfail(
reason=(
"Potential logits-processor bug: EOS tokens may bypass min_tokens"),
reason=("Potential logits-processor bug: EOS tokens may bypass min_tokens"),
strict=False,
)
def test_min_tokens_eos_behavior(llm_v1: LLM):
@@ -404,8 +418,14 @@ def test_min_tokens_eos_behavior(llm_v1: LLM):
finish_no_min = choice_no_min.finish_reason
stop_no_min = choice_no_min.stop_reason
print("[no-min] tokens=", len(ids_no_min), " finish=", finish_no_min,
" stop_reason=", stop_no_min)
print(
"[no-min] tokens=",
len(ids_no_min),
" finish=",
finish_no_min,
" stop_reason=",
stop_no_min,
)
assert finish_no_min == "stop", (
f"Expected finish_reason 'stop' without min_tokens, got {finish_no_min}"
@@ -414,7 +434,8 @@ def test_min_tokens_eos_behavior(llm_v1: LLM):
"For EOS-based stop (no user stop strings), stop_reason should be None."
)
assert len(ids_no_min) < max_toks, (
f"Expected early EOS with < {max_toks} tokens, got {len(ids_no_min)}")
f"Expected early EOS with < {max_toks} tokens, got {len(ids_no_min)}"
)
# Case 2: WITH min_tokens
sp_with_min = SamplingParams(
@@ -430,23 +451,31 @@ def test_min_tokens_eos_behavior(llm_v1: LLM):
finish_with_min = choice_with_min.finish_reason
stop_with_min = choice_with_min.stop_reason
print("[with-min] tokens=", len(ids_with_min), " finish=", finish_with_min,
" stop_reason=", stop_with_min)
print(
"[with-min] tokens=",
len(ids_with_min),
" finish=",
finish_with_min,
" stop_reason=",
stop_with_min,
)
# Exact length reached; EOS should have been blocked
assert len(ids_with_min) == max_toks, (
f"Expected exactly {max_toks} tokens with min_tokens; "
f"got {len(ids_with_min)}")
f"Expected exactly {max_toks} tokens with min_tokens; got {len(ids_with_min)}"
)
assert finish_with_min == "length", (
f"Expected finish_reason 'length'; got {finish_with_min}")
f"Expected finish_reason 'length'; got {finish_with_min}"
)
assert eos_token_id not in ids_with_min, (
"EOS token id should not appear when min_tokens prevents early EOS.")
"EOS token id should not appear when min_tokens prevents early EOS."
)
def test_min_tokens_validation():
"""
Test that SamplingParams correctly validates min_tokens parameters.
This tests the parameter validation logic in SamplingParams.
"""
# Valid cases
@@ -456,14 +485,14 @@ def test_min_tokens_validation():
# Invalid cases
with pytest.raises(
ValueError,
match="min_tokens must be greater than or equal to 0",
ValueError,
match="min_tokens must be greater than or equal to 0",
):
SamplingParams(min_tokens=-1, max_tokens=10)
with pytest.raises(
ValueError,
match="min_tokens must be less than or equal to max_tokens",
ValueError,
match="min_tokens must be less than or equal to max_tokens",
):
SamplingParams(min_tokens=15, max_tokens=10)