Add Eagle and Eagle3 support to Transformers modeling backend (#30340)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -280,9 +280,20 @@ def test_speculators_model_integration(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["model_setup", "mm_enabled", "enable_chunked_prefill"],
|
||||
["model_setup", "mm_enabled", "enable_chunked_prefill", "model_impl"],
|
||||
[
|
||||
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False, False),
|
||||
(
|
||||
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
),
|
||||
(
|
||||
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
|
||||
False,
|
||||
False,
|
||||
"transformers",
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle3",
|
||||
@@ -292,6 +303,7 @@ def test_speculators_model_integration(
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
marks=pytest.mark.skip(
|
||||
reason="architecture of its eagle3 is LlamaForCausalLMEagle3"
|
||||
),
|
||||
@@ -305,6 +317,7 @@ def test_speculators_model_integration(
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
marks=pytest.mark.skip(
|
||||
reason="Skipping due to its head_dim not being a a multiple of 32"
|
||||
),
|
||||
@@ -318,6 +331,7 @@ def test_speculators_model_integration(
|
||||
),
|
||||
False,
|
||||
True,
|
||||
"auto",
|
||||
marks=large_gpu_mark(min_gb=40),
|
||||
), # works on 4x H100
|
||||
(
|
||||
@@ -329,6 +343,7 @@ def test_speculators_model_integration(
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
@@ -339,6 +354,7 @@ def test_speculators_model_integration(
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
marks=large_gpu_mark(min_gb=80),
|
||||
), # works on 4x H100
|
||||
pytest.param(
|
||||
@@ -350,6 +366,7 @@ def test_speculators_model_integration(
|
||||
),
|
||||
True,
|
||||
True,
|
||||
"auto",
|
||||
marks=large_gpu_mark(min_gb=80),
|
||||
), # works on 4x H100
|
||||
(
|
||||
@@ -361,10 +378,12 @@ def test_speculators_model_integration(
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"qwen3_eagle3",
|
||||
"qwen3_eagle3-transformers",
|
||||
"qwen3_vl_eagle3",
|
||||
"qwen2_5_vl_eagle3",
|
||||
"llama3_eagle",
|
||||
@@ -381,6 +400,7 @@ def test_eagle_correctness(
|
||||
model_setup: tuple[str, str, str, int],
|
||||
mm_enabled: bool,
|
||||
enable_chunked_prefill: bool,
|
||||
model_impl: str,
|
||||
attn_backend: str,
|
||||
):
|
||||
if attn_backend == "TREE_ATTN":
|
||||
@@ -389,6 +409,17 @@ def test_eagle_correctness(
|
||||
"TREE_ATTN is flaky in the test disable for now until it can be "
|
||||
"resolved (see https://github.com/vllm-project/vllm/issues/22922)"
|
||||
)
|
||||
if model_impl == "transformers":
|
||||
import transformers
|
||||
from packaging.version import Version
|
||||
|
||||
installed = Version(transformers.__version__)
|
||||
required = Version("5.0.0.dev")
|
||||
if installed < required:
|
||||
pytest.skip(
|
||||
"Eagle3 with the Transformers modeling backend requires "
|
||||
f"transformers>={required}, but got {installed}"
|
||||
)
|
||||
|
||||
# Generate test prompts inside the function instead of using fixture
|
||||
test_prompts = get_test_prompts(mm_enabled)
|
||||
@@ -448,6 +479,7 @@ def test_eagle_correctness(
|
||||
max_model_len=max_model_len,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
model_impl=model_impl,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
|
||||
Reference in New Issue
Block a user