[Core] Async scheduling + structured outputs compatibility (#26866)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-10-31 17:35:04 -07:00
committed by GitHub
parent df334868ca
commit 0cdbe7b744
25 changed files with 419 additions and 191 deletions

View File

@@ -16,7 +16,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.tasks import SupportedTask
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
@@ -187,28 +187,44 @@ class Executor(ABC):
@overload
def execute_model(
self,
scheduler_output: SchedulerOutput,
non_block: Literal[False] = False,
) -> ModelRunnerOutput:
self, scheduler_output: SchedulerOutput, non_block: Literal[False] = False
) -> ModelRunnerOutput | None:
pass
@overload
def execute_model(
self,
scheduler_output: SchedulerOutput,
non_block: Literal[True] = True,
) -> Future[ModelRunnerOutput]:
self, scheduler_output: SchedulerOutput, non_block: Literal[True] = True
) -> Future[ModelRunnerOutput | None]:
pass
def execute_model(
self, scheduler_output: SchedulerOutput, non_block: bool = False
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
output = self.collective_rpc( # type: ignore[call-overload]
"execute_model", args=(scheduler_output,), non_block=non_block
)
return output[0]
@overload
def sample_tokens(
self, grammar_output: GrammarOutput | None, non_block: Literal[False] = False
) -> ModelRunnerOutput:
pass
@overload
def sample_tokens(
self, grammar_output: GrammarOutput | None, non_block: Literal[True] = True
) -> Future[ModelRunnerOutput]:
pass
def sample_tokens(
self, grammar_output: GrammarOutput | None, non_block: bool = False
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
output = self.collective_rpc( # type: ignore[call-overload]
"sample_tokens", args=(grammar_output,), non_block=non_block
)
return output[0]
def execute_dummy_batch(self) -> None:
self.collective_rpc("execute_dummy_batch")

View File

@@ -46,7 +46,7 @@ from vllm.utils.system_utils import (
get_mp_context,
set_process_title,
)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.executor.abstract import Executor, FailureCallback
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerWrapperBase
@@ -132,15 +132,12 @@ class MultiprocExecutor(Executor):
uw.death_writer.close()
self._ensure_worker_termination([uw.proc for uw in unready_workers])
# For pipeline parallel, we use a thread pool for asynchronous
# execute_model.
if self.max_concurrent_batches > 1:
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue
# _async_aggregate_workers_output also assumes a single IO thread
self.io_thread_pool = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="mp_exec_io"
)
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue.
# _async_aggregate_workers_output also assumes a single IO thread.
self.io_thread_pool = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="mp_exec_io"
)
self.output_rank = self._get_output_rank()
self.has_connector = self.vllm_config.kv_transfer_config is not None
@@ -180,15 +177,27 @@ class MultiprocExecutor(Executor):
self.failure_callback = callback
def execute_model( # type: ignore[override]
self,
scheduler_output: SchedulerOutput,
non_block: bool = False,
self, scheduler_output: SchedulerOutput, non_block: bool = False
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
return self._execute_with_aggregation(
"execute_model", scheduler_output, non_block=non_block
)
def sample_tokens( # type: ignore[override]
self, grammar_output: GrammarOutput | None, non_block: bool = False
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
return self._execute_with_aggregation( # type: ignore[return-value]
"sample_tokens", grammar_output, non_block=non_block
)
def _execute_with_aggregation(
self, method: str, *args, non_block: bool = False
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
if not self.has_connector:
# get output only from a single worker (output_rank)
(output,) = self.collective_rpc(
"execute_model",
args=(scheduler_output,),
method,
args=args,
unique_reply_rank=self.output_rank,
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
@@ -197,8 +206,8 @@ class MultiprocExecutor(Executor):
# get output from all workers
outputs = self.collective_rpc(
"execute_model",
args=(scheduler_output,),
method,
args=args,
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
)

View File

@@ -19,7 +19,7 @@ from vllm.utils.network_utils import (
get_ip,
get_open_port,
)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.ray_utils import (
@@ -41,6 +41,9 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
COMPLETED_NONE_FUTURE: Future[ModelRunnerOutput | None] = Future()
COMPLETED_NONE_FUTURE.set_result(None)
@dataclass
class RayWorkerMetaData:
@@ -96,6 +99,8 @@ class RayDistributedExecutor(Executor):
# KV connector setup
self.has_connector = self.vllm_config.kv_transfer_config is not None
self.scheduler_output: SchedulerOutput | None = None
@property
def max_concurrent_batches(self) -> int:
"""Ray distributed executor supports pipeline parallelism,
@@ -381,22 +386,46 @@ class RayDistributedExecutor(Executor):
self.shutdown()
def execute_model( # type: ignore[override]
self, scheduler_output: SchedulerOutput, non_block: bool = False
self,
scheduler_output: SchedulerOutput,
non_block: bool = False,
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
if self.scheduler_output is not None:
raise RuntimeError(
"State error: sample_tokens() must be called "
"after execute_model() returns None."
)
self.scheduler_output = scheduler_output
return COMPLETED_NONE_FUTURE if non_block else None
def sample_tokens( # type: ignore[override]
self,
grammar_output: "GrammarOutput | None",
non_block: bool = False,
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
"""Execute the model on the Ray workers.
The scheduler output to use should have been provided in
a prior call to execute_model().
Args:
scheduler_output: The scheduler output to execute.
grammar_output: The structured outputs grammar bitmask, if applicable.
non_block: If True, the method will return a Future.
Returns:
The model runner output.
"""
scheduler_output = self.scheduler_output
if scheduler_output is None:
return None # noqa
self.scheduler_output = None
# Build the compiled DAG for the first time.
if self.forward_dag is None: # type: ignore
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
refs = self.forward_dag.execute(scheduler_output) # type: ignore
refs = self.forward_dag.execute((scheduler_output, grammar_output)) # type: ignore
if not self.has_connector:
# Get output only from a single worker (output_rank)

View File

@@ -19,7 +19,7 @@ from vllm.v1.outputs import AsyncModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerWrapperBase
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
logger = init_logger(__name__)
@@ -82,36 +82,41 @@ try:
def execute_model_ray(
self,
scheduler_output: Union[
"SchedulerOutput", tuple["SchedulerOutput", "IntermediateTensors"]
],
execute_model_input: tuple["SchedulerOutput", "GrammarOutput"]
| tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
) -> Union[
"ModelRunnerOutput", tuple["SchedulerOutput", "IntermediateTensors"]
"ModelRunnerOutput",
tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
]:
# This method is used by Ray Compiled Graph to execute the model,
# and it needs a special logic of self.setup_device_if_necessary()
self.setup_device_if_necessary()
assert self.worker is not None, "Worker is not initialized"
if isinstance(scheduler_output, tuple):
scheduler_output, intermediate_tensors = scheduler_output
if len(execute_model_input) == 3:
scheduler_output, grammar_output, intermediate_tensors = (
execute_model_input
)
else:
scheduler_output, intermediate_tensors = scheduler_output, None
scheduler_output, grammar_output = execute_model_input
intermediate_tensors = None
assert self.worker.model_runner is not None
output = self.worker.model_runner.execute_model(
scheduler_output, intermediate_tensors
)
if isinstance(output, IntermediateTensors):
output = scheduler_output, output
output = scheduler_output, grammar_output, output
elif not get_pp_group().is_last_rank:
# Case where there are no scheduled requests
# but may still be finished requests.
assert not output or not output.req_ids
output = scheduler_output, None
# Ensure outputs crossing Ray compiled DAG are serializable.
# AsyncModelRunnerOutput holds CUDA events and cannot be
# pickled.
if isinstance(output, AsyncModelRunnerOutput):
output = output.get_output()
output = scheduler_output, grammar_output, None
elif output is None:
output = self.worker.model_runner.sample_tokens(grammar_output)
# Ensure outputs crossing Ray compiled DAG are serializable.
# AsyncModelRunnerOutput holds CUDA events and cannot be
# pickled.
if isinstance(output, AsyncModelRunnerOutput):
output = output.get_output()
return output
def override_env_vars(self, vars: dict[str, str]):