diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 46c43727c..273cecd3b 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -57,7 +57,7 @@ from vllm.v1.worker.gpu.kv_connector import ( from vllm.v1.worker.gpu.lora_utils import LoraState from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState -from vllm.v1.worker.gpu.pp_handler import PPHandler, get_pp_handler +from vllm.v1.worker.gpu.pp_handler import PPHandler from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker from vllm.v1.worker.gpu.sample.sampler import Sampler @@ -184,9 +184,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Pipeline parallelism. self.use_pp = self.parallel_config.pipeline_parallel_size > 1 - self.pp_handler: PPHandler | None = ( - get_pp_handler(self.parallel_config) if self.use_pp else None - ) + if self.use_pp: + self.is_first_pp_rank = get_pp_group().is_first_rank + self.is_last_pp_rank = get_pp_group().is_last_rank + self.pp_handler: PPHandler | None = PPHandler(self.device) + else: + self.is_first_pp_rank = True + self.is_last_pp_rank = True + self.pp_handler = None def update_max_model_len(self, max_model_len: int) -> None: self.max_model_len = max_model_len @@ -318,7 +323,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # For non-first PP ranks, create dummy intermediate_tensors. intermediate_tensors = None - if self.use_pp and not get_pp_group().is_first_rank: + if not self.is_first_pp_rank: intermediate_tensors = self.model.make_empty_intermediate_tensors( batch_size=num_tokens, dtype=self.model_config.dtype, @@ -335,7 +340,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.kv_connector.set_disabled(False) # Non-last PP ranks don't produce output for sampling. - if self.use_pp and not get_pp_group().is_last_rank: + if not self.is_last_pp_rank: return None, None assert self.execute_model_state is not None @@ -373,20 +378,23 @@ class GPUModelRunner(LoRAModelRunnerMixin): hidden_states, sample_hidden_states = self._dummy_run( self.max_num_tokens, skip_attn=True ) + # Only run sampler on last PP rank (non-last ranks return None). - if not self.use_pp or get_pp_group().is_last_rank: + if self.is_last_pp_rank: assert sample_hidden_states is not None self._dummy_sampler_run(sample_hidden_states) - if self.do_spec_decode: - num_tokens_across_dp = make_num_tokens_across_dp( - self.parallel_config.data_parallel_size, self.max_num_tokens - ) - self.speculator.run_model( - self.max_num_tokens, - attn_metadata=None, - slot_mappings=None, - num_tokens_across_dp=num_tokens_across_dp, - ) + + if self.do_spec_decode: + num_tokens_across_dp = make_num_tokens_across_dp( + self.parallel_config.data_parallel_size, self.max_num_tokens + ) + self.speculator.run_model( + self.max_num_tokens, + attn_metadata=None, + slot_mappings=None, + num_tokens_across_dp=num_tokens_across_dp, + ) + torch.cuda.synchronize() del hidden_states, sample_hidden_states gc.collect() @@ -890,9 +898,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self._set_active_loras(*lora_inputs) # Only first PP rank prepares multimodal embeddings. - if self.supports_mm_inputs and ( - not self.use_pp or get_pp_group().is_first_rank - ): + if self.supports_mm_inputs and self.is_first_pp_rank: mm_embeds, is_mm_embed = self.get_mm_embeddings( scheduler_output.scheduled_encoder_inputs, input_batch ) @@ -935,6 +941,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): assert input_batch.mrope_positions is not None positions = input_batch.mrope_positions + if self.is_first_pp_rank: + input_ids = input_batch.input_ids + inputs_embeds = input_batch.inputs_embeds + assert intermediate_tensors is None + else: + input_ids = None + inputs_embeds = None + assert intermediate_tensors is not None + with set_forward_context( input_batch.attn_metadata, self.vllm_config, @@ -945,25 +960,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): slot_mapping=input_batch.slot_mappings, ): self.kv_connector.pre_forward(scheduler_output) - if self.use_pp and not get_pp_group().is_first_rank: - # Non-first PP rank: forward with intermediate tensors. - assert intermediate_tensors is not None - hidden_states = self.model( - input_ids=None, - positions=positions, - inputs_embeds=None, - intermediate_tensors=intermediate_tensors, - ) - else: - hidden_states = self.model( - input_ids=input_batch.input_ids, - positions=positions, - inputs_embeds=input_batch.inputs_embeds, - ) + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + ) kv_connector_output = self.kv_connector.post_forward(scheduler_output) - if self.use_pp and not get_pp_group().is_last_rank: + if not self.is_last_pp_rank: # Non-last PP rank: return IntermediateTensors for sending. assert isinstance(hidden_states, IntermediateTensors) hidden_states.kv_connector_output = kv_connector_output @@ -986,16 +992,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Non-last PP rank: hidden_states is None because this rank produced # IntermediateTensors instead of final hidden states. Receive the # sampled tokens broadcast by the last rank and update local state. - if self.use_pp and not get_pp_group().is_last_rank: + if not self.is_last_pp_rank: assert self.pp_handler is not None received = self.pp_handler.maybe_receive_sampled_tokens( - input_batch.num_reqs, - self.device, - max_sample_len=self.num_speculative_steps + 1, + input_batch.num_reqs, max_sample_len=self.num_speculative_steps + 1 ) - if received is not None: - sampled, num_sampled, num_rejected = received - self.postprocess(input_batch, sampled, num_sampled, num_rejected) + assert received is not None + sampled, num_sampled, num_rejected = received + self.postprocess(input_batch, sampled, num_sampled, num_rejected) return None # Last rank: sample tokens diff --git a/vllm/v1/worker/gpu/pp_handler.py b/vllm/v1/worker/gpu/pp_handler.py index a254f577f..b4faec348 100644 --- a/vllm/v1/worker/gpu/pp_handler.py +++ b/vllm/v1/worker/gpu/pp_handler.py @@ -15,6 +15,9 @@ class PPHandler: Only instantiated when PP is enabled (pp_size > 1). """ + def __init__(self, device: torch.device): + self.device = device + def maybe_broadcast_sampled_tokens( self, sampler_output: SamplerOutput, @@ -59,7 +62,6 @@ class PPHandler: def maybe_receive_sampled_tokens( self, num_reqs: int, - device: torch.device, max_sample_len: int = 1, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None: """Receive sampled tokens broadcast by the last PP rank. @@ -84,7 +86,7 @@ class PPHandler: return None sampled_tokens = torch.empty( - num_reqs, max_sample_len, dtype=torch.int64, device=device + num_reqs, max_sample_len, dtype=torch.int64, device=self.device ) torch.distributed.broadcast( sampled_tokens, @@ -93,27 +95,16 @@ class PPHandler: ) # NOTE: num_sampled/num_rejected are only needed # for speculative decoding. - num_sampled = torch.empty(num_reqs, dtype=torch.int32, device=device) + num_sampled = torch.empty(num_reqs, dtype=torch.int32, device=self.device) torch.distributed.broadcast( num_sampled, src=pp.last_rank, group=pp.device_group, ) - num_rejected = torch.empty(num_reqs, dtype=torch.int32, device=device) + num_rejected = torch.empty(num_reqs, dtype=torch.int32, device=self.device) torch.distributed.broadcast( num_rejected, src=pp.last_rank, group=pp.device_group, ) return sampled_tokens, num_sampled, num_rejected - - -def get_pp_handler(parallel_config) -> PPHandler: - """Factory function to create PPHandler. - - Must only be called when PP is enabled (pp_size > 1). - """ - assert parallel_config.pipeline_parallel_size > 1, ( - "PPHandler should not be created when pipeline parallelism is disabled." - ) - return PPHandler()