[Feature] Fully support for async scheduling + PP, 30.8% E2E throughput improvement, 31.8% TPOT improvement (#32618)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: Nick Hill <nickhill123@gmail.com> Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -42,6 +42,22 @@ def test_compile_config_repr_succeeds():
|
||||
assert "inductor_passes" in val
|
||||
|
||||
|
||||
def test_async_scheduling_with_pipeline_parallelism_is_allowed():
|
||||
cfg = VllmConfig(
|
||||
scheduler_config=SchedulerConfig(
|
||||
max_model_len=8192,
|
||||
is_encoder_decoder=False,
|
||||
async_scheduling=True,
|
||||
),
|
||||
parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=2,
|
||||
distributed_executor_backend="mp",
|
||||
nnodes=2,
|
||||
),
|
||||
)
|
||||
assert cfg.scheduler_config.async_scheduling is True
|
||||
|
||||
|
||||
@dataclass
|
||||
class _TestConfigFields:
|
||||
a: int
|
||||
|
||||
@@ -127,6 +127,21 @@ def test_schedule_multimodal_requests():
|
||||
assert len(encoder_input) == 1
|
||||
|
||||
|
||||
def test_async_scheduling_pp_allows_rescheduling_with_output_placeholders():
|
||||
"""Async scheduling + PP: allow multi-step in-flight scheduling per request"""
|
||||
scheduler = create_scheduler(async_scheduling=True, pipeline_parallel_size=2)
|
||||
(req,) = create_requests(num_requests=1, num_tokens=8)
|
||||
scheduler.add_request(req)
|
||||
|
||||
_ = scheduler.schedule()
|
||||
assert req.num_output_placeholders > 0
|
||||
|
||||
# before any update_from_output, we still expect the request can be
|
||||
# scheduled again (multi-step in-flight).
|
||||
output = scheduler.schedule()
|
||||
assert req.request_id in output.num_scheduled_tokens
|
||||
|
||||
|
||||
def test_schedule_partial_requests():
|
||||
"""Test scheduling behavior with partial requests.
|
||||
|
||||
|
||||
@@ -344,13 +344,6 @@ class Scheduler(SchedulerInterface):
|
||||
while req_index < len(self.running) and token_budget > 0:
|
||||
request = self.running[req_index]
|
||||
|
||||
# do not schedule another step for the same request while it still has
|
||||
# output placeholders for PP.
|
||||
# TODO: support PP + async scheduling without this limit
|
||||
if self.use_pp and request.num_output_placeholders > 0:
|
||||
req_index += 1
|
||||
continue
|
||||
|
||||
if (
|
||||
request.num_output_placeholders > 0
|
||||
# This is (num_computed_tokens + 1) - (num_output_placeholders - 1).
|
||||
@@ -1003,7 +996,10 @@ class Scheduler(SchedulerInterface):
|
||||
for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)):
|
||||
req_id = req.request_id
|
||||
req_ids.append(req_id)
|
||||
if self.use_pp:
|
||||
# NOTE: In PP+async scheduling, we consume token ids via a direct GPU
|
||||
# broadcast path (`input_batch.prev_sampled_token_ids`), so we can
|
||||
# omit this payload.
|
||||
if self.use_pp and not self.scheduler_config.async_scheduling:
|
||||
# When using PP, the scheduler sends the sampled tokens back,
|
||||
# because there's no direct communication between the first-
|
||||
# stage worker and the last-stage worker. Otherwise, we don't
|
||||
|
||||
@@ -1010,20 +1010,26 @@ class GPUModelRunner(
|
||||
req_state.num_computed_tokens = num_computed_tokens
|
||||
|
||||
if not is_last_rank:
|
||||
# When using PP, the scheduler sends the sampled tokens back,
|
||||
# because there's no direct communication between the first-
|
||||
# stage worker and the last-stage worker.
|
||||
new_token_ids = req_data.new_token_ids[i]
|
||||
# Add the sampled token(s) from the previous step (if any).
|
||||
# This doesn't include "unverified" tokens like spec tokens.
|
||||
num_new_tokens = (
|
||||
num_computed_tokens + len(new_token_ids) - req_state.num_tokens
|
||||
)
|
||||
if num_new_tokens == 1:
|
||||
# Avoid slicing list in most common case.
|
||||
req_state.output_token_ids.append(new_token_ids[-1])
|
||||
elif num_new_tokens > 0:
|
||||
req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:])
|
||||
if not req_data.new_token_ids:
|
||||
# Async scheduled PP: Sampled tokens propagated via GPU broadcast.
|
||||
new_token_ids: list[int] = []
|
||||
else:
|
||||
# Non-async scheduling with PP: The scheduler sends
|
||||
# sampled token ids back because there's no direct communication
|
||||
# between the first-stage worker and the last-stage worker.
|
||||
new_token_ids = req_data.new_token_ids[i]
|
||||
# Add the sampled token(s) from the previous step (if any).
|
||||
# This doesn't include "unverified" tokens like spec tokens.
|
||||
num_new_tokens = (
|
||||
num_computed_tokens + len(new_token_ids) - req_state.num_tokens
|
||||
)
|
||||
if num_new_tokens == 1:
|
||||
# Avoid slicing list in most common case.
|
||||
req_state.output_token_ids.append(new_token_ids[-1])
|
||||
elif num_new_tokens > 0:
|
||||
req_state.output_token_ids.extend(
|
||||
new_token_ids[-num_new_tokens:]
|
||||
)
|
||||
elif num_output_tokens < len(req_state.output_token_ids):
|
||||
# Some output tokens were discarded due to a sync-KV-load
|
||||
# failure. Align the cached state.
|
||||
@@ -3577,7 +3583,9 @@ class GPUModelRunner(
|
||||
self.kv_connector_output = None
|
||||
|
||||
if self.execute_model_state is None:
|
||||
# Nothing to do (PP non-final rank case), output isn't used.
|
||||
# receive sampled token ids from the last PP rank.
|
||||
if self.use_async_scheduling and get_pp_group().world_size > 1:
|
||||
self._pp_receive_prev_sampled_token_ids_to_input_batch()
|
||||
if not kv_connector_output:
|
||||
return None # type: ignore[return-value]
|
||||
|
||||
@@ -3618,6 +3626,12 @@ class GPUModelRunner(
|
||||
self._update_states_after_model_execute(
|
||||
sampler_output.sampled_token_ids, scheduler_output
|
||||
)
|
||||
if self.use_async_scheduling:
|
||||
pp = get_pp_group()
|
||||
if pp.world_size > 1 and pp.is_last_rank:
|
||||
self._pp_broadcast_prev_sampled_token_ids(
|
||||
sampler_output.sampled_token_ids
|
||||
)
|
||||
|
||||
self._draft_token_ids = None
|
||||
self._draft_token_req_ids = None
|
||||
@@ -3753,6 +3767,45 @@ class GPUModelRunner(
|
||||
|
||||
return async_output
|
||||
|
||||
def _pp_broadcast_prev_sampled_token_ids(
|
||||
self, sampled_token_ids: torch.Tensor
|
||||
) -> None:
|
||||
"""Broadcast sampled token ids (GPU) from last PP stage"""
|
||||
pp = get_pp_group()
|
||||
assert pp.is_last_rank
|
||||
# `prev_sampled_token_ids` is expected to have shape [num_reqs, 1].
|
||||
assert sampled_token_ids.dim() == 2 and sampled_token_ids.shape[-1] == 1, (
|
||||
"PP+async expects sampled_token_ids to have shape [num_reqs, 1]"
|
||||
)
|
||||
torch.distributed.broadcast(
|
||||
sampled_token_ids, src=pp.rank, group=pp.device_group
|
||||
)
|
||||
|
||||
def _pp_receive_prev_sampled_token_ids_to_input_batch(self) -> None:
|
||||
"""Receive sampled token ids broadcast from last PP stage"""
|
||||
pp = get_pp_group()
|
||||
assert not pp.is_last_rank
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
# `prev_sampled_token_ids` is expected to have shape [num_reqs, 1].
|
||||
recv = torch.empty((num_reqs, 1), dtype=torch.int32, device=self.device)
|
||||
torch.distributed.broadcast(recv, src=pp.last_rank, group=pp.device_group)
|
||||
self.input_batch.prev_sampled_token_ids = recv
|
||||
|
||||
# construct `prev_req_id_to_index` here so `_prepare_input_ids`
|
||||
# can map req_id -> previous batch row
|
||||
discard_req_indices = np.nonzero(self.discard_request_mask.np[:num_reqs])[0]
|
||||
discard_req_indices_set = set(discard_req_indices)
|
||||
prev_req_id_to_index: dict[str, int] = {}
|
||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||
if i in discard_req_indices_set:
|
||||
continue
|
||||
prev_req_id_to_index[req_id] = i
|
||||
# PP+async scheduling: advance per-request local cached output length by
|
||||
# appending a placeholder (-1) token id.
|
||||
if (req_state := self.requests.get(req_id)) is not None:
|
||||
req_state.output_token_ids.append(-1)
|
||||
self.input_batch.prev_req_id_to_index = prev_req_id_to_index
|
||||
|
||||
def take_draft_token_ids(self) -> DraftTokenIds | None:
|
||||
if not self.num_spec_tokens or not self._draft_token_req_ids:
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user