diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 2c50ea15f..e8f7e051b 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -1003,15 +1003,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): hidden_states, input_batch, kv_connector_output = self.execute_model_state self.execute_model_state = None # type: ignore - # Non-last PP rank: hidden_states is None because this rank produced - # IntermediateTensors instead of final hidden states. Receive the - # sampled tokens broadcast by the last rank and update local state. if not self.is_last_pp_rank: - received = pp_receive( + # Non-last PP rank: hidden_states is None because this rank produced + # IntermediateTensors instead of final hidden states. Receive the + # sampled tokens broadcast from the last rank and update local state. + sampled, num_sampled, num_rejected = pp_receive( input_batch.num_reqs, max_sample_len=self.num_speculative_steps + 1 ) - assert received is not None - sampled, num_sampled, num_rejected = received self.postprocess(input_batch, sampled, num_sampled, num_rejected) return None @@ -1020,8 +1018,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): hidden_states, input_batch, grammar_output ) - # Broadcast to non-last PP ranks (handles spec decode multi-token). if self.use_pp: + # Broadcast to non-last PP ranks (handles spec decode multi-token). pp_broadcast(sampler_output.sampled_token_ids, num_sampled, num_rejected) prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs( diff --git a/vllm/v1/worker/gpu/pp_utils.py b/vllm/v1/worker/gpu/pp_utils.py index 8cf868b2f..bf379b5fb 100644 --- a/vllm/v1/worker/gpu/pp_utils.py +++ b/vllm/v1/worker/gpu/pp_utils.py @@ -13,8 +13,7 @@ def pp_broadcast( num_rejected: torch.Tensor, ) -> None: pp = get_pp_group() - if not pp.is_last_rank: - return + assert pp.is_last_rank assert sampled_token_ids.dtype == torch.int64 torch.distributed.broadcast( @@ -27,10 +26,9 @@ def pp_broadcast( def pp_receive( num_reqs: int, max_sample_len: int = 1 -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: pp = get_pp_group() - if pp.is_last_rank: - return None + assert not pp.is_last_rank sampled_tokens = torch.empty( num_reqs, max_sample_len, dtype=torch.int64, device=pp.device