[Frontend] Refactor prompt processing (#4028)
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
import asyncio
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
|
||||
Set, Tuple, Type, Union)
|
||||
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping,
|
||||
Optional, Set, Tuple, Type, Union)
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
@@ -151,7 +151,10 @@ class RequestTracker:
|
||||
logger.info("Finished request %s.", request_id)
|
||||
self.abort_request(request_id)
|
||||
|
||||
def add_request(self, request_id: str,
|
||||
def add_request(self,
|
||||
request_id: str,
|
||||
*,
|
||||
verbose: bool = False,
|
||||
**engine_add_request_kwargs) -> AsyncStream:
|
||||
"""Add a request to be sent to the engine on the next background
|
||||
loop iteration."""
|
||||
@@ -166,6 +169,9 @@ class RequestTracker:
|
||||
|
||||
self.new_requests_event.set()
|
||||
|
||||
if verbose:
|
||||
logger.info("Added request %s.", request_id)
|
||||
|
||||
return stream
|
||||
|
||||
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
|
||||
@@ -299,14 +305,14 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
return self.input_processor(llm_inputs)
|
||||
|
||||
async def add_request_async(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: PromptInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: PromptInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> None:
|
||||
if lora_request is not None and not self.lora_config:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
@@ -353,8 +359,6 @@ class AsyncLLMEngine:
|
||||
async frontend will be executed in a separate process as the
|
||||
model workers.
|
||||
log_requests: Whether to log the requests.
|
||||
max_log_len: Maximum number of prompt characters or prompt ID numbers
|
||||
being printed in log.
|
||||
start_engine_loop: If True, the background task to run the engine
|
||||
will be automatically started in the generate call.
|
||||
*args: Arguments for :class:`LLMEngine`.
|
||||
@@ -368,13 +372,11 @@ class AsyncLLMEngine:
|
||||
engine_use_ray: bool,
|
||||
*args,
|
||||
log_requests: bool = True,
|
||||
max_log_len: Optional[int] = None,
|
||||
start_engine_loop: bool = True,
|
||||
**kwargs) -> None:
|
||||
self.worker_use_ray = worker_use_ray
|
||||
self.engine_use_ray = engine_use_ray
|
||||
self.log_requests = log_requests
|
||||
self.max_log_len = max_log_len
|
||||
self.engine = self._init_engine(*args, **kwargs)
|
||||
|
||||
self.background_loop: Optional[asyncio.Future] = None
|
||||
@@ -468,7 +470,6 @@ class AsyncLLMEngine:
|
||||
executor_class=executor_class,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
max_log_len=engine_args.max_log_len,
|
||||
start_engine_loop=start_engine_loop,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
@@ -667,30 +668,9 @@ class AsyncLLMEngine:
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> AsyncStream:
|
||||
if self.log_requests:
|
||||
if isinstance(inputs, str):
|
||||
shortened_prompt = inputs
|
||||
shortened_token_ids = None
|
||||
else:
|
||||
shortened_prompt = inputs.get("prompt")
|
||||
shortened_token_ids = inputs.get("prompt_token_ids")
|
||||
|
||||
max_log_len = self.max_log_len
|
||||
if max_log_len is not None:
|
||||
if shortened_prompt is not None:
|
||||
shortened_prompt = shortened_prompt[:max_log_len]
|
||||
if shortened_token_ids is not None:
|
||||
shortened_token_ids = shortened_token_ids[:max_log_len]
|
||||
|
||||
logger.info(
|
||||
"Received request %s: prompt: %r, "
|
||||
"params: %s, prompt_token_ids: %s, "
|
||||
"lora_request: %s.", request_id, shortened_prompt, params,
|
||||
shortened_token_ids, lora_request)
|
||||
|
||||
if not self.is_running:
|
||||
if self.start_engine_loop:
|
||||
self.start_background_loop()
|
||||
@@ -706,6 +686,7 @@ class AsyncLLMEngine:
|
||||
|
||||
stream = self._request_tracker.add_request(
|
||||
request_id,
|
||||
verbose=self.log_requests,
|
||||
inputs=inputs,
|
||||
params=params,
|
||||
arrival_time=arrival_time,
|
||||
@@ -721,7 +702,7 @@ class AsyncLLMEngine:
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> AsyncIterator[RequestOutput]:
|
||||
"""Generate outputs for a request.
|
||||
@@ -804,7 +785,7 @@ class AsyncLLMEngine:
|
||||
pooling_params: PoolingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> AsyncIterator[EmbeddingRequestOutput]:
|
||||
"""Generate outputs for a request from an embedding model.
|
||||
|
||||
@@ -882,7 +863,7 @@ class AsyncLLMEngine:
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
*,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
"""Common logic to process requests with SamplingParams or
|
||||
|
||||
Reference in New Issue
Block a user