Allow AsyncLLMEngine.generate to target a specific DP rank (#19102)
Signed-off-by: Jon Swenson <jmswen@gmail.com>
This commit is contained in:
@@ -442,6 +442,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@@ -456,6 +457,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@@ -473,6 +475,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
*,
|
||||
inputs: Optional[PromptType] = None, # DEPRECATED
|
||||
) -> None:
|
||||
@@ -902,6 +905,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> Coroutine[None, None, AsyncGenerator[Union[
|
||||
RequestOutput, PoolingRequestOutput], None]]:
|
||||
...
|
||||
@@ -917,6 +921,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> Coroutine[None, None, AsyncGenerator[Union[
|
||||
RequestOutput, PoolingRequestOutput], None]]:
|
||||
...
|
||||
@@ -935,6 +940,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
*,
|
||||
inputs: Optional[PromptType] = None, # DEPRECATED
|
||||
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
|
||||
@@ -967,6 +973,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
|
||||
return stream.generator()
|
||||
@@ -980,6 +987,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""Generate outputs for a request.
|
||||
|
||||
@@ -999,7 +1007,8 @@ class AsyncLLMEngine(EngineClient):
|
||||
for generation, if any.
|
||||
priority: The priority of the request.
|
||||
Only applicable with priority scheduling.
|
||||
|
||||
data_parallel_rank: The (global) data parallel rank that must
|
||||
handle this request. Only applicable if DP is enabled.
|
||||
Yields:
|
||||
The output `RequestOutput` objects from the LLMEngine
|
||||
for the request.
|
||||
@@ -1057,6 +1066,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
):
|
||||
yield LLMEngine.validate_output(output, RequestOutput)
|
||||
except asyncio.CancelledError:
|
||||
|
||||
Reference in New Issue
Block a user