diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 0ca0e828b..8cca3cb46 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -57,7 +57,7 @@ from vllm.v1.worker.gpu.kv_connector import ( from vllm.v1.worker.gpu.lora_utils import LoraState from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState -from vllm.v1.worker.gpu.pp_handler import PPHandler +from vllm.v1.worker.gpu.pp_utils import pp_broadcast, pp_receive from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker from vllm.v1.worker.gpu.sample.sampler import Sampler @@ -185,11 +185,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): if self.use_pp: self.is_first_pp_rank = get_pp_group().is_first_rank self.is_last_pp_rank = get_pp_group().is_last_rank - self.pp_handler: PPHandler | None = PPHandler(self.device) else: self.is_first_pp_rank = True self.is_last_pp_rank = True - self.pp_handler = None def update_max_model_len(self, max_model_len: int) -> None: self.max_model_len = max_model_len @@ -987,8 +985,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # IntermediateTensors instead of final hidden states. Receive the # sampled tokens broadcast by the last rank and update local state. if not self.is_last_pp_rank: - assert self.pp_handler is not None - received = self.pp_handler.maybe_receive_sampled_tokens( + received = pp_receive( input_batch.num_reqs, max_sample_len=self.num_speculative_steps + 1 ) assert received is not None @@ -1003,10 +1000,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Broadcast to non-last PP ranks (handles spec decode multi-token). if self.use_pp: - assert self.pp_handler is not None - self.pp_handler.maybe_broadcast_sampled_tokens( - sampler_output, num_sampled, num_rejected - ) + pp_broadcast(sampler_output.sampled_token_ids, num_sampled, num_rejected) prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs( self.model.compute_logits, diff --git a/vllm/v1/worker/gpu/pp_handler.py b/vllm/v1/worker/gpu/pp_handler.py deleted file mode 100644 index e98ffd89b..000000000 --- a/vllm/v1/worker/gpu/pp_handler.py +++ /dev/null @@ -1,109 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Pipeline Parallelism handler for V2 Model Runner.""" - -import torch - -from vllm.distributed.parallel_state import get_pp_group -from vllm.v1.worker.gpu.sample.output import SamplerOutput - - -class PPHandler: - """Pipeline parallelism handler for Model Runner V2. - - Manages sampled token synchronization between PP ranks. - Only instantiated when PP is enabled (pp_size > 1). - """ - - def __init__(self, device: torch.device): - self.device = device - - def maybe_broadcast_sampled_tokens( - self, - sampler_output: SamplerOutput, - num_sampled: torch.Tensor, - num_rejected: torch.Tensor, - ) -> None: - """Broadcast sampled tokens from the last PP rank to all other ranks. - - No-ops if this is not the last rank. - - Broadcasts sampled_token_ids [num_reqs, max_sample_len], num_sampled - [num_reqs], and num_rejected [num_reqs] to support both regular decode - and speculative decoding. - - Args: - sampler_output: SamplerOutput from sampling. - num_sampled: Number of accepted tokens per request. - num_rejected: Number of rejected tokens per request. - """ - pp = get_pp_group() - if not pp.is_last_rank: - return - - torch.distributed.broadcast( - sampler_output.sampled_token_ids.contiguous(), - src=pp.last_rank, - group=pp.device_group, - ) - # NOTE: num_sampled/num_rejected are only needed - # for speculative decoding. - torch.distributed.broadcast( - num_sampled.contiguous(), - src=pp.last_rank, - group=pp.device_group, - ) - torch.distributed.broadcast( - num_rejected.contiguous(), - src=pp.last_rank, - group=pp.device_group, - ) - - def maybe_receive_sampled_tokens( - self, - num_reqs: int, - max_sample_len: int = 1, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None: - """Receive sampled tokens broadcast by the last PP rank. - - Returns None if this is the last rank (which samples, not receives). - - Args: - num_reqs: Number of requests in the batch. - max_sample_len: Maximum number of tokens sampled per request - (1 for regular decode, >1 for speculative decoding). - - Returns: - None if called on last rank. - Otherwise, tuple of (sampled_tokens, num_sampled, num_rejected): - - sampled_tokens: shape [num_reqs, max_sample_len] - - num_sampled: shape [num_reqs] - - num_rejected: shape [num_reqs] - """ - pp = get_pp_group() - if pp.is_last_rank: - return None - - sampled_tokens = torch.empty( - num_reqs, max_sample_len, dtype=torch.int64, device=self.device - ) - torch.distributed.broadcast( - sampled_tokens, - src=pp.last_rank, - group=pp.device_group, - ) - # NOTE: num_sampled/num_rejected are only needed - # for speculative decoding. - num_sampled = torch.empty(num_reqs, dtype=torch.int32, device=self.device) - torch.distributed.broadcast( - num_sampled, - src=pp.last_rank, - group=pp.device_group, - ) - num_rejected = torch.empty(num_reqs, dtype=torch.int32, device=self.device) - torch.distributed.broadcast( - num_rejected, - src=pp.last_rank, - group=pp.device_group, - ) - return sampled_tokens, num_sampled, num_rejected diff --git a/vllm/v1/worker/gpu/pp_utils.py b/vllm/v1/worker/gpu/pp_utils.py new file mode 100644 index 000000000..8cf868b2f --- /dev/null +++ b/vllm/v1/worker/gpu/pp_utils.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Pipeline Parallelism utils for V2 Model Runner.""" + +import torch + +from vllm.distributed.parallel_state import get_pp_group + + +def pp_broadcast( + sampled_token_ids: torch.Tensor, + num_sampled: torch.Tensor, + num_rejected: torch.Tensor, +) -> None: + pp = get_pp_group() + if not pp.is_last_rank: + return + + assert sampled_token_ids.dtype == torch.int64 + torch.distributed.broadcast( + sampled_token_ids.contiguous(), src=pp.last_rank, group=pp.device_group + ) + + combined = torch.stack((num_sampled, num_rejected), dim=0) + torch.distributed.broadcast(combined, src=pp.last_rank, group=pp.device_group) + + +def pp_receive( + num_reqs: int, max_sample_len: int = 1 +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None: + pp = get_pp_group() + if pp.is_last_rank: + return None + + sampled_tokens = torch.empty( + num_reqs, max_sample_len, dtype=torch.int64, device=pp.device + ) + torch.distributed.broadcast(sampled_tokens, src=pp.last_rank, group=pp.device_group) + + combined = torch.empty(2, num_reqs, dtype=torch.int32, device=pp.device) + torch.distributed.broadcast(combined, src=pp.last_rank, group=pp.device_group) + num_sampled, num_rejected = combined.unbind(dim=0) + return sampled_tokens, num_sampled, num_rejected