[Model Runner V2] A bit more PP simplification (#34766)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-02-17 21:39:07 -08:00
committed by GitHub
parent 30ebe0dc3c
commit a49ea5a58f
2 changed files with 8 additions and 12 deletions

View File

@@ -1003,15 +1003,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states, input_batch, kv_connector_output = self.execute_model_state
self.execute_model_state = None # type: ignore
# Non-last PP rank: hidden_states is None because this rank produced
# IntermediateTensors instead of final hidden states. Receive the
# sampled tokens broadcast by the last rank and update local state.
if not self.is_last_pp_rank:
received = pp_receive(
# Non-last PP rank: hidden_states is None because this rank produced
# IntermediateTensors instead of final hidden states. Receive the
# sampled tokens broadcast from the last rank and update local state.
sampled, num_sampled, num_rejected = pp_receive(
input_batch.num_reqs, max_sample_len=self.num_speculative_steps + 1
)
assert received is not None
sampled, num_sampled, num_rejected = received
self.postprocess(input_batch, sampled, num_sampled, num_rejected)
return None
@@ -1020,8 +1018,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states, input_batch, grammar_output
)
# Broadcast to non-last PP ranks (handles spec decode multi-token).
if self.use_pp:
# Broadcast to non-last PP ranks (handles spec decode multi-token).
pp_broadcast(sampler_output.sampled_token_ids, num_sampled, num_rejected)
prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs(

View File

@@ -13,8 +13,7 @@ def pp_broadcast(
num_rejected: torch.Tensor,
) -> None:
pp = get_pp_group()
if not pp.is_last_rank:
return
assert pp.is_last_rank
assert sampled_token_ids.dtype == torch.int64
torch.distributed.broadcast(
@@ -27,10 +26,9 @@ def pp_broadcast(
def pp_receive(
num_reqs: int, max_sample_len: int = 1
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pp = get_pp_group()
if pp.is_last_rank:
return None
assert not pp.is_last_rank
sampled_tokens = torch.empty(
num_reqs, max_sample_len, dtype=torch.int64, device=pp.device