[Meta] Official Eagle mm support, first enablement on llama4 (#20788)
Signed-off-by: morgendave <morgendave@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.me>
This commit is contained in:
@@ -3,29 +3,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from typing import Any
|
||||
from typing import Any, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.base import VLLM_S3_BUCKET_URL
|
||||
from vllm.assets.image import VLM_IMAGES_DIR
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_prompts():
|
||||
def get_test_prompts(mm_enabled: bool):
|
||||
prompt_types = ["repeat", "sentence"]
|
||||
if mm_enabled:
|
||||
prompt_types.append("mm")
|
||||
num_prompts = 100
|
||||
prompts = []
|
||||
|
||||
random.seed(0)
|
||||
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
|
||||
print(f"Prompt types: {random_prompt_type_choices}")
|
||||
|
||||
# Generate a mixed batch of prompts, some of which can be easily
|
||||
# predicted by n-gram matching and some which likely cannot.
|
||||
for kind in random_prompt_type_choices:
|
||||
word_choices = ["test", "temp", "hello", "where"]
|
||||
word = random.choice(word_choices)
|
||||
prompt: Union[str, list[dict[str, Any]]] = ""
|
||||
if kind == "repeat":
|
||||
prompt = f"""
|
||||
please repeat the word '{word}' 10 times.
|
||||
@@ -38,6 +43,21 @@ def test_prompts():
|
||||
uses the word {word} at least once.
|
||||
give no other output than that simple sentence without quotes.
|
||||
"""
|
||||
elif kind == "mm":
|
||||
placeholders = [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url":
|
||||
f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
|
||||
},
|
||||
}]
|
||||
prompt = [
|
||||
*placeholders,
|
||||
{
|
||||
"type": "text",
|
||||
"text": "The meaning of the image is"
|
||||
},
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unknown prompt type: {kind}")
|
||||
prompts.append([{"role": "user", "content": prompt}])
|
||||
@@ -57,7 +77,6 @@ def model_name():
|
||||
|
||||
def test_ngram_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
test_prompts: list[list[dict[str, Any]]],
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
):
|
||||
@@ -67,6 +86,7 @@ def test_ngram_correctness(
|
||||
'''
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
test_prompts = get_test_prompts(mm_enabled=False)
|
||||
|
||||
ref_llm = LLM(model=model_name, max_model_len=1024)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
@@ -103,23 +123,32 @@ def test_ngram_correctness(
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_setup", [
|
||||
("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1),
|
||||
("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1),
|
||||
pytest.param(
|
||||
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||
],
|
||||
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"])
|
||||
@pytest.mark.parametrize(
|
||||
["model_setup", "mm_enabled"], [
|
||||
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
|
||||
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
|
||||
pytest.param(
|
||||
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||
False,
|
||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||
pytest.param(
|
||||
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||
True,
|
||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||
],
|
||||
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
|
||||
def test_eagle_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
test_prompts: list[list[dict[str, Any]]],
|
||||
sampling_config: SamplingParams,
|
||||
model_setup: tuple[str, str, str, int],
|
||||
mm_enabled: bool,
|
||||
):
|
||||
# Generate test prompts inside the function instead of using fixture
|
||||
test_prompts = get_test_prompts(mm_enabled)
|
||||
'''
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
should be the same when using eagle speculative decoding.
|
||||
|
||||
Reference in New Issue
Block a user