[Meta] Llama4 EAGLE Support (#20591)
Signed-off-by: qizixi <qizixi@meta.com> Co-authored-by: qizixi <qizixi@meta.com>
This commit is contained in:
@@ -6,8 +6,10 @@ import random
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -53,14 +55,6 @@ def model_name():
|
||||
return "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
|
||||
def eagle_model_name():
|
||||
return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
|
||||
|
||||
|
||||
def eagle3_model_name():
|
||||
return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
|
||||
|
||||
def test_ngram_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
test_prompts: list[list[dict[str, Any]]],
|
||||
@@ -77,6 +71,8 @@ def test_ngram_correctness(
|
||||
ref_llm = LLM(model=model_name, max_model_len=1024)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
@@ -103,34 +99,50 @@ def test_ngram_correctness(
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(0.7 * len(ref_outputs))
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
|
||||
@pytest.mark.parametrize("model_setup", [
|
||||
("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1),
|
||||
("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1),
|
||||
pytest.param(
|
||||
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||
],
|
||||
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"])
|
||||
def test_eagle_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
test_prompts: list[list[dict[str, Any]]],
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
use_eagle3: bool,
|
||||
model_setup: tuple[str, str, str, int],
|
||||
):
|
||||
'''
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
should be the same when using eagle speculative decoding.
|
||||
model_setup: (method, model_name, eagle_model_name, tp_size)
|
||||
'''
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
method, model_name, spec_model_name, tp_size = model_setup
|
||||
|
||||
ref_llm = LLM(model=model_name, max_model_len=2048)
|
||||
ref_llm = LLM(model=model_name,
|
||||
max_model_len=2048,
|
||||
tensor_parallel_size=tp_size)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
spec_model_name = eagle3_model_name(
|
||||
) if use_eagle3 else eagle_model_name()
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
tensor_parallel_size=tp_size,
|
||||
speculative_config={
|
||||
"method": "eagle3" if use_eagle3 else "eagle",
|
||||
"method": method,
|
||||
"model": spec_model_name,
|
||||
"num_speculative_tokens": 3,
|
||||
"max_model_len": 2048,
|
||||
@@ -152,3 +164,5 @@ def test_eagle_correctness(
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(0.66 * len(ref_outputs))
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
Reference in New Issue
Block a user