[V1] V1 engine implements parallel sampling (AsyncLLM and LLMEngine) (#10980)
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -24,6 +24,7 @@ from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import cdiv, kill_process_tree
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.engine.parallel_sampling import generate_parallel_sampling_async
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger,
|
||||
@@ -170,7 +171,7 @@ class AsyncLLM(EngineClient):
|
||||
# requests we don't need to send multiple messages to core proc,
|
||||
# and so we don't need multiple streams which then get
|
||||
# re-multiplexed in the API server anyhow.
|
||||
async def generate(
|
||||
async def _generate(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
@@ -241,6 +242,30 @@ class AsyncLLM(EngineClient):
|
||||
await self.abort(request_id)
|
||||
raise
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
kwargs = dict(prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority)
|
||||
if sampling_params.n is None or sampling_params.n == 1:
|
||||
return self._generate(**kwargs)
|
||||
else:
|
||||
# Special handling for parallel sampling requests
|
||||
return generate_parallel_sampling_async(generate=self._generate,
|
||||
**kwargs)
|
||||
|
||||
async def _run_output_handler(self):
|
||||
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user