[Meta] Official Eagle mm support, first enablement on llama4 (#20788)

Signed-off-by: morgendave <morgendave@gmail.com>
Co-authored-by: Roger Wang <hey@rogerw.me>
This commit is contained in:
zhiweiz
2025-07-31 10:35:07 -07:00
committed by GitHub
parent 53c21e492e
commit 9e0726e5bf
8 changed files with 205 additions and 36 deletions

View File

@@ -3,29 +3,34 @@
from __future__ import annotations
import random
from typing import Any
from typing import Any, Union
import pytest
import torch
from vllm import LLM, SamplingParams
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
from vllm.distributed import cleanup_dist_env_and_memory
@pytest.fixture
def test_prompts():
def get_test_prompts(mm_enabled: bool):
prompt_types = ["repeat", "sentence"]
if mm_enabled:
prompt_types.append("mm")
num_prompts = 100
prompts = []
random.seed(0)
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
print(f"Prompt types: {random_prompt_type_choices}")
# Generate a mixed batch of prompts, some of which can be easily
# predicted by n-gram matching and some which likely cannot.
for kind in random_prompt_type_choices:
word_choices = ["test", "temp", "hello", "where"]
word = random.choice(word_choices)
prompt: Union[str, list[dict[str, Any]]] = ""
if kind == "repeat":
prompt = f"""
please repeat the word '{word}' 10 times.
@@ -38,6 +43,21 @@ def test_prompts():
uses the word {word} at least once.
give no other output than that simple sentence without quotes.
"""
elif kind == "mm":
placeholders = [{
"type": "image_url",
"image_url": {
"url":
f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
},
}]
prompt = [
*placeholders,
{
"type": "text",
"text": "The meaning of the image is"
},
]
else:
raise ValueError(f"Unknown prompt type: {kind}")
prompts.append([{"role": "user", "content": prompt}])
@@ -57,7 +77,6 @@ def model_name():
def test_ngram_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
):
@@ -67,6 +86,7 @@ def test_ngram_correctness(
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
test_prompts = get_test_prompts(mm_enabled=False)
ref_llm = LLM(model=model_name, max_model_len=1024)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
@@ -103,23 +123,32 @@ def test_ngram_correctness(
cleanup_dist_env_and_memory()
@pytest.mark.parametrize("model_setup", [
("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1),
("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1),
pytest.param(
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
],
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"])
@pytest.mark.parametrize(
["model_setup", "mm_enabled"], [
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
pytest.param(
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
False,
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
pytest.param(
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
True,
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
],
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
def test_eagle_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_setup: tuple[str, str, str, int],
mm_enabled: bool,
):
# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle speculative decoding.