[Model Runner V2] Minor cleanup for PP (#34666)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Woosuk Kwon
2026-02-16 19:15:31 -08:00
committed by GitHub
parent d74278fb67
commit 04925b2202
2 changed files with 53 additions and 58 deletions

View File

@@ -57,7 +57,7 @@ from vllm.v1.worker.gpu.kv_connector import (
from vllm.v1.worker.gpu.lora_utils import LoraState
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
from vllm.v1.worker.gpu.pp_handler import PPHandler, get_pp_handler
from vllm.v1.worker.gpu.pp_handler import PPHandler
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
from vllm.v1.worker.gpu.sample.sampler import Sampler
@@ -184,9 +184,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Pipeline parallelism.
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.pp_handler: PPHandler | None = (
get_pp_handler(self.parallel_config) if self.use_pp else None
)
if self.use_pp:
self.is_first_pp_rank = get_pp_group().is_first_rank
self.is_last_pp_rank = get_pp_group().is_last_rank
self.pp_handler: PPHandler | None = PPHandler(self.device)
else:
self.is_first_pp_rank = True
self.is_last_pp_rank = True
self.pp_handler = None
def update_max_model_len(self, max_model_len: int) -> None:
self.max_model_len = max_model_len
@@ -318,7 +323,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# For non-first PP ranks, create dummy intermediate_tensors.
intermediate_tensors = None
if self.use_pp and not get_pp_group().is_first_rank:
if not self.is_first_pp_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=num_tokens,
dtype=self.model_config.dtype,
@@ -335,7 +340,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.kv_connector.set_disabled(False)
# Non-last PP ranks don't produce output for sampling.
if self.use_pp and not get_pp_group().is_last_rank:
if not self.is_last_pp_rank:
return None, None
assert self.execute_model_state is not None
@@ -373,20 +378,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states, sample_hidden_states = self._dummy_run(
self.max_num_tokens, skip_attn=True
)
# Only run sampler on last PP rank (non-last ranks return None).
if not self.use_pp or get_pp_group().is_last_rank:
if self.is_last_pp_rank:
assert sample_hidden_states is not None
self._dummy_sampler_run(sample_hidden_states)
if self.do_spec_decode:
num_tokens_across_dp = make_num_tokens_across_dp(
self.parallel_config.data_parallel_size, self.max_num_tokens
)
self.speculator.run_model(
self.max_num_tokens,
attn_metadata=None,
slot_mappings=None,
num_tokens_across_dp=num_tokens_across_dp,
)
if self.do_spec_decode:
num_tokens_across_dp = make_num_tokens_across_dp(
self.parallel_config.data_parallel_size, self.max_num_tokens
)
self.speculator.run_model(
self.max_num_tokens,
attn_metadata=None,
slot_mappings=None,
num_tokens_across_dp=num_tokens_across_dp,
)
torch.cuda.synchronize()
del hidden_states, sample_hidden_states
gc.collect()
@@ -890,9 +898,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self._set_active_loras(*lora_inputs)
# Only first PP rank prepares multimodal embeddings.
if self.supports_mm_inputs and (
not self.use_pp or get_pp_group().is_first_rank
):
if self.supports_mm_inputs and self.is_first_pp_rank:
mm_embeds, is_mm_embed = self.get_mm_embeddings(
scheduler_output.scheduled_encoder_inputs, input_batch
)
@@ -935,6 +941,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert input_batch.mrope_positions is not None
positions = input_batch.mrope_positions
if self.is_first_pp_rank:
input_ids = input_batch.input_ids
inputs_embeds = input_batch.inputs_embeds
assert intermediate_tensors is None
else:
input_ids = None
inputs_embeds = None
assert intermediate_tensors is not None
with set_forward_context(
input_batch.attn_metadata,
self.vllm_config,
@@ -945,25 +960,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
slot_mapping=input_batch.slot_mappings,
):
self.kv_connector.pre_forward(scheduler_output)
if self.use_pp and not get_pp_group().is_first_rank:
# Non-first PP rank: forward with intermediate tensors.
assert intermediate_tensors is not None
hidden_states = self.model(
input_ids=None,
positions=positions,
inputs_embeds=None,
intermediate_tensors=intermediate_tensors,
)
else:
hidden_states = self.model(
input_ids=input_batch.input_ids,
positions=positions,
inputs_embeds=input_batch.inputs_embeds,
)
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
)
kv_connector_output = self.kv_connector.post_forward(scheduler_output)
if self.use_pp and not get_pp_group().is_last_rank:
if not self.is_last_pp_rank:
# Non-last PP rank: return IntermediateTensors for sending.
assert isinstance(hidden_states, IntermediateTensors)
hidden_states.kv_connector_output = kv_connector_output
@@ -986,16 +992,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Non-last PP rank: hidden_states is None because this rank produced
# IntermediateTensors instead of final hidden states. Receive the
# sampled tokens broadcast by the last rank and update local state.
if self.use_pp and not get_pp_group().is_last_rank:
if not self.is_last_pp_rank:
assert self.pp_handler is not None
received = self.pp_handler.maybe_receive_sampled_tokens(
input_batch.num_reqs,
self.device,
max_sample_len=self.num_speculative_steps + 1,
input_batch.num_reqs, max_sample_len=self.num_speculative_steps + 1
)
if received is not None:
sampled, num_sampled, num_rejected = received
self.postprocess(input_batch, sampled, num_sampled, num_rejected)
assert received is not None
sampled, num_sampled, num_rejected = received
self.postprocess(input_batch, sampled, num_sampled, num_rejected)
return None
# Last rank: sample tokens

View File

@@ -15,6 +15,9 @@ class PPHandler:
Only instantiated when PP is enabled (pp_size > 1).
"""
def __init__(self, device: torch.device):
self.device = device
def maybe_broadcast_sampled_tokens(
self,
sampler_output: SamplerOutput,
@@ -59,7 +62,6 @@ class PPHandler:
def maybe_receive_sampled_tokens(
self,
num_reqs: int,
device: torch.device,
max_sample_len: int = 1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None:
"""Receive sampled tokens broadcast by the last PP rank.
@@ -84,7 +86,7 @@ class PPHandler:
return None
sampled_tokens = torch.empty(
num_reqs, max_sample_len, dtype=torch.int64, device=device
num_reqs, max_sample_len, dtype=torch.int64, device=self.device
)
torch.distributed.broadcast(
sampled_tokens,
@@ -93,27 +95,16 @@ class PPHandler:
)
# NOTE: num_sampled/num_rejected are only needed
# for speculative decoding.
num_sampled = torch.empty(num_reqs, dtype=torch.int32, device=device)
num_sampled = torch.empty(num_reqs, dtype=torch.int32, device=self.device)
torch.distributed.broadcast(
num_sampled,
src=pp.last_rank,
group=pp.device_group,
)
num_rejected = torch.empty(num_reqs, dtype=torch.int32, device=device)
num_rejected = torch.empty(num_reqs, dtype=torch.int32, device=self.device)
torch.distributed.broadcast(
num_rejected,
src=pp.last_rank,
group=pp.device_group,
)
return sampled_tokens, num_sampled, num_rejected
def get_pp_handler(parallel_config) -> PPHandler:
"""Factory function to create PPHandler.
Must only be called when PP is enabled (pp_size > 1).
"""
assert parallel_config.pipeline_parallel_size > 1, (
"PPHandler should not be created when pipeline parallelism is disabled."
)
return PPHandler()