Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -14,8 +14,10 @@ def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:"
if attn_backend == "FLASHINFER":
pytest.skip("This test is failing with FlashInfer backend and "
"needs investigation. See issue #25679.")
pytest.skip(
"This test is failing with FlashInfer backend and "
"needs investigation. See issue #25679."
)
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

View File

@@ -92,7 +92,7 @@ def test_max_context_length(
)
# HF returns the prompt + generated tokens. Slice off the prompt.
hf_output_ids = hf_generated.cpu().tolist()[0][len(prompt_ids[0]):]
hf_output_ids = hf_generated.cpu().tolist()[0][len(prompt_ids[0]) :]
# check that exactly max_tokens tokens were generated with vLLM and HF
assert len(vllm_output_ids) == len(hf_output_ids) == max_tokens

View File

@@ -26,12 +26,14 @@ model_config = {
[
"bigcode/starcoder2-3b", # sliding window only
"google/gemma-3-1b-it", # sliding window + full attention
])
],
)
@pytest.mark.parametrize("batch_size", [5])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("disable_hybrid_kv_cache_manager", [True, False])
def test_sliding_window_retrieval(monkeypatch, model, batch_size, seed,
disable_hybrid_kv_cache_manager):
def test_sliding_window_retrieval(
monkeypatch, model, batch_size, seed, disable_hybrid_kv_cache_manager
):
"""
The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then
asks for value of one of them (which is outside the sliding window).
@@ -44,33 +46,38 @@ def test_sliding_window_retrieval(monkeypatch, model, batch_size, seed,
test_config = model_config[model]
llm = LLM(
model=model,
disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager)
model=model, disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager
)
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
prompts, answer, indices = prep_prompts(batch_size,
ln_range=test_config.ln_range)
prompts, answer, indices = prep_prompts(
batch_size, ln_range=test_config.ln_range
)
check_length(prompts, llm, test_config.sliding_window)
# Fresh generation
responses = llm.generate(prompts, sampling_params)
check_answers(indices,
answer,
[response.outputs[0].text for response in responses],
accept_rate=1.0)
check_answers(
indices,
answer,
[response.outputs[0].text for response in responses],
accept_rate=1.0,
)
# Re-generate with the same prompts to test prefix caching
responses = llm.generate(prompts, sampling_params)
check_answers(indices,
answer,
[response.outputs[0].text for response in responses],
accept_rate=1.0)
check_answers(
indices,
answer,
[response.outputs[0].text for response in responses],
accept_rate=1.0,
)
def check_length(prompts: list[str], llm: LLM, sliding_window: int):
"""
Check if the prompt length is valid, i.e., longer than the sliding window
Check if the prompt length is valid, i.e., longer than the sliding window
size and shorter than the model's max length.
Args:
@@ -80,9 +87,9 @@ def check_length(prompts: list[str], llm: LLM, sliding_window: int):
"""
tokenizer = llm.get_tokenizer()
max_model_len = llm.llm_engine.model_config.max_model_len
assert any(
len(tokenizer.encode(prompt)) > sliding_window
for prompt in prompts), "Prompt is too short for test"
assert all(
len(tokenizer.encode(prompt)) <= max_model_len
for prompt in prompts), "Prompt is too long for test"
assert any(len(tokenizer.encode(prompt)) > sliding_window for prompt in prompts), (
"Prompt is too short for test"
)
assert all(len(tokenizer.encode(prompt)) <= max_model_len for prompt in prompts), (
"Prompt is too long for test"
)

View File

@@ -76,7 +76,9 @@ def test_kv_sharing_fast_prefill(
# managing buffers for cudagraph
cudagraph_copy_inputs=True,
level=CompilationLevel.PIECEWISE
if not enforce_eager else CompilationLevel.NO_COMPILATION)
if not enforce_eager
else CompilationLevel.NO_COMPILATION,
)
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
@@ -94,21 +96,21 @@ def test_kv_sharing_fast_prefill(
cleanup(llm, compilation_config)
llm = LLM(model="google/gemma-3n-E2B-it",
enforce_eager=enforce_eager,
compilation_config=compilation_config,
seed=SEED,
kv_sharing_fast_prefill=True)
llm = LLM(
model="google/gemma-3n-E2B-it",
enforce_eager=enforce_eager,
compilation_config=compilation_config,
seed=SEED,
kv_sharing_fast_prefill=True,
)
optimized_responses = llm.generate(test_prompts, sampling_params)
cleanup(llm, compilation_config)
misses = 0
for ref_response, optimized_response in zip(ref_responses,
optimized_responses):
if ref_response.outputs[0].text != optimized_response.outputs[
0].text:
for ref_response, optimized_response in zip(ref_responses, optimized_responses):
if ref_response.outputs[0].text != optimized_response.outputs[0].text:
misses += 1
assert misses == 0

View File

@@ -46,29 +46,36 @@ class MinTokensTestCase:
self.expected_exact_len = expected_exact_len
def __str__(self):
return (f"{self.name}: min={self.min_tokens}, "
f"max={self.max_tokens}, stop={self.stop}")
return (
f"{self.name}: min={self.min_tokens}, "
f"max={self.max_tokens}, stop={self.stop}"
)
# Test scenarios covering all critical cases
MIN_TOKENS_TEST_CASES = [
# === BASIC FUNCTIONALITY (should work) ===
MinTokensTestCase(name="basic_min_tokens_no_stop",
min_tokens=8,
max_tokens=20,
stop=None,
expected_min_len=8),
MinTokensTestCase(name="min_tokens_zero",
min_tokens=0,
max_tokens=10,
stop=None,
expected_min_len=0),
MinTokensTestCase(name="min_equals_max_no_stop",
min_tokens=15,
max_tokens=15,
stop=None,
expected_exact_len=15),
MinTokensTestCase(
name="basic_min_tokens_no_stop",
min_tokens=8,
max_tokens=20,
stop=None,
expected_min_len=8,
),
MinTokensTestCase(
name="min_tokens_zero",
min_tokens=0,
max_tokens=10,
stop=None,
expected_min_len=0,
),
MinTokensTestCase(
name="min_equals_max_no_stop",
min_tokens=15,
max_tokens=15,
stop=None,
expected_exact_len=15,
),
# === STOP STRINGS WITH MIN_TOKENS ===
# These tests expose the detokenizer bug where stop strings
# bypass min_tokens
@@ -94,9 +101,11 @@ MIN_TOKENS_TEST_CASES = [
expected_min_len=5,
),
marks=pytest.mark.xfail(
reason=("Known bug #21987: stop strings bypass min_tokens "
"(fixed by PR #22014)"),
strict=False),
reason=(
"Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)"
),
strict=False,
),
id="min_tokens_with_comprehensive_stops",
),
pytest.param(
@@ -108,12 +117,13 @@ MIN_TOKENS_TEST_CASES = [
expected_min_len=3,
),
marks=pytest.mark.xfail(
reason=("Known bug #21987: stop strings bypass min_tokens "
"(fixed by PR #22014)"),
strict=False),
reason=(
"Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)"
),
strict=False,
),
id="min_tokens_with_simple_char_stop",
),
# === EOS TOKEN WITH MIN_TOKENS (potential LogitsProcessor bug) ===
# These test the MinTokensLogitsProcessor handling of EOS tokens
pytest.param(
@@ -125,26 +135,26 @@ MIN_TOKENS_TEST_CASES = [
expected_exact_len=20,
),
marks=pytest.mark.xfail(
reason=
("Potential logits-processor bug: EOS tokens may bypass min_tokens"
),
reason=("Potential logits-processor bug: EOS tokens may bypass min_tokens"),
strict=False,
),
id="min_equals_max_eos_only",
),
# === EDGE CASES ===
MinTokensTestCase(name="large_min_tokens",
min_tokens=50,
max_tokens=60,
stop=None,
expected_min_len=50),
MinTokensTestCase(
name="large_min_tokens",
min_tokens=50,
max_tokens=60,
stop=None,
expected_min_len=50,
),
MinTokensTestCase(
name="min_tokens_with_empty_stop_list",
min_tokens=5,
max_tokens=15,
stop=[], # Empty stop list
expected_min_len=5),
expected_min_len=5,
),
]
@@ -170,25 +180,27 @@ def get_token_count(output: RequestOutput) -> int:
return len(output.outputs[0].token_ids)
def assert_min_tokens_satisfied(output: RequestOutput,
test_case: MinTokensTestCase) -> None:
def assert_min_tokens_satisfied(
output: RequestOutput, test_case: MinTokensTestCase
) -> None:
"""Assert that min_tokens requirement is satisfied"""
token_count = get_token_count(output)
stop_reason = (output.outputs[0].stop_reason
if output.outputs else "no output")
stop_reason = output.outputs[0].stop_reason if output.outputs else "no output"
if test_case.expected_exact_len is not None:
# Exact length requirement
assert token_count == test_case.expected_exact_len, (
f"Expected exactly {test_case.expected_exact_len} tokens, "
f"got {token_count} tokens. "
f"Stop reason: {stop_reason}")
f"Stop reason: {stop_reason}"
)
else:
# Minimum length requirement
assert token_count >= (test_case.expected_min_len or 0), (
f"Expected at least {test_case.expected_min_len} tokens, "
f"got {token_count} tokens. "
f"Stop reason: {stop_reason}")
f"Stop reason: {stop_reason}"
)
@pytest.mark.parametrize(
@@ -199,13 +211,13 @@ def assert_min_tokens_satisfied(output: RequestOutput,
def test_min_tokens_comprehensive(llm_v1: LLM, test_case: MinTokensTestCase):
"""
Comprehensive test for min_tokens functionality in V1 engine.
This test covers all critical scenarios for min_tokens:
- Basic functionality (should work)
- Stop strings with min_tokens (known bug)
- EOS tokens with min_tokens (potential bug)
- Edge cases
Args:
llm_v1: V1 LLM instance
test_case: Test scenario parameters
@@ -218,7 +230,7 @@ def test_min_tokens_comprehensive(llm_v1: LLM, test_case: MinTokensTestCase):
max_tokens=test_case.max_tokens,
stop=test_case.stop,
temperature=GREEDY,
include_stop_str_in_output=True # Include stop strings for debugging
include_stop_str_in_output=True, # Include stop strings for debugging
)
# Use simple prompt. Comprehensive stop lists should catch any generation
@@ -250,13 +262,11 @@ def test_min_tokens_comprehensive(llm_v1: LLM, test_case: MinTokensTestCase):
def test_min_tokens_basic_functionality(llm_v1: LLM):
"""
Test basic min_tokens functionality without stop conditions.
This is a baseline test that should always pass and validates
that min_tokens works correctly in the simple case.
"""
sampling_params = SamplingParams(min_tokens=10,
max_tokens=20,
temperature=GREEDY)
sampling_params = SamplingParams(min_tokens=10, max_tokens=20, temperature=GREEDY)
prompt = "Once upon a time"
outputs = llm_v1.generate([prompt], sampling_params)
@@ -269,17 +279,16 @@ def test_min_tokens_basic_functionality(llm_v1: LLM):
@pytest.mark.xfail(
reason=("Known bug #21987: stop strings bypass min_tokens "
"(fixed by PR #22014)"),
reason=("Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)"),
strict=False,
)
def test_min_tokens_stop_strings_bug(llm_v1: LLM):
"""
Test the specific bug where stop strings bypass min_tokens.
This test specifically reproduces the bug Calvin is fixing in PR #22014.
It should fail until that fix is merged.
Strategy: Use guaranteed stop characters that will appear
in any generated text.
"""
@@ -291,7 +300,8 @@ def test_min_tokens_stop_strings_bug(llm_v1: LLM):
# Common letter; likely appears early
stop=["e"],
temperature=GREEDY,
include_stop_str_in_output=True)
include_stop_str_in_output=True,
)
# Simple prompt that will generate text containing "e"
prompt = "The quick brown fox"
@@ -308,23 +318,25 @@ def test_min_tokens_stop_strings_bug(llm_v1: LLM):
# This assertion should fail due to the bug - if stop string is found early,
# the model should still continue generating until min_tokens is reached
stop_reason = (outputs[0].outputs[0].stop_reason
if outputs[0].outputs else "no output")
assert token_count >= 15, ("Bug confirmed: "
f"{token_count} tokens < min_tokens=15. "
f"Reason: {stop_reason}. "
f"Text: {repr(generated_text)}")
stop_reason = (
outputs[0].outputs[0].stop_reason if outputs[0].outputs else "no output"
)
assert token_count >= 15, (
"Bug confirmed: "
f"{token_count} tokens < min_tokens=15. "
f"Reason: {stop_reason}. "
f"Text: {repr(generated_text)}"
)
@pytest.mark.xfail(
reason=("Known bug #21987: stop strings bypass min_tokens "
"(fixed by PR #22014)"),
reason=("Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)"),
strict=False,
)
def test_min_tokens_stop_strings_guaranteed_early_trigger(llm_v1: LLM):
"""
Guaranteed test for stop strings bypassing min_tokens bug.
Strategy: Use very low temperature and multiple common stop strings
to virtually guarantee early detection, combined with long min_tokens
to ensure the bug is exposed regardless of model behavior.
@@ -337,7 +349,8 @@ def test_min_tokens_stop_strings_guaranteed_early_trigger(llm_v1: LLM):
# Use multiple very common patterns - at least one will appear
stop=["e", "a", "i", "o", "u", " ", "t", "n", "s", "r"],
temperature=GREEDY,
include_stop_str_in_output=True)
include_stop_str_in_output=True,
)
# Simple prompt that will generate some text
prompt = "The cat"
@@ -346,8 +359,7 @@ def test_min_tokens_stop_strings_guaranteed_early_trigger(llm_v1: LLM):
assert len(outputs) == 1
token_count = get_token_count(outputs[0])
generated_text = outputs[0].outputs[0].text if outputs[0].outputs else ""
stop_reason = (outputs[0].outputs[0].stop_reason
if outputs[0].outputs else "unknown")
stop_reason = outputs[0].outputs[0].stop_reason if outputs[0].outputs else "unknown"
print(f"Generated text: {repr(generated_text)}")
print(f"Token count: {token_count}")
@@ -357,21 +369,23 @@ def test_min_tokens_stop_strings_guaranteed_early_trigger(llm_v1: LLM):
# will trigger early termination before min_tokens=50 is reached
# It's virtually impossible to generate 50 tokens without hitting
# at least one of: e, a, i, o, u, space, t, n, s, r
finish_reason = (outputs[0].outputs[0].finish_reason
if outputs[0].outputs else "unknown")
finish_reason = (
outputs[0].outputs[0].finish_reason if outputs[0].outputs else "unknown"
)
print(f"Finish reason: {finish_reason}")
if finish_reason == "stop":
assert token_count >= 50, ("Bug confirmed: "
f"{token_count} tokens < min_tokens=50. "
f"Reason: {finish_reason}. "
f"Text: {repr(generated_text)}")
assert token_count >= 50, (
"Bug confirmed: "
f"{token_count} tokens < min_tokens=50. "
f"Reason: {finish_reason}. "
f"Text: {repr(generated_text)}"
)
@pytest.mark.xfail(
reason=(
"Potential logits-processor bug: EOS tokens may bypass min_tokens"),
reason=("Potential logits-processor bug: EOS tokens may bypass min_tokens"),
strict=False,
)
def test_min_tokens_eos_behavior(llm_v1: LLM):
@@ -404,8 +418,14 @@ def test_min_tokens_eos_behavior(llm_v1: LLM):
finish_no_min = choice_no_min.finish_reason
stop_no_min = choice_no_min.stop_reason
print("[no-min] tokens=", len(ids_no_min), " finish=", finish_no_min,
" stop_reason=", stop_no_min)
print(
"[no-min] tokens=",
len(ids_no_min),
" finish=",
finish_no_min,
" stop_reason=",
stop_no_min,
)
assert finish_no_min == "stop", (
f"Expected finish_reason 'stop' without min_tokens, got {finish_no_min}"
@@ -414,7 +434,8 @@ def test_min_tokens_eos_behavior(llm_v1: LLM):
"For EOS-based stop (no user stop strings), stop_reason should be None."
)
assert len(ids_no_min) < max_toks, (
f"Expected early EOS with < {max_toks} tokens, got {len(ids_no_min)}")
f"Expected early EOS with < {max_toks} tokens, got {len(ids_no_min)}"
)
# Case 2: WITH min_tokens
sp_with_min = SamplingParams(
@@ -430,23 +451,31 @@ def test_min_tokens_eos_behavior(llm_v1: LLM):
finish_with_min = choice_with_min.finish_reason
stop_with_min = choice_with_min.stop_reason
print("[with-min] tokens=", len(ids_with_min), " finish=", finish_with_min,
" stop_reason=", stop_with_min)
print(
"[with-min] tokens=",
len(ids_with_min),
" finish=",
finish_with_min,
" stop_reason=",
stop_with_min,
)
# Exact length reached; EOS should have been blocked
assert len(ids_with_min) == max_toks, (
f"Expected exactly {max_toks} tokens with min_tokens; "
f"got {len(ids_with_min)}")
f"Expected exactly {max_toks} tokens with min_tokens; got {len(ids_with_min)}"
)
assert finish_with_min == "length", (
f"Expected finish_reason 'length'; got {finish_with_min}")
f"Expected finish_reason 'length'; got {finish_with_min}"
)
assert eos_token_id not in ids_with_min, (
"EOS token id should not appear when min_tokens prevents early EOS.")
"EOS token id should not appear when min_tokens prevents early EOS."
)
def test_min_tokens_validation():
"""
Test that SamplingParams correctly validates min_tokens parameters.
This tests the parameter validation logic in SamplingParams.
"""
# Valid cases
@@ -456,14 +485,14 @@ def test_min_tokens_validation():
# Invalid cases
with pytest.raises(
ValueError,
match="min_tokens must be greater than or equal to 0",
ValueError,
match="min_tokens must be greater than or equal to 0",
):
SamplingParams(min_tokens=-1, max_tokens=10)
with pytest.raises(
ValueError,
match="min_tokens must be less than or equal to max_tokens",
ValueError,
match="min_tokens must be less than or equal to max_tokens",
):
SamplingParams(min_tokens=15, max_tokens=10)

View File

@@ -48,19 +48,17 @@ def get_test_prompts(mm_enabled: bool):
give no other output than that simple sentence without quotes.
"""
elif kind == "mm":
placeholders = [{
"type": "image_url",
"image_url": {
"url":
f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
},
}]
placeholders = [
{
"type": "image_url",
"image_url": {
"url": f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
},
}
]
prompt = [
*placeholders,
{
"type": "text",
"text": "The meaning of the image is"
},
{"type": "text", "text": "The meaning of the image is"},
]
else:
raise ValueError(f"Unknown prompt type: {kind}")
@@ -84,10 +82,10 @@ def test_ngram_correctness(
sampling_config: SamplingParams,
model_name: str,
):
'''
"""
Compare the outputs of an original LLM and a speculative LLM
should be the same when using ngram speculative decoding.
'''
"""
test_prompts = get_test_prompts(mm_enabled=False)
ref_llm = LLM(model=model_name, max_model_len=1024)
@@ -129,32 +127,77 @@ def test_ngram_correctness(
["model_setup", "mm_enabled"],
[
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
pytest.param(("eagle3", "Qwen/Qwen2.5-VL-7B-Instruct",
"Rayzl/qwen2.5-vl-7b-eagle3-sgl", 1),
False,
marks=pytest.mark.skip(reason="Skipping due to its " \
"head_dim not being a a multiple of 32")),
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
False,
marks=large_gpu_mark(min_gb=80)), # works on 4x H100
pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
True,
marks=large_gpu_mark(min_gb=80)), # works on 4x H100
(("eagle", "eagle618/deepseek-v3-random",
"eagle618/eagle-deepseek-v3-random", 1), False),
pytest.param(
(
"eagle3",
"Qwen/Qwen2.5-VL-7B-Instruct",
"Rayzl/qwen2.5-vl-7b-eagle3-sgl",
1,
),
False,
marks=pytest.mark.skip(
reason="Skipping due to its head_dim not being a a multiple of 32"
),
),
(
(
"eagle",
"meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
1,
),
False,
),
(
(
"eagle3",
"meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
1,
),
False,
),
pytest.param(
(
"eagle",
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
4,
),
False,
marks=large_gpu_mark(min_gb=80),
), # works on 4x H100
pytest.param(
(
"eagle",
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
4,
),
True,
marks=large_gpu_mark(min_gb=80),
), # works on 4x H100
(
(
"eagle",
"eagle618/deepseek-v3-random",
"eagle618/eagle-deepseek-v3-random",
1,
),
False,
),
],
ids=[
"qwen3_eagle3", "qwen2_5_vl_eagle3", "llama3_eagle", "llama3_eagle3",
"llama4_eagle", "llama4_eagle_mm", "deepseek_eagle"
])
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
"qwen3_eagle3",
"qwen2_5_vl_eagle3",
"llama3_eagle",
"llama3_eagle3",
"llama4_eagle",
"llama4_eagle_mm",
"deepseek_eagle",
],
)
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
def test_eagle_correctness(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
@@ -166,15 +209,16 @@ def test_eagle_correctness(
# TODO: Fix this flaky test
pytest.skip(
"TREE_ATTN is flaky in the test disable for now until it can be "
"resolved (see https://github.com/vllm-project/vllm/issues/22922)")
"resolved (see https://github.com/vllm-project/vllm/issues/22922)"
)
# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
'''
"""
Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle speculative decoding.
model_setup: (method, model_name, eagle_model_name, tp_size)
'''
"""
with monkeypatch.context() as m:
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
# Scout requires default backend selection
@@ -185,18 +229,20 @@ def test_eagle_correctness(
m.setenv("VLLM_MLA_DISABLE", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()):
pytest.skip("TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform")
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
m.setenv("VLLM_ROCM_USE_AITER", "1")
method, model_name, spec_model_name, tp_size = model_setup
ref_llm = LLM(model=model_name,
max_model_len=2048,
tensor_parallel_size=tp_size)
ref_llm = LLM(
model=model_name, max_model_len=2048, tensor_parallel_size=tp_size
)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
@@ -233,11 +279,14 @@ def test_eagle_correctness(
cleanup_dist_env_and_memory()
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
(("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False),
(("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False),
],
ids=["mimo", "deepseek"])
@pytest.mark.parametrize(
["model_setup", "mm_enabled"],
[
(("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False),
(("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False),
],
ids=["mimo", "deepseek"],
)
def test_mtp_correctness(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
@@ -246,21 +295,23 @@ def test_mtp_correctness(
):
# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
'''
"""
Compare the outputs of a original LLM and a speculative LLM
should be the same when using MTP speculative decoding.
model_setup: (method, model_name, tp_size)
'''
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_MLA_DISABLE", "1")
method, model_name, tp_size = model_setup
ref_llm = LLM(model=model_name,
max_model_len=2048,
tensor_parallel_size=tp_size,
trust_remote_code=True)
ref_llm = LLM(
model=model_name,
max_model_len=2048,
tensor_parallel_size=tp_size,
trust_remote_code=True,
)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()