EAGLE Support DP>1 (#26086)
Signed-off-by: Rémi Delacourt <remi@mistral.ai> Signed-off-by: Rémi Delacourt <54138269+Flechman@users.noreply.github.com> Signed-off-by: remi <remi@mistral.ai>
This commit is contained in:
77
tests/v1/distributed/test_eagle_dp.py
Normal file
77
tests/v1/distributed/test_eagle_dp.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import os
|
||||
from contextlib import AsyncExitStack
|
||||
from dataclasses import replace
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
DP_SIZE = int(os.getenv("DP_SIZE", 2))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_eagle_dp():
|
||||
target_model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
draft_model = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
|
||||
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=target_model,
|
||||
tokenizer_mode="auto",
|
||||
enforce_eager=False,
|
||||
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
|
||||
data_parallel_size=DP_SIZE,
|
||||
data_parallel_backend="mp", # ray takes more time
|
||||
trust_remote_code=True,
|
||||
max_model_len=16384,
|
||||
)
|
||||
|
||||
eagle_engine_args = replace(
|
||||
engine_args,
|
||||
speculative_config={
|
||||
"model": draft_model,
|
||||
"method": "eagle",
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
)
|
||||
|
||||
prompt = "This is a test of data parallel with eagle"
|
||||
num_expected_tokens = 100
|
||||
sampling_params = SamplingParams(
|
||||
min_tokens=num_expected_tokens,
|
||||
max_tokens=num_expected_tokens,
|
||||
ignore_eos=True,
|
||||
output_kind=RequestOutputKind.FINAL_ONLY,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
async def generate_with_timeout(given_engine: AsyncLLM):
|
||||
async for out in given_engine.generate(
|
||||
request_id="test-eagle-dp", prompt=prompt, sampling_params=sampling_params
|
||||
):
|
||||
token_ids = out.outputs[0].token_ids
|
||||
assert len(token_ids) == num_expected_tokens
|
||||
return token_ids
|
||||
|
||||
async def engine_create_and_generate(engine_args: AsyncEngineArgs):
|
||||
async with AsyncExitStack() as after:
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
token_ids = await asyncio.wait_for(
|
||||
generate_with_timeout(engine), timeout=30
|
||||
)
|
||||
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
return token_ids
|
||||
|
||||
token_ids_with_eagle = await engine_create_and_generate(eagle_engine_args)
|
||||
token_ids_no_eagle = await engine_create_and_generate(engine_args)
|
||||
|
||||
# Test for correctness
|
||||
assert token_ids_with_eagle == token_ids_no_eagle
|
||||
Reference in New Issue
Block a user