[Model Runner V2] Minor cleanup for PP (#34666)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -57,7 +57,7 @@ from vllm.v1.worker.gpu.kv_connector import (
|
||||
from vllm.v1.worker.gpu.lora_utils import LoraState
|
||||
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
|
||||
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
|
||||
from vllm.v1.worker.gpu.pp_handler import PPHandler, get_pp_handler
|
||||
from vllm.v1.worker.gpu.pp_handler import PPHandler
|
||||
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
||||
from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
|
||||
from vllm.v1.worker.gpu.sample.sampler import Sampler
|
||||
@@ -184,9 +184,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Pipeline parallelism.
|
||||
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
||||
self.pp_handler: PPHandler | None = (
|
||||
get_pp_handler(self.parallel_config) if self.use_pp else None
|
||||
)
|
||||
if self.use_pp:
|
||||
self.is_first_pp_rank = get_pp_group().is_first_rank
|
||||
self.is_last_pp_rank = get_pp_group().is_last_rank
|
||||
self.pp_handler: PPHandler | None = PPHandler(self.device)
|
||||
else:
|
||||
self.is_first_pp_rank = True
|
||||
self.is_last_pp_rank = True
|
||||
self.pp_handler = None
|
||||
|
||||
def update_max_model_len(self, max_model_len: int) -> None:
|
||||
self.max_model_len = max_model_len
|
||||
@@ -318,7 +323,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# For non-first PP ranks, create dummy intermediate_tensors.
|
||||
intermediate_tensors = None
|
||||
if self.use_pp and not get_pp_group().is_first_rank:
|
||||
if not self.is_first_pp_rank:
|
||||
intermediate_tensors = self.model.make_empty_intermediate_tensors(
|
||||
batch_size=num_tokens,
|
||||
dtype=self.model_config.dtype,
|
||||
@@ -335,7 +340,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.kv_connector.set_disabled(False)
|
||||
|
||||
# Non-last PP ranks don't produce output for sampling.
|
||||
if self.use_pp and not get_pp_group().is_last_rank:
|
||||
if not self.is_last_pp_rank:
|
||||
return None, None
|
||||
|
||||
assert self.execute_model_state is not None
|
||||
@@ -373,20 +378,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
hidden_states, sample_hidden_states = self._dummy_run(
|
||||
self.max_num_tokens, skip_attn=True
|
||||
)
|
||||
|
||||
# Only run sampler on last PP rank (non-last ranks return None).
|
||||
if not self.use_pp or get_pp_group().is_last_rank:
|
||||
if self.is_last_pp_rank:
|
||||
assert sample_hidden_states is not None
|
||||
self._dummy_sampler_run(sample_hidden_states)
|
||||
if self.do_spec_decode:
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(
|
||||
self.parallel_config.data_parallel_size, self.max_num_tokens
|
||||
)
|
||||
self.speculator.run_model(
|
||||
self.max_num_tokens,
|
||||
attn_metadata=None,
|
||||
slot_mappings=None,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
)
|
||||
|
||||
if self.do_spec_decode:
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(
|
||||
self.parallel_config.data_parallel_size, self.max_num_tokens
|
||||
)
|
||||
self.speculator.run_model(
|
||||
self.max_num_tokens,
|
||||
attn_metadata=None,
|
||||
slot_mappings=None,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
del hidden_states, sample_hidden_states
|
||||
gc.collect()
|
||||
@@ -890,9 +898,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self._set_active_loras(*lora_inputs)
|
||||
|
||||
# Only first PP rank prepares multimodal embeddings.
|
||||
if self.supports_mm_inputs and (
|
||||
not self.use_pp or get_pp_group().is_first_rank
|
||||
):
|
||||
if self.supports_mm_inputs and self.is_first_pp_rank:
|
||||
mm_embeds, is_mm_embed = self.get_mm_embeddings(
|
||||
scheduler_output.scheduled_encoder_inputs, input_batch
|
||||
)
|
||||
@@ -935,6 +941,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
assert input_batch.mrope_positions is not None
|
||||
positions = input_batch.mrope_positions
|
||||
|
||||
if self.is_first_pp_rank:
|
||||
input_ids = input_batch.input_ids
|
||||
inputs_embeds = input_batch.inputs_embeds
|
||||
assert intermediate_tensors is None
|
||||
else:
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
assert intermediate_tensors is not None
|
||||
|
||||
with set_forward_context(
|
||||
input_batch.attn_metadata,
|
||||
self.vllm_config,
|
||||
@@ -945,25 +960,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
slot_mapping=input_batch.slot_mappings,
|
||||
):
|
||||
self.kv_connector.pre_forward(scheduler_output)
|
||||
if self.use_pp and not get_pp_group().is_first_rank:
|
||||
# Non-first PP rank: forward with intermediate tensors.
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = self.model(
|
||||
input_ids=None,
|
||||
positions=positions,
|
||||
inputs_embeds=None,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
)
|
||||
else:
|
||||
hidden_states = self.model(
|
||||
input_ids=input_batch.input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=input_batch.inputs_embeds,
|
||||
)
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
)
|
||||
|
||||
kv_connector_output = self.kv_connector.post_forward(scheduler_output)
|
||||
|
||||
if self.use_pp and not get_pp_group().is_last_rank:
|
||||
if not self.is_last_pp_rank:
|
||||
# Non-last PP rank: return IntermediateTensors for sending.
|
||||
assert isinstance(hidden_states, IntermediateTensors)
|
||||
hidden_states.kv_connector_output = kv_connector_output
|
||||
@@ -986,16 +992,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Non-last PP rank: hidden_states is None because this rank produced
|
||||
# IntermediateTensors instead of final hidden states. Receive the
|
||||
# sampled tokens broadcast by the last rank and update local state.
|
||||
if self.use_pp and not get_pp_group().is_last_rank:
|
||||
if not self.is_last_pp_rank:
|
||||
assert self.pp_handler is not None
|
||||
received = self.pp_handler.maybe_receive_sampled_tokens(
|
||||
input_batch.num_reqs,
|
||||
self.device,
|
||||
max_sample_len=self.num_speculative_steps + 1,
|
||||
input_batch.num_reqs, max_sample_len=self.num_speculative_steps + 1
|
||||
)
|
||||
if received is not None:
|
||||
sampled, num_sampled, num_rejected = received
|
||||
self.postprocess(input_batch, sampled, num_sampled, num_rejected)
|
||||
assert received is not None
|
||||
sampled, num_sampled, num_rejected = received
|
||||
self.postprocess(input_batch, sampled, num_sampled, num_rejected)
|
||||
return None
|
||||
|
||||
# Last rank: sample tokens
|
||||
|
||||
@@ -15,6 +15,9 @@ class PPHandler:
|
||||
Only instantiated when PP is enabled (pp_size > 1).
|
||||
"""
|
||||
|
||||
def __init__(self, device: torch.device):
|
||||
self.device = device
|
||||
|
||||
def maybe_broadcast_sampled_tokens(
|
||||
self,
|
||||
sampler_output: SamplerOutput,
|
||||
@@ -59,7 +62,6 @@ class PPHandler:
|
||||
def maybe_receive_sampled_tokens(
|
||||
self,
|
||||
num_reqs: int,
|
||||
device: torch.device,
|
||||
max_sample_len: int = 1,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None:
|
||||
"""Receive sampled tokens broadcast by the last PP rank.
|
||||
@@ -84,7 +86,7 @@ class PPHandler:
|
||||
return None
|
||||
|
||||
sampled_tokens = torch.empty(
|
||||
num_reqs, max_sample_len, dtype=torch.int64, device=device
|
||||
num_reqs, max_sample_len, dtype=torch.int64, device=self.device
|
||||
)
|
||||
torch.distributed.broadcast(
|
||||
sampled_tokens,
|
||||
@@ -93,27 +95,16 @@ class PPHandler:
|
||||
)
|
||||
# NOTE: num_sampled/num_rejected are only needed
|
||||
# for speculative decoding.
|
||||
num_sampled = torch.empty(num_reqs, dtype=torch.int32, device=device)
|
||||
num_sampled = torch.empty(num_reqs, dtype=torch.int32, device=self.device)
|
||||
torch.distributed.broadcast(
|
||||
num_sampled,
|
||||
src=pp.last_rank,
|
||||
group=pp.device_group,
|
||||
)
|
||||
num_rejected = torch.empty(num_reqs, dtype=torch.int32, device=device)
|
||||
num_rejected = torch.empty(num_reqs, dtype=torch.int32, device=self.device)
|
||||
torch.distributed.broadcast(
|
||||
num_rejected,
|
||||
src=pp.last_rank,
|
||||
group=pp.device_group,
|
||||
)
|
||||
return sampled_tokens, num_sampled, num_rejected
|
||||
|
||||
|
||||
def get_pp_handler(parallel_config) -> PPHandler:
|
||||
"""Factory function to create PPHandler.
|
||||
|
||||
Must only be called when PP is enabled (pp_size > 1).
|
||||
"""
|
||||
assert parallel_config.pipeline_parallel_size > 1, (
|
||||
"PPHandler should not be created when pipeline parallelism is disabled."
|
||||
)
|
||||
return PPHandler()
|
||||
|
||||
Reference in New Issue
Block a user