[V1][Spec Decode] Eagle Model loading (#16035)

Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
Lily Liu
2025-04-10 11:21:48 -07:00
committed by GitHub
parent 9665313c39
commit e8224f3dca
9 changed files with 251 additions and 28 deletions

View File

@@ -53,6 +53,11 @@ def model_name():
return "meta-llama/Meta-Llama-3-8B-Instruct"
@pytest.fixture
def eagle_model_name():
return "yuhuili/EAGLE-LLaMA3-Instruct-8B"
def test_ngram_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
@@ -95,3 +100,47 @@ def test_ngram_correctness(
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.7 * len(ref_outputs))
del spec_llm
def test_eagle_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
eagle_model_name: str,
):
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle speculative decoding.
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
ref_llm = LLM(model=model_name, max_model_len=1024)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
spec_llm = LLM(
model=model_name,
speculative_config={
"method": "eagle",
"model": eagle_model_name,
"num_speculative_tokens": 3,
},
max_model_len=1024,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 70% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.7 * len(ref_outputs))
del spec_llm