Add: SupportsEagle3 interface for explicit EAGLE3 support (#22642)
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
This commit is contained in:
@@ -3,12 +3,20 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.interfaces import supports_eagle3
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_path",
|
||||
[("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")])
|
||||
def test_llama(vllm_runner, example_prompts, model_path):
|
||||
def test_llama(vllm_runner, example_prompts, model_path, monkeypatch):
|
||||
# Set environment variable for V1 engine serialization
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
|
||||
eagle3_supported = vllm_model.apply_model(supports_eagle3)
|
||||
assert eagle3_supported
|
||||
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens=20)
|
||||
print(vllm_outputs)
|
||||
@@ -18,8 +26,14 @@ def test_llama(vllm_runner, example_prompts, model_path):
|
||||
@pytest.mark.parametrize(
|
||||
"model_path",
|
||||
[("nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized")])
|
||||
def test_qwen(vllm_runner, example_prompts, model_path):
|
||||
def test_qwen(vllm_runner, example_prompts, model_path, monkeypatch):
|
||||
# Set environment variable for V1 engine serialization
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
|
||||
eagle3_supported = vllm_model.apply_model(supports_eagle3)
|
||||
assert eagle3_supported
|
||||
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens=20)
|
||||
print(vllm_outputs)
|
||||
|
||||
Reference in New Issue
Block a user