From 3e440786afe763e892e12125ee7529f95f141c54 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Wed, 28 Jan 2026 15:30:32 -0500 Subject: [PATCH] [Feature] Fully support for async scheduling + PP, 30.8% E2E throughput improvement, 31.8% TPOT improvement (#32618) Signed-off-by: yewentao256 Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: Nick Hill Co-authored-by: Nick Hill --- tests/test_config.py | 16 ++++++ tests/v1/core/test_scheduler.py | 15 ++++++ vllm/v1/core/sched/scheduler.py | 12 ++--- vllm/v1/worker/gpu_model_runner.py | 83 ++++++++++++++++++++++++------ 4 files changed, 103 insertions(+), 23 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index 8c1bf6c40..1676598b1 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -42,6 +42,22 @@ def test_compile_config_repr_succeeds(): assert "inductor_passes" in val +def test_async_scheduling_with_pipeline_parallelism_is_allowed(): + cfg = VllmConfig( + scheduler_config=SchedulerConfig( + max_model_len=8192, + is_encoder_decoder=False, + async_scheduling=True, + ), + parallel_config=ParallelConfig( + pipeline_parallel_size=2, + distributed_executor_backend="mp", + nnodes=2, + ), + ) + assert cfg.scheduler_config.async_scheduling is True + + @dataclass class _TestConfigFields: a: int diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index acac3753d..d8e9e2e3c 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -127,6 +127,21 @@ def test_schedule_multimodal_requests(): assert len(encoder_input) == 1 +def test_async_scheduling_pp_allows_rescheduling_with_output_placeholders(): + """Async scheduling + PP: allow multi-step in-flight scheduling per request""" + scheduler = create_scheduler(async_scheduling=True, pipeline_parallel_size=2) + (req,) = create_requests(num_requests=1, num_tokens=8) + scheduler.add_request(req) + + _ = scheduler.schedule() + assert req.num_output_placeholders > 0 + + # before any update_from_output, we still expect the request can be + # scheduled again (multi-step in-flight). + output = scheduler.schedule() + assert req.request_id in output.num_scheduled_tokens + + def test_schedule_partial_requests(): """Test scheduling behavior with partial requests. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 30a459386..7f1459fdf 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -344,13 +344,6 @@ class Scheduler(SchedulerInterface): while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] - # do not schedule another step for the same request while it still has - # output placeholders for PP. - # TODO: support PP + async scheduling without this limit - if self.use_pp and request.num_output_placeholders > 0: - req_index += 1 - continue - if ( request.num_output_placeholders > 0 # This is (num_computed_tokens + 1) - (num_output_placeholders - 1). @@ -1003,7 +996,10 @@ class Scheduler(SchedulerInterface): for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)): req_id = req.request_id req_ids.append(req_id) - if self.use_pp: + # NOTE: In PP+async scheduling, we consume token ids via a direct GPU + # broadcast path (`input_batch.prev_sampled_token_ids`), so we can + # omit this payload. + if self.use_pp and not self.scheduler_config.async_scheduling: # When using PP, the scheduler sends the sampled tokens back, # because there's no direct communication between the first- # stage worker and the last-stage worker. Otherwise, we don't diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 60c8d4080..96dab077d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1010,20 +1010,26 @@ class GPUModelRunner( req_state.num_computed_tokens = num_computed_tokens if not is_last_rank: - # When using PP, the scheduler sends the sampled tokens back, - # because there's no direct communication between the first- - # stage worker and the last-stage worker. - new_token_ids = req_data.new_token_ids[i] - # Add the sampled token(s) from the previous step (if any). - # This doesn't include "unverified" tokens like spec tokens. - num_new_tokens = ( - num_computed_tokens + len(new_token_ids) - req_state.num_tokens - ) - if num_new_tokens == 1: - # Avoid slicing list in most common case. - req_state.output_token_ids.append(new_token_ids[-1]) - elif num_new_tokens > 0: - req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:]) + if not req_data.new_token_ids: + # Async scheduled PP: Sampled tokens propagated via GPU broadcast. + new_token_ids: list[int] = [] + else: + # Non-async scheduling with PP: The scheduler sends + # sampled token ids back because there's no direct communication + # between the first-stage worker and the last-stage worker. + new_token_ids = req_data.new_token_ids[i] + # Add the sampled token(s) from the previous step (if any). + # This doesn't include "unverified" tokens like spec tokens. + num_new_tokens = ( + num_computed_tokens + len(new_token_ids) - req_state.num_tokens + ) + if num_new_tokens == 1: + # Avoid slicing list in most common case. + req_state.output_token_ids.append(new_token_ids[-1]) + elif num_new_tokens > 0: + req_state.output_token_ids.extend( + new_token_ids[-num_new_tokens:] + ) elif num_output_tokens < len(req_state.output_token_ids): # Some output tokens were discarded due to a sync-KV-load # failure. Align the cached state. @@ -3577,7 +3583,9 @@ class GPUModelRunner( self.kv_connector_output = None if self.execute_model_state is None: - # Nothing to do (PP non-final rank case), output isn't used. + # receive sampled token ids from the last PP rank. + if self.use_async_scheduling and get_pp_group().world_size > 1: + self._pp_receive_prev_sampled_token_ids_to_input_batch() if not kv_connector_output: return None # type: ignore[return-value] @@ -3618,6 +3626,12 @@ class GPUModelRunner( self._update_states_after_model_execute( sampler_output.sampled_token_ids, scheduler_output ) + if self.use_async_scheduling: + pp = get_pp_group() + if pp.world_size > 1 and pp.is_last_rank: + self._pp_broadcast_prev_sampled_token_ids( + sampler_output.sampled_token_ids + ) self._draft_token_ids = None self._draft_token_req_ids = None @@ -3753,6 +3767,45 @@ class GPUModelRunner( return async_output + def _pp_broadcast_prev_sampled_token_ids( + self, sampled_token_ids: torch.Tensor + ) -> None: + """Broadcast sampled token ids (GPU) from last PP stage""" + pp = get_pp_group() + assert pp.is_last_rank + # `prev_sampled_token_ids` is expected to have shape [num_reqs, 1]. + assert sampled_token_ids.dim() == 2 and sampled_token_ids.shape[-1] == 1, ( + "PP+async expects sampled_token_ids to have shape [num_reqs, 1]" + ) + torch.distributed.broadcast( + sampled_token_ids, src=pp.rank, group=pp.device_group + ) + + def _pp_receive_prev_sampled_token_ids_to_input_batch(self) -> None: + """Receive sampled token ids broadcast from last PP stage""" + pp = get_pp_group() + assert not pp.is_last_rank + num_reqs = self.input_batch.num_reqs + # `prev_sampled_token_ids` is expected to have shape [num_reqs, 1]. + recv = torch.empty((num_reqs, 1), dtype=torch.int32, device=self.device) + torch.distributed.broadcast(recv, src=pp.last_rank, group=pp.device_group) + self.input_batch.prev_sampled_token_ids = recv + + # construct `prev_req_id_to_index` here so `_prepare_input_ids` + # can map req_id -> previous batch row + discard_req_indices = np.nonzero(self.discard_request_mask.np[:num_reqs])[0] + discard_req_indices_set = set(discard_req_indices) + prev_req_id_to_index: dict[str, int] = {} + for i, req_id in enumerate(self.input_batch.req_ids): + if i in discard_req_indices_set: + continue + prev_req_id_to_index[req_id] = i + # PP+async scheduling: advance per-request local cached output length by + # appending a placeholder (-1) token id. + if (req_state := self.requests.get(req_id)) is not None: + req_state.output_token_ids.append(-1) + self.input_batch.prev_req_id_to_index = prev_req_id_to_index + def take_draft_token_ids(self) -> DraftTokenIds | None: if not self.num_spec_tokens or not self._draft_token_req_ids: return None