[Model Runner V2] A bit more PP simplification (#34766)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -1003,15 +1003,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
hidden_states, input_batch, kv_connector_output = self.execute_model_state
|
||||
self.execute_model_state = None # type: ignore
|
||||
|
||||
# Non-last PP rank: hidden_states is None because this rank produced
|
||||
# IntermediateTensors instead of final hidden states. Receive the
|
||||
# sampled tokens broadcast by the last rank and update local state.
|
||||
if not self.is_last_pp_rank:
|
||||
received = pp_receive(
|
||||
# Non-last PP rank: hidden_states is None because this rank produced
|
||||
# IntermediateTensors instead of final hidden states. Receive the
|
||||
# sampled tokens broadcast from the last rank and update local state.
|
||||
sampled, num_sampled, num_rejected = pp_receive(
|
||||
input_batch.num_reqs, max_sample_len=self.num_speculative_steps + 1
|
||||
)
|
||||
assert received is not None
|
||||
sampled, num_sampled, num_rejected = received
|
||||
self.postprocess(input_batch, sampled, num_sampled, num_rejected)
|
||||
return None
|
||||
|
||||
@@ -1020,8 +1018,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
hidden_states, input_batch, grammar_output
|
||||
)
|
||||
|
||||
# Broadcast to non-last PP ranks (handles spec decode multi-token).
|
||||
if self.use_pp:
|
||||
# Broadcast to non-last PP ranks (handles spec decode multi-token).
|
||||
pp_broadcast(sampler_output.sampled_token_ids, num_sampled, num_rejected)
|
||||
|
||||
prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs(
|
||||
|
||||
@@ -13,8 +13,7 @@ def pp_broadcast(
|
||||
num_rejected: torch.Tensor,
|
||||
) -> None:
|
||||
pp = get_pp_group()
|
||||
if not pp.is_last_rank:
|
||||
return
|
||||
assert pp.is_last_rank
|
||||
|
||||
assert sampled_token_ids.dtype == torch.int64
|
||||
torch.distributed.broadcast(
|
||||
@@ -27,10 +26,9 @@ def pp_broadcast(
|
||||
|
||||
def pp_receive(
|
||||
num_reqs: int, max_sample_len: int = 1
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
pp = get_pp_group()
|
||||
if pp.is_last_rank:
|
||||
return None
|
||||
assert not pp.is_last_rank
|
||||
|
||||
sampled_tokens = torch.empty(
|
||||
num_reqs, max_sample_len, dtype=torch.int64, device=pp.device
|
||||
|
||||
Reference in New Issue
Block a user