[Model Runner V2][Minor] Simplify PP logic (#38031)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-03-24 13:57:17 -07:00
committed by GitHub
parent 0c1809c806
commit 4e824d1c83
3 changed files with 23 additions and 28 deletions

View File

@@ -62,3 +62,12 @@ class IntermediateTensors:
def __repr__(self) -> str: def __repr__(self) -> str:
return f"IntermediateTensors(tensors={self.tensors})" return f"IntermediateTensors(tensors={self.tensors})"
@staticmethod
def empty_like(
intermediate_tensors: "IntermediateTensors",
) -> "IntermediateTensors":
tensors = {
k: torch.empty_like(v) for k, v in intermediate_tensors.tensors.items()
}
return IntermediateTensors(tensors)

View File

@@ -94,13 +94,8 @@ class CudaGraphManager:
self.decode_query_len = decode_query_len self.decode_query_len = decode_query_len
self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_size = vllm_config.parallel_config.data_parallel_size
self.pp_size = vllm_config.parallel_config.pipeline_parallel_size self.is_first_pp_rank = get_pp_group().is_first_rank
if self.pp_size > 1: self.is_last_pp_rank = get_pp_group().is_last_rank
self.is_first_pp_rank = get_pp_group().is_first_rank
self.is_last_pp_rank = get_pp_group().is_last_rank
else:
self.is_first_pp_rank = True
self.is_last_pp_rank = True
self.graphs: dict[BatchExecutionDescriptor, torch.cuda.CUDAGraph] = {} self.graphs: dict[BatchExecutionDescriptor, torch.cuda.CUDAGraph] = {}
self.pool = current_platform.get_global_graph_pool() if cudagraph_mode else None self.pool = current_platform.get_global_graph_pool() if cudagraph_mode else None
@@ -371,14 +366,11 @@ class ModelCudaGraphManager(CudaGraphManager):
self.aux_hidden_states[i][:num_tokens] = aux self.aux_hidden_states[i][:num_tokens] = aux
else: else:
# Non-last PP rank. # Non-last PP rank.
assert isinstance(model_output, IntermediateTensors)
intermediate_tensors = model_output intermediate_tensors = model_output
assert isinstance(intermediate_tensors, IntermediateTensors)
if self.intermediate_tensors is None: if self.intermediate_tensors is None:
self.intermediate_tensors = IntermediateTensors( self.intermediate_tensors = IntermediateTensors.empty_like(
{ intermediate_tensors
k: torch.empty_like(v)
for k, v in intermediate_tensors.tensors.items()
}
) )
for k, v in intermediate_tensors.tensors.items(): for k, v in intermediate_tensors.tensors.items():
self.intermediate_tensors[k][:num_tokens] = v self.intermediate_tensors[k][:num_tokens] = v

View File

@@ -132,14 +132,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.output_copy_event = torch.cuda.Event() self.output_copy_event = torch.cuda.Event()
# Pipeline parallelism. # Pipeline parallelism.
self.pp_size = self.parallel_config.pipeline_parallel_size self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.use_pp = self.pp_size > 1 self.is_first_pp_rank = get_pp_group().is_first_rank
if self.use_pp: self.is_last_pp_rank = get_pp_group().is_last_rank
self.is_first_pp_rank = get_pp_group().is_first_rank
self.is_last_pp_rank = get_pp_group().is_last_rank
else:
self.is_first_pp_rank = True
self.is_last_pp_rank = True
# Persistent buffer for intermediate tensors (non-first PP ranks). # Persistent buffer for intermediate tensors (non-first PP ranks).
self.intermediate_tensors: IntermediateTensors | None = None self.intermediate_tensors: IntermediateTensors | None = None
@@ -179,7 +175,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.speculative_config.method == "eagle3": if self.speculative_config.method == "eagle3":
# EAGLE3 may require auxiliary hidden states from target model outputs. # EAGLE3 may require auxiliary hidden states from target model outputs.
self.use_aux_hidden_state_outputs = True self.use_aux_hidden_state_outputs = True
if self.pp_size > 1: if self.use_pp:
raise ValueError("EAGLE3 with pipeline parallel is not supported.") raise ValueError("EAGLE3 with pipeline parallel is not supported.")
# Draft tokens propagation - for spec-dec + struct outputs. # Draft tokens propagation - for spec-dec + struct outputs.
@@ -270,8 +266,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logger.info("Loading model from scratch...") logger.info("Loading model from scratch...")
self.model = model_loader.load_model( self.model = model_loader.load_model(
vllm_config=self.vllm_config, vllm_config=self.vllm_config, model_config=self.vllm_config.model_config
model_config=self.vllm_config.model_config,
) )
if self.lora_config: if self.lora_config:
self.model = self.load_lora_model( self.model = self.load_lora_model(
@@ -1026,14 +1021,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert intermediate_tensors is not None assert intermediate_tensors is not None
assert self.intermediate_tensors is not None assert self.intermediate_tensors is not None
n = input_batch.num_tokens_after_padding n = input_batch.num_tokens_after_padding
intermediate_tensors = IntermediateTensors( model_inputs["intermediate_tensors"] = IntermediateTensors(
{ {
k: v[:n].copy_(intermediate_tensors.tensors[k][:n]) k: v[:n].copy_(intermediate_tensors.tensors[k][:n])
for k, v in self.intermediate_tensors.tensors.items() for k, v in self.intermediate_tensors.tensors.items()
}, }
intermediate_tensors.kv_connector_output,
) )
model_inputs["intermediate_tensors"] = intermediate_tensors del intermediate_tensors
# Run model. # Run model.
if batch_desc.cg_mode == CUDAGraphMode.FULL: if batch_desc.cg_mode == CUDAGraphMode.FULL: