[BugFix] fix num_lookahead_slots missing in async executor (#4165)
Co-authored-by: Lei Wen <wenlei03@qiyi.com>
This commit is contained in:
@@ -1,10 +1,127 @@
|
||||
from typing import List, Tuple
|
||||
import asyncio
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
|
||||
from tests.conftest import cleanup
|
||||
from vllm import LLM
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import MultiModalData
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Counter, random_uuid
|
||||
|
||||
|
||||
class AsyncLLM:
|
||||
"""AsyncLLM
|
||||
|
||||
Note: Current LLM class in vllm don't support async mode, for test purpose,
|
||||
we implement async one in here. Maybe we could move to
|
||||
vllm/entrypoints/llm.py in future.
|
||||
|
||||
Below AsyncLLM is directly borrow from vllm/entrypoints/llm.py with changes
|
||||
to make to work in async mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
tokenizer: Optional[str] = None,
|
||||
tokenizer_mode: str = "auto",
|
||||
skip_tokenizer_init: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
tensor_parallel_size: int = 1,
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
seed: int = 0,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: int = 4,
|
||||
enforce_eager: bool = False,
|
||||
max_context_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if "disable_log_stats" not in kwargs:
|
||||
kwargs["disable_log_stats"] = True
|
||||
self.engine_args = AsyncEngineArgs(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
dtype=dtype,
|
||||
quantization=quantization,
|
||||
revision=revision,
|
||||
tokenizer_revision=tokenizer_revision,
|
||||
seed=seed,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
swap_space=swap_space,
|
||||
enforce_eager=enforce_eager,
|
||||
max_context_len_to_capture=max_context_len_to_capture,
|
||||
engine_use_ray=True,
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
**kwargs,
|
||||
)
|
||||
self.request_counter = Counter()
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: Optional[Union[str, List[str]]] = None,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
List[SamplingParams]]] = None,
|
||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[RequestOutput]:
|
||||
|
||||
llm_engine = AsyncLLMEngine.from_engine_args(
|
||||
self.engine_args, usage_context=UsageContext.LLM_CLASS)
|
||||
|
||||
if prompts is None:
|
||||
raise ValueError("prompts must be provided.")
|
||||
if isinstance(prompts, str):
|
||||
# Convert a single prompt to a list.
|
||||
prompts = [prompts]
|
||||
|
||||
if prompts is not None:
|
||||
num_requests = len(prompts)
|
||||
|
||||
if sampling_params is None:
|
||||
# Use default sampling params.
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
elif isinstance(sampling_params,
|
||||
list) and len(sampling_params) != num_requests:
|
||||
raise ValueError("The lengths of prompts and "
|
||||
"sampling_params must be the same.")
|
||||
|
||||
async def get_output(prompt, sampling_param) -> str:
|
||||
request_id = random_uuid()
|
||||
results_generator = llm_engine.generate(prompt, sampling_param,
|
||||
request_id)
|
||||
final_output = None
|
||||
async for request_output in results_generator:
|
||||
final_output = request_output
|
||||
return final_output
|
||||
|
||||
outputs = []
|
||||
try:
|
||||
for i in range(num_requests):
|
||||
prompt = prompts[i] if prompts is not None else None
|
||||
res = asyncio.run(get_output(prompt, sampling_params))
|
||||
outputs.append(res)
|
||||
finally:
|
||||
ray.shutdown()
|
||||
return outputs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -36,8 +153,12 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
|
||||
|
||||
def generator_inner():
|
||||
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
|
||||
llm = LLM(**kwargs)
|
||||
|
||||
use_async = False
|
||||
if "use_async" in kwargs:
|
||||
use_async = kwargs.pop("use_async")
|
||||
|
||||
llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs)
|
||||
set_random_seed(seed)
|
||||
|
||||
yield llm
|
||||
|
||||
Reference in New Issue
Block a user