Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -48,19 +48,17 @@ def get_test_prompts(mm_enabled: bool):
|
||||
give no other output than that simple sentence without quotes.
|
||||
"""
|
||||
elif kind == "mm":
|
||||
placeholders = [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url":
|
||||
f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
|
||||
},
|
||||
}]
|
||||
placeholders = [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
|
||||
},
|
||||
}
|
||||
]
|
||||
prompt = [
|
||||
*placeholders,
|
||||
{
|
||||
"type": "text",
|
||||
"text": "The meaning of the image is"
|
||||
},
|
||||
{"type": "text", "text": "The meaning of the image is"},
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unknown prompt type: {kind}")
|
||||
@@ -84,10 +82,10 @@ def test_ngram_correctness(
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
):
|
||||
'''
|
||||
"""
|
||||
Compare the outputs of an original LLM and a speculative LLM
|
||||
should be the same when using ngram speculative decoding.
|
||||
'''
|
||||
"""
|
||||
test_prompts = get_test_prompts(mm_enabled=False)
|
||||
|
||||
ref_llm = LLM(model=model_name, max_model_len=1024)
|
||||
@@ -129,32 +127,77 @@ def test_ngram_correctness(
|
||||
["model_setup", "mm_enabled"],
|
||||
[
|
||||
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
|
||||
pytest.param(("eagle3", "Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
"Rayzl/qwen2.5-vl-7b-eagle3-sgl", 1),
|
||||
False,
|
||||
marks=pytest.mark.skip(reason="Skipping due to its " \
|
||||
"head_dim not being a a multiple of 32")),
|
||||
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
|
||||
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
|
||||
pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||
False,
|
||||
marks=large_gpu_mark(min_gb=80)), # works on 4x H100
|
||||
pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||
True,
|
||||
marks=large_gpu_mark(min_gb=80)), # works on 4x H100
|
||||
(("eagle", "eagle618/deepseek-v3-random",
|
||||
"eagle618/eagle-deepseek-v3-random", 1), False),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle3",
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
"Rayzl/qwen2.5-vl-7b-eagle3-sgl",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
marks=pytest.mark.skip(
|
||||
reason="Skipping due to its head_dim not being a a multiple of 32"
|
||||
),
|
||||
),
|
||||
(
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
),
|
||||
(
|
||||
(
|
||||
"eagle3",
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
|
||||
4,
|
||||
),
|
||||
False,
|
||||
marks=large_gpu_mark(min_gb=80),
|
||||
), # works on 4x H100
|
||||
pytest.param(
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
|
||||
4,
|
||||
),
|
||||
True,
|
||||
marks=large_gpu_mark(min_gb=80),
|
||||
), # works on 4x H100
|
||||
(
|
||||
(
|
||||
"eagle",
|
||||
"eagle618/deepseek-v3-random",
|
||||
"eagle618/eagle-deepseek-v3-random",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"qwen3_eagle3", "qwen2_5_vl_eagle3", "llama3_eagle", "llama3_eagle3",
|
||||
"llama4_eagle", "llama4_eagle_mm", "deepseek_eagle"
|
||||
])
|
||||
@pytest.mark.parametrize("attn_backend",
|
||||
get_attn_backend_list_based_on_platform())
|
||||
"qwen3_eagle3",
|
||||
"qwen2_5_vl_eagle3",
|
||||
"llama3_eagle",
|
||||
"llama3_eagle3",
|
||||
"llama4_eagle",
|
||||
"llama4_eagle_mm",
|
||||
"deepseek_eagle",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
|
||||
def test_eagle_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
@@ -166,15 +209,16 @@ def test_eagle_correctness(
|
||||
# TODO: Fix this flaky test
|
||||
pytest.skip(
|
||||
"TREE_ATTN is flaky in the test disable for now until it can be "
|
||||
"resolved (see https://github.com/vllm-project/vllm/issues/22922)")
|
||||
"resolved (see https://github.com/vllm-project/vllm/issues/22922)"
|
||||
)
|
||||
|
||||
# Generate test prompts inside the function instead of using fixture
|
||||
test_prompts = get_test_prompts(mm_enabled)
|
||||
'''
|
||||
"""
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
should be the same when using eagle speculative decoding.
|
||||
model_setup: (method, model_name, eagle_model_name, tp_size)
|
||||
'''
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
|
||||
# Scout requires default backend selection
|
||||
@@ -185,18 +229,20 @@ def test_eagle_correctness(
|
||||
m.setenv("VLLM_MLA_DISABLE", "1")
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()):
|
||||
pytest.skip("TRITON_ATTN does not support "
|
||||
"multi-token eagle spec decode on current platform")
|
||||
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"TRITON_ATTN does not support "
|
||||
"multi-token eagle spec decode on current platform"
|
||||
)
|
||||
|
||||
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
method, model_name, spec_model_name, tp_size = model_setup
|
||||
|
||||
ref_llm = LLM(model=model_name,
|
||||
max_model_len=2048,
|
||||
tensor_parallel_size=tp_size)
|
||||
ref_llm = LLM(
|
||||
model=model_name, max_model_len=2048, tensor_parallel_size=tp_size
|
||||
)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
@@ -233,11 +279,14 @@ def test_eagle_correctness(
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
|
||||
(("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False),
|
||||
(("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False),
|
||||
],
|
||||
ids=["mimo", "deepseek"])
|
||||
@pytest.mark.parametrize(
|
||||
["model_setup", "mm_enabled"],
|
||||
[
|
||||
(("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False),
|
||||
(("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False),
|
||||
],
|
||||
ids=["mimo", "deepseek"],
|
||||
)
|
||||
def test_mtp_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
@@ -246,21 +295,23 @@ def test_mtp_correctness(
|
||||
):
|
||||
# Generate test prompts inside the function instead of using fixture
|
||||
test_prompts = get_test_prompts(mm_enabled)
|
||||
'''
|
||||
"""
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
should be the same when using MTP speculative decoding.
|
||||
model_setup: (method, model_name, tp_size)
|
||||
'''
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
m.setenv("VLLM_MLA_DISABLE", "1")
|
||||
|
||||
method, model_name, tp_size = model_setup
|
||||
|
||||
ref_llm = LLM(model=model_name,
|
||||
max_model_len=2048,
|
||||
tensor_parallel_size=tp_size,
|
||||
trust_remote_code=True)
|
||||
ref_llm = LLM(
|
||||
model=model_name,
|
||||
max_model_len=2048,
|
||||
tensor_parallel_size=tp_size,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user