[V1][Spec Decode] EAGLE-3 Support (#16937)

Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Co-authored-by: Bryan Lu <yuzhelu@amazon.com>
This commit is contained in:
Benjamin Chislett
2025-04-25 18:43:07 -04:00
committed by GitHub
parent 70116459c3
commit a0e619e62a
12 changed files with 358 additions and 34 deletions

View File

@@ -52,8 +52,8 @@ def main():
args = parse_args()
model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm"
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
max_model_len = 2048
@@ -81,7 +81,7 @@ def main():
max_num_seqs=args.max_num_seqs,
gpu_memory_utilization=0.8,
speculative_config={
"method": "eagle",
"method": "eagle3" if "eagle3" in eagle_dir.lower() else "eagle",
"model": eagle_dir,
"num_speculative_tokens": args.num_spec_tokens,
"draft_tensor_parallel_size": args.draft_tp,
@@ -95,6 +95,9 @@ def main():
outputs = llm.generate(prompt_token_ids=prompt_ids,
sampling_params=sampling_params)
if not hasattr(outputs, "metrics") or outputs.metrics is None:
return
# calculate the average number of accepted tokens per forward pass, +1 is
# to account for the token from the target model that's always going to be
# accepted
@@ -109,6 +112,11 @@ def main():
{sum(acceptance_counts) / acceptance_counts[0]:.2f}")
print("-" * 50)
# print acceptance at each token position
for i in range(len(acceptance_counts)):
print(f"acceptance at token {i}:"
f"{acceptance_counts[i] / (acceptance_counts[0]):.2f}")
if __name__ == "__main__":
main()