Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -14,8 +14,10 @@ def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
|
||||
prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:"
|
||||
|
||||
if attn_backend == "FLASHINFER":
|
||||
pytest.skip("This test is failing with FlashInfer backend and "
|
||||
"needs investigation. See issue #25679.")
|
||||
pytest.skip(
|
||||
"This test is failing with FlashInfer backend and "
|
||||
"needs investigation. See issue #25679."
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
@@ -92,7 +92,7 @@ def test_max_context_length(
|
||||
)
|
||||
|
||||
# HF returns the prompt + generated tokens. Slice off the prompt.
|
||||
hf_output_ids = hf_generated.cpu().tolist()[0][len(prompt_ids[0]):]
|
||||
hf_output_ids = hf_generated.cpu().tolist()[0][len(prompt_ids[0]) :]
|
||||
|
||||
# check that exactly max_tokens tokens were generated with vLLM and HF
|
||||
assert len(vllm_output_ids) == len(hf_output_ids) == max_tokens
|
||||
|
||||
@@ -26,12 +26,14 @@ model_config = {
|
||||
[
|
||||
"bigcode/starcoder2-3b", # sliding window only
|
||||
"google/gemma-3-1b-it", # sliding window + full attention
|
||||
])
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("batch_size", [5])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("disable_hybrid_kv_cache_manager", [True, False])
|
||||
def test_sliding_window_retrieval(monkeypatch, model, batch_size, seed,
|
||||
disable_hybrid_kv_cache_manager):
|
||||
def test_sliding_window_retrieval(
|
||||
monkeypatch, model, batch_size, seed, disable_hybrid_kv_cache_manager
|
||||
):
|
||||
"""
|
||||
The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then
|
||||
asks for value of one of them (which is outside the sliding window).
|
||||
@@ -44,33 +46,38 @@ def test_sliding_window_retrieval(monkeypatch, model, batch_size, seed,
|
||||
test_config = model_config[model]
|
||||
|
||||
llm = LLM(
|
||||
model=model,
|
||||
disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager)
|
||||
model=model, disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager
|
||||
)
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
|
||||
|
||||
prompts, answer, indices = prep_prompts(batch_size,
|
||||
ln_range=test_config.ln_range)
|
||||
prompts, answer, indices = prep_prompts(
|
||||
batch_size, ln_range=test_config.ln_range
|
||||
)
|
||||
|
||||
check_length(prompts, llm, test_config.sliding_window)
|
||||
|
||||
# Fresh generation
|
||||
responses = llm.generate(prompts, sampling_params)
|
||||
check_answers(indices,
|
||||
answer,
|
||||
[response.outputs[0].text for response in responses],
|
||||
accept_rate=1.0)
|
||||
check_answers(
|
||||
indices,
|
||||
answer,
|
||||
[response.outputs[0].text for response in responses],
|
||||
accept_rate=1.0,
|
||||
)
|
||||
|
||||
# Re-generate with the same prompts to test prefix caching
|
||||
responses = llm.generate(prompts, sampling_params)
|
||||
check_answers(indices,
|
||||
answer,
|
||||
[response.outputs[0].text for response in responses],
|
||||
accept_rate=1.0)
|
||||
check_answers(
|
||||
indices,
|
||||
answer,
|
||||
[response.outputs[0].text for response in responses],
|
||||
accept_rate=1.0,
|
||||
)
|
||||
|
||||
|
||||
def check_length(prompts: list[str], llm: LLM, sliding_window: int):
|
||||
"""
|
||||
Check if the prompt length is valid, i.e., longer than the sliding window
|
||||
Check if the prompt length is valid, i.e., longer than the sliding window
|
||||
size and shorter than the model's max length.
|
||||
|
||||
Args:
|
||||
@@ -80,9 +87,9 @@ def check_length(prompts: list[str], llm: LLM, sliding_window: int):
|
||||
"""
|
||||
tokenizer = llm.get_tokenizer()
|
||||
max_model_len = llm.llm_engine.model_config.max_model_len
|
||||
assert any(
|
||||
len(tokenizer.encode(prompt)) > sliding_window
|
||||
for prompt in prompts), "Prompt is too short for test"
|
||||
assert all(
|
||||
len(tokenizer.encode(prompt)) <= max_model_len
|
||||
for prompt in prompts), "Prompt is too long for test"
|
||||
assert any(len(tokenizer.encode(prompt)) > sliding_window for prompt in prompts), (
|
||||
"Prompt is too short for test"
|
||||
)
|
||||
assert all(len(tokenizer.encode(prompt)) <= max_model_len for prompt in prompts), (
|
||||
"Prompt is too long for test"
|
||||
)
|
||||
|
||||
@@ -76,7 +76,9 @@ def test_kv_sharing_fast_prefill(
|
||||
# managing buffers for cudagraph
|
||||
cudagraph_copy_inputs=True,
|
||||
level=CompilationLevel.PIECEWISE
|
||||
if not enforce_eager else CompilationLevel.NO_COMPILATION)
|
||||
if not enforce_eager
|
||||
else CompilationLevel.NO_COMPILATION,
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
@@ -94,21 +96,21 @@ def test_kv_sharing_fast_prefill(
|
||||
|
||||
cleanup(llm, compilation_config)
|
||||
|
||||
llm = LLM(model="google/gemma-3n-E2B-it",
|
||||
enforce_eager=enforce_eager,
|
||||
compilation_config=compilation_config,
|
||||
seed=SEED,
|
||||
kv_sharing_fast_prefill=True)
|
||||
llm = LLM(
|
||||
model="google/gemma-3n-E2B-it",
|
||||
enforce_eager=enforce_eager,
|
||||
compilation_config=compilation_config,
|
||||
seed=SEED,
|
||||
kv_sharing_fast_prefill=True,
|
||||
)
|
||||
optimized_responses = llm.generate(test_prompts, sampling_params)
|
||||
|
||||
cleanup(llm, compilation_config)
|
||||
|
||||
misses = 0
|
||||
|
||||
for ref_response, optimized_response in zip(ref_responses,
|
||||
optimized_responses):
|
||||
if ref_response.outputs[0].text != optimized_response.outputs[
|
||||
0].text:
|
||||
for ref_response, optimized_response in zip(ref_responses, optimized_responses):
|
||||
if ref_response.outputs[0].text != optimized_response.outputs[0].text:
|
||||
misses += 1
|
||||
|
||||
assert misses == 0
|
||||
|
||||
@@ -46,29 +46,36 @@ class MinTokensTestCase:
|
||||
self.expected_exact_len = expected_exact_len
|
||||
|
||||
def __str__(self):
|
||||
return (f"{self.name}: min={self.min_tokens}, "
|
||||
f"max={self.max_tokens}, stop={self.stop}")
|
||||
return (
|
||||
f"{self.name}: min={self.min_tokens}, "
|
||||
f"max={self.max_tokens}, stop={self.stop}"
|
||||
)
|
||||
|
||||
|
||||
# Test scenarios covering all critical cases
|
||||
MIN_TOKENS_TEST_CASES = [
|
||||
# === BASIC FUNCTIONALITY (should work) ===
|
||||
MinTokensTestCase(name="basic_min_tokens_no_stop",
|
||||
min_tokens=8,
|
||||
max_tokens=20,
|
||||
stop=None,
|
||||
expected_min_len=8),
|
||||
MinTokensTestCase(name="min_tokens_zero",
|
||||
min_tokens=0,
|
||||
max_tokens=10,
|
||||
stop=None,
|
||||
expected_min_len=0),
|
||||
MinTokensTestCase(name="min_equals_max_no_stop",
|
||||
min_tokens=15,
|
||||
max_tokens=15,
|
||||
stop=None,
|
||||
expected_exact_len=15),
|
||||
|
||||
MinTokensTestCase(
|
||||
name="basic_min_tokens_no_stop",
|
||||
min_tokens=8,
|
||||
max_tokens=20,
|
||||
stop=None,
|
||||
expected_min_len=8,
|
||||
),
|
||||
MinTokensTestCase(
|
||||
name="min_tokens_zero",
|
||||
min_tokens=0,
|
||||
max_tokens=10,
|
||||
stop=None,
|
||||
expected_min_len=0,
|
||||
),
|
||||
MinTokensTestCase(
|
||||
name="min_equals_max_no_stop",
|
||||
min_tokens=15,
|
||||
max_tokens=15,
|
||||
stop=None,
|
||||
expected_exact_len=15,
|
||||
),
|
||||
# === STOP STRINGS WITH MIN_TOKENS ===
|
||||
# These tests expose the detokenizer bug where stop strings
|
||||
# bypass min_tokens
|
||||
@@ -94,9 +101,11 @@ MIN_TOKENS_TEST_CASES = [
|
||||
expected_min_len=5,
|
||||
),
|
||||
marks=pytest.mark.xfail(
|
||||
reason=("Known bug #21987: stop strings bypass min_tokens "
|
||||
"(fixed by PR #22014)"),
|
||||
strict=False),
|
||||
reason=(
|
||||
"Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)"
|
||||
),
|
||||
strict=False,
|
||||
),
|
||||
id="min_tokens_with_comprehensive_stops",
|
||||
),
|
||||
pytest.param(
|
||||
@@ -108,12 +117,13 @@ MIN_TOKENS_TEST_CASES = [
|
||||
expected_min_len=3,
|
||||
),
|
||||
marks=pytest.mark.xfail(
|
||||
reason=("Known bug #21987: stop strings bypass min_tokens "
|
||||
"(fixed by PR #22014)"),
|
||||
strict=False),
|
||||
reason=(
|
||||
"Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)"
|
||||
),
|
||||
strict=False,
|
||||
),
|
||||
id="min_tokens_with_simple_char_stop",
|
||||
),
|
||||
|
||||
# === EOS TOKEN WITH MIN_TOKENS (potential LogitsProcessor bug) ===
|
||||
# These test the MinTokensLogitsProcessor handling of EOS tokens
|
||||
pytest.param(
|
||||
@@ -125,26 +135,26 @@ MIN_TOKENS_TEST_CASES = [
|
||||
expected_exact_len=20,
|
||||
),
|
||||
marks=pytest.mark.xfail(
|
||||
reason=
|
||||
("Potential logits-processor bug: EOS tokens may bypass min_tokens"
|
||||
),
|
||||
reason=("Potential logits-processor bug: EOS tokens may bypass min_tokens"),
|
||||
strict=False,
|
||||
),
|
||||
id="min_equals_max_eos_only",
|
||||
),
|
||||
|
||||
# === EDGE CASES ===
|
||||
MinTokensTestCase(name="large_min_tokens",
|
||||
min_tokens=50,
|
||||
max_tokens=60,
|
||||
stop=None,
|
||||
expected_min_len=50),
|
||||
MinTokensTestCase(
|
||||
name="large_min_tokens",
|
||||
min_tokens=50,
|
||||
max_tokens=60,
|
||||
stop=None,
|
||||
expected_min_len=50,
|
||||
),
|
||||
MinTokensTestCase(
|
||||
name="min_tokens_with_empty_stop_list",
|
||||
min_tokens=5,
|
||||
max_tokens=15,
|
||||
stop=[], # Empty stop list
|
||||
expected_min_len=5),
|
||||
expected_min_len=5,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -170,25 +180,27 @@ def get_token_count(output: RequestOutput) -> int:
|
||||
return len(output.outputs[0].token_ids)
|
||||
|
||||
|
||||
def assert_min_tokens_satisfied(output: RequestOutput,
|
||||
test_case: MinTokensTestCase) -> None:
|
||||
def assert_min_tokens_satisfied(
|
||||
output: RequestOutput, test_case: MinTokensTestCase
|
||||
) -> None:
|
||||
"""Assert that min_tokens requirement is satisfied"""
|
||||
token_count = get_token_count(output)
|
||||
stop_reason = (output.outputs[0].stop_reason
|
||||
if output.outputs else "no output")
|
||||
stop_reason = output.outputs[0].stop_reason if output.outputs else "no output"
|
||||
|
||||
if test_case.expected_exact_len is not None:
|
||||
# Exact length requirement
|
||||
assert token_count == test_case.expected_exact_len, (
|
||||
f"Expected exactly {test_case.expected_exact_len} tokens, "
|
||||
f"got {token_count} tokens. "
|
||||
f"Stop reason: {stop_reason}")
|
||||
f"Stop reason: {stop_reason}"
|
||||
)
|
||||
else:
|
||||
# Minimum length requirement
|
||||
assert token_count >= (test_case.expected_min_len or 0), (
|
||||
f"Expected at least {test_case.expected_min_len} tokens, "
|
||||
f"got {token_count} tokens. "
|
||||
f"Stop reason: {stop_reason}")
|
||||
f"Stop reason: {stop_reason}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -199,13 +211,13 @@ def assert_min_tokens_satisfied(output: RequestOutput,
|
||||
def test_min_tokens_comprehensive(llm_v1: LLM, test_case: MinTokensTestCase):
|
||||
"""
|
||||
Comprehensive test for min_tokens functionality in V1 engine.
|
||||
|
||||
|
||||
This test covers all critical scenarios for min_tokens:
|
||||
- Basic functionality (should work)
|
||||
- Stop strings with min_tokens (known bug)
|
||||
- EOS tokens with min_tokens (potential bug)
|
||||
- Edge cases
|
||||
|
||||
|
||||
Args:
|
||||
llm_v1: V1 LLM instance
|
||||
test_case: Test scenario parameters
|
||||
@@ -218,7 +230,7 @@ def test_min_tokens_comprehensive(llm_v1: LLM, test_case: MinTokensTestCase):
|
||||
max_tokens=test_case.max_tokens,
|
||||
stop=test_case.stop,
|
||||
temperature=GREEDY,
|
||||
include_stop_str_in_output=True # Include stop strings for debugging
|
||||
include_stop_str_in_output=True, # Include stop strings for debugging
|
||||
)
|
||||
|
||||
# Use simple prompt. Comprehensive stop lists should catch any generation
|
||||
@@ -250,13 +262,11 @@ def test_min_tokens_comprehensive(llm_v1: LLM, test_case: MinTokensTestCase):
|
||||
def test_min_tokens_basic_functionality(llm_v1: LLM):
|
||||
"""
|
||||
Test basic min_tokens functionality without stop conditions.
|
||||
|
||||
|
||||
This is a baseline test that should always pass and validates
|
||||
that min_tokens works correctly in the simple case.
|
||||
"""
|
||||
sampling_params = SamplingParams(min_tokens=10,
|
||||
max_tokens=20,
|
||||
temperature=GREEDY)
|
||||
sampling_params = SamplingParams(min_tokens=10, max_tokens=20, temperature=GREEDY)
|
||||
|
||||
prompt = "Once upon a time"
|
||||
outputs = llm_v1.generate([prompt], sampling_params)
|
||||
@@ -269,17 +279,16 @@ def test_min_tokens_basic_functionality(llm_v1: LLM):
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=("Known bug #21987: stop strings bypass min_tokens "
|
||||
"(fixed by PR #22014)"),
|
||||
reason=("Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)"),
|
||||
strict=False,
|
||||
)
|
||||
def test_min_tokens_stop_strings_bug(llm_v1: LLM):
|
||||
"""
|
||||
Test the specific bug where stop strings bypass min_tokens.
|
||||
|
||||
|
||||
This test specifically reproduces the bug Calvin is fixing in PR #22014.
|
||||
It should fail until that fix is merged.
|
||||
|
||||
|
||||
Strategy: Use guaranteed stop characters that will appear
|
||||
in any generated text.
|
||||
"""
|
||||
@@ -291,7 +300,8 @@ def test_min_tokens_stop_strings_bug(llm_v1: LLM):
|
||||
# Common letter; likely appears early
|
||||
stop=["e"],
|
||||
temperature=GREEDY,
|
||||
include_stop_str_in_output=True)
|
||||
include_stop_str_in_output=True,
|
||||
)
|
||||
|
||||
# Simple prompt that will generate text containing "e"
|
||||
prompt = "The quick brown fox"
|
||||
@@ -308,23 +318,25 @@ def test_min_tokens_stop_strings_bug(llm_v1: LLM):
|
||||
|
||||
# This assertion should fail due to the bug - if stop string is found early,
|
||||
# the model should still continue generating until min_tokens is reached
|
||||
stop_reason = (outputs[0].outputs[0].stop_reason
|
||||
if outputs[0].outputs else "no output")
|
||||
assert token_count >= 15, ("Bug confirmed: "
|
||||
f"{token_count} tokens < min_tokens=15. "
|
||||
f"Reason: {stop_reason}. "
|
||||
f"Text: {repr(generated_text)}")
|
||||
stop_reason = (
|
||||
outputs[0].outputs[0].stop_reason if outputs[0].outputs else "no output"
|
||||
)
|
||||
assert token_count >= 15, (
|
||||
"Bug confirmed: "
|
||||
f"{token_count} tokens < min_tokens=15. "
|
||||
f"Reason: {stop_reason}. "
|
||||
f"Text: {repr(generated_text)}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=("Known bug #21987: stop strings bypass min_tokens "
|
||||
"(fixed by PR #22014)"),
|
||||
reason=("Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)"),
|
||||
strict=False,
|
||||
)
|
||||
def test_min_tokens_stop_strings_guaranteed_early_trigger(llm_v1: LLM):
|
||||
"""
|
||||
Guaranteed test for stop strings bypassing min_tokens bug.
|
||||
|
||||
|
||||
Strategy: Use very low temperature and multiple common stop strings
|
||||
to virtually guarantee early detection, combined with long min_tokens
|
||||
to ensure the bug is exposed regardless of model behavior.
|
||||
@@ -337,7 +349,8 @@ def test_min_tokens_stop_strings_guaranteed_early_trigger(llm_v1: LLM):
|
||||
# Use multiple very common patterns - at least one will appear
|
||||
stop=["e", "a", "i", "o", "u", " ", "t", "n", "s", "r"],
|
||||
temperature=GREEDY,
|
||||
include_stop_str_in_output=True)
|
||||
include_stop_str_in_output=True,
|
||||
)
|
||||
|
||||
# Simple prompt that will generate some text
|
||||
prompt = "The cat"
|
||||
@@ -346,8 +359,7 @@ def test_min_tokens_stop_strings_guaranteed_early_trigger(llm_v1: LLM):
|
||||
assert len(outputs) == 1
|
||||
token_count = get_token_count(outputs[0])
|
||||
generated_text = outputs[0].outputs[0].text if outputs[0].outputs else ""
|
||||
stop_reason = (outputs[0].outputs[0].stop_reason
|
||||
if outputs[0].outputs else "unknown")
|
||||
stop_reason = outputs[0].outputs[0].stop_reason if outputs[0].outputs else "unknown"
|
||||
|
||||
print(f"Generated text: {repr(generated_text)}")
|
||||
print(f"Token count: {token_count}")
|
||||
@@ -357,21 +369,23 @@ def test_min_tokens_stop_strings_guaranteed_early_trigger(llm_v1: LLM):
|
||||
# will trigger early termination before min_tokens=50 is reached
|
||||
# It's virtually impossible to generate 50 tokens without hitting
|
||||
# at least one of: e, a, i, o, u, space, t, n, s, r
|
||||
finish_reason = (outputs[0].outputs[0].finish_reason
|
||||
if outputs[0].outputs else "unknown")
|
||||
finish_reason = (
|
||||
outputs[0].outputs[0].finish_reason if outputs[0].outputs else "unknown"
|
||||
)
|
||||
|
||||
print(f"Finish reason: {finish_reason}")
|
||||
|
||||
if finish_reason == "stop":
|
||||
assert token_count >= 50, ("Bug confirmed: "
|
||||
f"{token_count} tokens < min_tokens=50. "
|
||||
f"Reason: {finish_reason}. "
|
||||
f"Text: {repr(generated_text)}")
|
||||
assert token_count >= 50, (
|
||||
"Bug confirmed: "
|
||||
f"{token_count} tokens < min_tokens=50. "
|
||||
f"Reason: {finish_reason}. "
|
||||
f"Text: {repr(generated_text)}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"Potential logits-processor bug: EOS tokens may bypass min_tokens"),
|
||||
reason=("Potential logits-processor bug: EOS tokens may bypass min_tokens"),
|
||||
strict=False,
|
||||
)
|
||||
def test_min_tokens_eos_behavior(llm_v1: LLM):
|
||||
@@ -404,8 +418,14 @@ def test_min_tokens_eos_behavior(llm_v1: LLM):
|
||||
finish_no_min = choice_no_min.finish_reason
|
||||
stop_no_min = choice_no_min.stop_reason
|
||||
|
||||
print("[no-min] tokens=", len(ids_no_min), " finish=", finish_no_min,
|
||||
" stop_reason=", stop_no_min)
|
||||
print(
|
||||
"[no-min] tokens=",
|
||||
len(ids_no_min),
|
||||
" finish=",
|
||||
finish_no_min,
|
||||
" stop_reason=",
|
||||
stop_no_min,
|
||||
)
|
||||
|
||||
assert finish_no_min == "stop", (
|
||||
f"Expected finish_reason 'stop' without min_tokens, got {finish_no_min}"
|
||||
@@ -414,7 +434,8 @@ def test_min_tokens_eos_behavior(llm_v1: LLM):
|
||||
"For EOS-based stop (no user stop strings), stop_reason should be None."
|
||||
)
|
||||
assert len(ids_no_min) < max_toks, (
|
||||
f"Expected early EOS with < {max_toks} tokens, got {len(ids_no_min)}")
|
||||
f"Expected early EOS with < {max_toks} tokens, got {len(ids_no_min)}"
|
||||
)
|
||||
|
||||
# Case 2: WITH min_tokens
|
||||
sp_with_min = SamplingParams(
|
||||
@@ -430,23 +451,31 @@ def test_min_tokens_eos_behavior(llm_v1: LLM):
|
||||
finish_with_min = choice_with_min.finish_reason
|
||||
stop_with_min = choice_with_min.stop_reason
|
||||
|
||||
print("[with-min] tokens=", len(ids_with_min), " finish=", finish_with_min,
|
||||
" stop_reason=", stop_with_min)
|
||||
print(
|
||||
"[with-min] tokens=",
|
||||
len(ids_with_min),
|
||||
" finish=",
|
||||
finish_with_min,
|
||||
" stop_reason=",
|
||||
stop_with_min,
|
||||
)
|
||||
|
||||
# Exact length reached; EOS should have been blocked
|
||||
assert len(ids_with_min) == max_toks, (
|
||||
f"Expected exactly {max_toks} tokens with min_tokens; "
|
||||
f"got {len(ids_with_min)}")
|
||||
f"Expected exactly {max_toks} tokens with min_tokens; got {len(ids_with_min)}"
|
||||
)
|
||||
assert finish_with_min == "length", (
|
||||
f"Expected finish_reason 'length'; got {finish_with_min}")
|
||||
f"Expected finish_reason 'length'; got {finish_with_min}"
|
||||
)
|
||||
assert eos_token_id not in ids_with_min, (
|
||||
"EOS token id should not appear when min_tokens prevents early EOS.")
|
||||
"EOS token id should not appear when min_tokens prevents early EOS."
|
||||
)
|
||||
|
||||
|
||||
def test_min_tokens_validation():
|
||||
"""
|
||||
Test that SamplingParams correctly validates min_tokens parameters.
|
||||
|
||||
|
||||
This tests the parameter validation logic in SamplingParams.
|
||||
"""
|
||||
# Valid cases
|
||||
@@ -456,14 +485,14 @@ def test_min_tokens_validation():
|
||||
|
||||
# Invalid cases
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="min_tokens must be greater than or equal to 0",
|
||||
ValueError,
|
||||
match="min_tokens must be greater than or equal to 0",
|
||||
):
|
||||
SamplingParams(min_tokens=-1, max_tokens=10)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="min_tokens must be less than or equal to max_tokens",
|
||||
ValueError,
|
||||
match="min_tokens must be less than or equal to max_tokens",
|
||||
):
|
||||
SamplingParams(min_tokens=15, max_tokens=10)
|
||||
|
||||
|
||||
@@ -48,19 +48,17 @@ def get_test_prompts(mm_enabled: bool):
|
||||
give no other output than that simple sentence without quotes.
|
||||
"""
|
||||
elif kind == "mm":
|
||||
placeholders = [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url":
|
||||
f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
|
||||
},
|
||||
}]
|
||||
placeholders = [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
|
||||
},
|
||||
}
|
||||
]
|
||||
prompt = [
|
||||
*placeholders,
|
||||
{
|
||||
"type": "text",
|
||||
"text": "The meaning of the image is"
|
||||
},
|
||||
{"type": "text", "text": "The meaning of the image is"},
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unknown prompt type: {kind}")
|
||||
@@ -84,10 +82,10 @@ def test_ngram_correctness(
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
):
|
||||
'''
|
||||
"""
|
||||
Compare the outputs of an original LLM and a speculative LLM
|
||||
should be the same when using ngram speculative decoding.
|
||||
'''
|
||||
"""
|
||||
test_prompts = get_test_prompts(mm_enabled=False)
|
||||
|
||||
ref_llm = LLM(model=model_name, max_model_len=1024)
|
||||
@@ -129,32 +127,77 @@ def test_ngram_correctness(
|
||||
["model_setup", "mm_enabled"],
|
||||
[
|
||||
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
|
||||
pytest.param(("eagle3", "Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
"Rayzl/qwen2.5-vl-7b-eagle3-sgl", 1),
|
||||
False,
|
||||
marks=pytest.mark.skip(reason="Skipping due to its " \
|
||||
"head_dim not being a a multiple of 32")),
|
||||
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
|
||||
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
|
||||
pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||
False,
|
||||
marks=large_gpu_mark(min_gb=80)), # works on 4x H100
|
||||
pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||
True,
|
||||
marks=large_gpu_mark(min_gb=80)), # works on 4x H100
|
||||
(("eagle", "eagle618/deepseek-v3-random",
|
||||
"eagle618/eagle-deepseek-v3-random", 1), False),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle3",
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
"Rayzl/qwen2.5-vl-7b-eagle3-sgl",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
marks=pytest.mark.skip(
|
||||
reason="Skipping due to its head_dim not being a a multiple of 32"
|
||||
),
|
||||
),
|
||||
(
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
),
|
||||
(
|
||||
(
|
||||
"eagle3",
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
|
||||
4,
|
||||
),
|
||||
False,
|
||||
marks=large_gpu_mark(min_gb=80),
|
||||
), # works on 4x H100
|
||||
pytest.param(
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
|
||||
4,
|
||||
),
|
||||
True,
|
||||
marks=large_gpu_mark(min_gb=80),
|
||||
), # works on 4x H100
|
||||
(
|
||||
(
|
||||
"eagle",
|
||||
"eagle618/deepseek-v3-random",
|
||||
"eagle618/eagle-deepseek-v3-random",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"qwen3_eagle3", "qwen2_5_vl_eagle3", "llama3_eagle", "llama3_eagle3",
|
||||
"llama4_eagle", "llama4_eagle_mm", "deepseek_eagle"
|
||||
])
|
||||
@pytest.mark.parametrize("attn_backend",
|
||||
get_attn_backend_list_based_on_platform())
|
||||
"qwen3_eagle3",
|
||||
"qwen2_5_vl_eagle3",
|
||||
"llama3_eagle",
|
||||
"llama3_eagle3",
|
||||
"llama4_eagle",
|
||||
"llama4_eagle_mm",
|
||||
"deepseek_eagle",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
|
||||
def test_eagle_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
@@ -166,15 +209,16 @@ def test_eagle_correctness(
|
||||
# TODO: Fix this flaky test
|
||||
pytest.skip(
|
||||
"TREE_ATTN is flaky in the test disable for now until it can be "
|
||||
"resolved (see https://github.com/vllm-project/vllm/issues/22922)")
|
||||
"resolved (see https://github.com/vllm-project/vllm/issues/22922)"
|
||||
)
|
||||
|
||||
# Generate test prompts inside the function instead of using fixture
|
||||
test_prompts = get_test_prompts(mm_enabled)
|
||||
'''
|
||||
"""
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
should be the same when using eagle speculative decoding.
|
||||
model_setup: (method, model_name, eagle_model_name, tp_size)
|
||||
'''
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
|
||||
# Scout requires default backend selection
|
||||
@@ -185,18 +229,20 @@ def test_eagle_correctness(
|
||||
m.setenv("VLLM_MLA_DISABLE", "1")
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()):
|
||||
pytest.skip("TRITON_ATTN does not support "
|
||||
"multi-token eagle spec decode on current platform")
|
||||
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"TRITON_ATTN does not support "
|
||||
"multi-token eagle spec decode on current platform"
|
||||
)
|
||||
|
||||
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
method, model_name, spec_model_name, tp_size = model_setup
|
||||
|
||||
ref_llm = LLM(model=model_name,
|
||||
max_model_len=2048,
|
||||
tensor_parallel_size=tp_size)
|
||||
ref_llm = LLM(
|
||||
model=model_name, max_model_len=2048, tensor_parallel_size=tp_size
|
||||
)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
@@ -233,11 +279,14 @@ def test_eagle_correctness(
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
|
||||
(("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False),
|
||||
(("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False),
|
||||
],
|
||||
ids=["mimo", "deepseek"])
|
||||
@pytest.mark.parametrize(
|
||||
["model_setup", "mm_enabled"],
|
||||
[
|
||||
(("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False),
|
||||
(("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False),
|
||||
],
|
||||
ids=["mimo", "deepseek"],
|
||||
)
|
||||
def test_mtp_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
@@ -246,21 +295,23 @@ def test_mtp_correctness(
|
||||
):
|
||||
# Generate test prompts inside the function instead of using fixture
|
||||
test_prompts = get_test_prompts(mm_enabled)
|
||||
'''
|
||||
"""
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
should be the same when using MTP speculative decoding.
|
||||
model_setup: (method, model_name, tp_size)
|
||||
'''
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
m.setenv("VLLM_MLA_DISABLE", "1")
|
||||
|
||||
method, model_name, tp_size = model_setup
|
||||
|
||||
ref_llm = LLM(model=model_name,
|
||||
max_model_len=2048,
|
||||
tensor_parallel_size=tp_size,
|
||||
trust_remote_code=True)
|
||||
ref_llm = LLM(
|
||||
model=model_name,
|
||||
max_model_len=2048,
|
||||
tensor_parallel_size=tp_size,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user