[Frontend] Refactor prompt processing (#4028)

Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Cyrus Leung
2024-07-23 01:13:53 +08:00
committed by GitHub
parent 89c1c6a196
commit 739b61a348
24 changed files with 699 additions and 391 deletions

View File

@@ -1,8 +1,8 @@
import asyncio
import time
from functools import partial
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
Set, Tuple, Type, Union)
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping,
Optional, Set, Tuple, Type, Union)
from transformers import PreTrainedTokenizer
@@ -151,7 +151,10 @@ class RequestTracker:
logger.info("Finished request %s.", request_id)
self.abort_request(request_id)
def add_request(self, request_id: str,
def add_request(self,
request_id: str,
*,
verbose: bool = False,
**engine_add_request_kwargs) -> AsyncStream:
"""Add a request to be sent to the engine on the next background
loop iteration."""
@@ -166,6 +169,9 @@ class RequestTracker:
self.new_requests_event.set()
if verbose:
logger.info("Added request %s.", request_id)
return stream
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
@@ -299,14 +305,14 @@ class _AsyncLLMEngine(LLMEngine):
return self.input_processor(llm_inputs)
async def add_request_async(
self,
request_id: str,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
self,
request_id: str,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
@@ -353,8 +359,6 @@ class AsyncLLMEngine:
async frontend will be executed in a separate process as the
model workers.
log_requests: Whether to log the requests.
max_log_len: Maximum number of prompt characters or prompt ID numbers
being printed in log.
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
*args: Arguments for :class:`LLMEngine`.
@@ -368,13 +372,11 @@ class AsyncLLMEngine:
engine_use_ray: bool,
*args,
log_requests: bool = True,
max_log_len: Optional[int] = None,
start_engine_loop: bool = True,
**kwargs) -> None:
self.worker_use_ray = worker_use_ray
self.engine_use_ray = engine_use_ray
self.log_requests = log_requests
self.max_log_len = max_log_len
self.engine = self._init_engine(*args, **kwargs)
self.background_loop: Optional[asyncio.Future] = None
@@ -468,7 +470,6 @@ class AsyncLLMEngine:
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
stat_loggers=stat_loggers,
@@ -667,30 +668,9 @@ class AsyncLLMEngine:
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncStream:
if self.log_requests:
if isinstance(inputs, str):
shortened_prompt = inputs
shortened_token_ids = None
else:
shortened_prompt = inputs.get("prompt")
shortened_token_ids = inputs.get("prompt_token_ids")
max_log_len = self.max_log_len
if max_log_len is not None:
if shortened_prompt is not None:
shortened_prompt = shortened_prompt[:max_log_len]
if shortened_token_ids is not None:
shortened_token_ids = shortened_token_ids[:max_log_len]
logger.info(
"Received request %s: prompt: %r, "
"params: %s, prompt_token_ids: %s, "
"lora_request: %s.", request_id, shortened_prompt, params,
shortened_token_ids, lora_request)
if not self.is_running:
if self.start_engine_loop:
self.start_background_loop()
@@ -706,6 +686,7 @@ class AsyncLLMEngine:
stream = self._request_tracker.add_request(
request_id,
verbose=self.log_requests,
inputs=inputs,
params=params,
arrival_time=arrival_time,
@@ -721,7 +702,7 @@ class AsyncLLMEngine:
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.
@@ -804,7 +785,7 @@ class AsyncLLMEngine:
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncIterator[EmbeddingRequestOutput]:
"""Generate outputs for a request from an embedding model.
@@ -882,7 +863,7 @@ class AsyncLLMEngine:
params: Union[SamplingParams, PoolingParams],
*,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Common logic to process requests with SamplingParams or