Allow AsyncLLMEngine.generate to target a specific DP rank (#19102)

Signed-off-by: Jon Swenson <jmswen@gmail.com>
This commit is contained in:
jmswen
2025-06-04 08:26:47 -07:00
committed by GitHub
parent 8f4ffbd373
commit c8dcc15921
10 changed files with 97 additions and 5 deletions

View File

@@ -442,6 +442,7 @@ class _AsyncLLMEngine(LLMEngine):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> None:
...
@@ -456,6 +457,7 @@ class _AsyncLLMEngine(LLMEngine):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> None:
...
@@ -473,6 +475,7 @@ class _AsyncLLMEngine(LLMEngine):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
@@ -902,6 +905,7 @@ class AsyncLLMEngine(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, PoolingRequestOutput], None]]:
...
@@ -917,6 +921,7 @@ class AsyncLLMEngine(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, PoolingRequestOutput], None]]:
...
@@ -935,6 +940,7 @@ class AsyncLLMEngine(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
@@ -967,6 +973,7 @@ class AsyncLLMEngine(EngineClient):
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
data_parallel_rank=data_parallel_rank,
)
return stream.generator()
@@ -980,6 +987,7 @@ class AsyncLLMEngine(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.
@@ -999,7 +1007,8 @@ class AsyncLLMEngine(EngineClient):
for generation, if any.
priority: The priority of the request.
Only applicable with priority scheduling.
data_parallel_rank: The (global) data parallel rank that must
handle this request. Only applicable if DP is enabled.
Yields:
The output `RequestOutput` objects from the LLMEngine
for the request.
@@ -1057,6 +1066,7 @@ class AsyncLLMEngine(EngineClient):
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
data_parallel_rank=data_parallel_rank,
):
yield LLMEngine.validate_output(output, RequestOutput)
except asyncio.CancelledError: