[V1][Spec Decode] EAGLE-3 Support (#16937)
Signed-off-by: Bryan Lu <yuzhelu@amazon.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Co-authored-by: Bryan Lu <yuzhelu@amazon.com>
This commit is contained in:
committed by
GitHub
parent
70116459c3
commit
a0e619e62a
@@ -50,12 +50,15 @@ def sampling_config():
|
||||
|
||||
@pytest.fixture
|
||||
def model_name():
|
||||
return "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
return "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def eagle_model_name():
|
||||
return "yuhuili/EAGLE-LLaMA3-Instruct-8B"
|
||||
return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
|
||||
|
||||
|
||||
def eagle3_model_name():
|
||||
return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
|
||||
|
||||
def test_ngram_correctness(
|
||||
@@ -102,12 +105,13 @@ def test_ngram_correctness(
|
||||
del spec_llm
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
|
||||
def test_eagle_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
test_prompts: list[list[dict[str, Any]]],
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
eagle_model_name: str,
|
||||
use_eagle3: bool,
|
||||
):
|
||||
'''
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
@@ -116,18 +120,22 @@ def test_eagle_correctness(
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
ref_llm = LLM(model=model_name, max_model_len=1024)
|
||||
ref_llm = LLM(model=model_name, max_model_len=2048)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
|
||||
spec_model_name = eagle3_model_name(
|
||||
) if use_eagle3 else eagle_model_name()
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
speculative_config={
|
||||
"method": "eagle",
|
||||
"model": eagle_model_name,
|
||||
"method": "eagle3" if use_eagle3 else "eagle",
|
||||
"model": spec_model_name,
|
||||
"num_speculative_tokens": 3,
|
||||
"max_model_len": 2048,
|
||||
},
|
||||
max_model_len=1024,
|
||||
max_model_len=2048,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
|
||||
Reference in New Issue
Block a user