[Model Runner V2] Enable piecewise CUDA graphs for pipeline parallelism (#35162)

Signed-off-by: Zhanqiu Hu <zh338@cornell.edu>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Co-authored-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
zhanqiuhu
2026-03-22 16:48:25 -04:00
committed by GitHub
parent a5e9d511de
commit 63f49b8bd4
2 changed files with 102 additions and 46 deletions

View File

@@ -11,11 +11,16 @@ from tqdm import tqdm
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
from vllm.distributed.parallel_state import (
get_pp_group,
graph_capture,
is_global_first_rank,
)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.offloader.base import get_offloader
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
from vllm.v1.worker.gpu.block_table import BlockTables
@@ -87,7 +92,15 @@ class CudaGraphManager:
assert self.compilation_config is not None
self.cudagraph_mode = cudagraph_mode
self.decode_query_len = decode_query_len
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.pp_size = vllm_config.parallel_config.pipeline_parallel_size
if self.pp_size > 1:
self.is_first_pp_rank = get_pp_group().is_first_rank
self.is_last_pp_rank = get_pp_group().is_last_rank
else:
self.is_first_pp_rank = True
self.is_last_pp_rank = True
self.graphs: dict[BatchExecutionDescriptor, torch.cuda.CUDAGraph] = {}
self.pool = current_platform.get_global_graph_pool() if cudagraph_mode else None
@@ -267,12 +280,14 @@ class ModelCudaGraphManager(CudaGraphManager):
self.hidden_states: torch.Tensor | None = None
self.aux_hidden_states: list[torch.Tensor] = []
self.use_aux_hidden_state_outputs = False
self.intermediate_tensors: IntermediateTensors | None = None
def capture(
self,
model: nn.Module,
model_state: ModelState,
input_buffers: InputBuffers,
intermediate_tensors: IntermediateTensors | None,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
@@ -293,6 +308,19 @@ class ModelCudaGraphManager(CudaGraphManager):
if self.dp_size > 1
else None
)
model_inputs = {
"input_ids": input_buffers.input_ids[:num_tokens],
"positions": input_buffers.positions[:num_tokens],
**model_state.prepare_dummy_inputs(num_reqs, num_tokens),
}
if not self.is_first_pp_rank:
# Update for non-first PP ranks.
model_inputs["input_ids"] = None
model_inputs["inputs_embeds"] = None
assert intermediate_tensors is not None
model_inputs["intermediate_tensors"] = intermediate_tensors[:num_tokens]
attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs,
num_tokens,
@@ -318,21 +346,15 @@ class ModelCudaGraphManager(CudaGraphManager):
slot_mapping=slot_mappings,
batch_descriptor=batch_descriptor,
):
model_inputs = {
"input_ids": input_buffers.input_ids[:num_tokens],
"positions": input_buffers.positions[:num_tokens],
# TODO: Pass intermediate_tensors for PP CUDA graph
# support (https://github.com/vllm-project/vllm/pull/35162).
"intermediate_tensors": None,
**model_state.prepare_dummy_inputs(num_reqs, num_tokens),
}
model_output = model(**model_inputs)
if cg_mode == CUDAGraphMode.PIECEWISE:
# PW CUDA graph internally handles the model outputs.
# No need to keep track of the hidden states.
return None
if cg_mode == CUDAGraphMode.PIECEWISE:
# PW CUDA graph internally handles the model outputs.
# No need to keep track of the hidden states.
return None
if self.is_last_pp_rank:
# Last PP rank (common case).
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
@@ -340,13 +362,26 @@ class ModelCudaGraphManager(CudaGraphManager):
aux_hidden_states = []
if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states)
self.hidden_states[:num_tokens] = hidden_states
if self.use_aux_hidden_state_outputs and not self.aux_hidden_states:
self.aux_hidden_states = [
torch.empty_like(x) for x in aux_hidden_states
]
self.hidden_states[:num_tokens] = hidden_states
for i, aux in enumerate(aux_hidden_states):
self.aux_hidden_states[i][:num_tokens] = aux
else:
# Non-last PP rank.
intermediate_tensors = model_output
assert isinstance(intermediate_tensors, IntermediateTensors)
if self.intermediate_tensors is None:
self.intermediate_tensors = IntermediateTensors(
{
k: torch.empty_like(v)
for k, v in intermediate_tensors.tensors.items()
}
)
for k, v in intermediate_tensors.tensors.items():
self.intermediate_tensors[k][:num_tokens] = v
return forward_fn
@@ -354,9 +389,13 @@ class ModelCudaGraphManager(CudaGraphManager):
def run_fullgraph(
self, desc: BatchExecutionDescriptor
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]] | IntermediateTensors:
"""Replay a captured FULL cudagraph and return hidden states."""
super().run_fullgraph(desc)
if not self.is_last_pp_rank:
assert self.intermediate_tensors is not None
return self.intermediate_tensors[: desc.num_tokens]
assert self.hidden_states is not None
hidden_states = self.hidden_states[: desc.num_tokens]
if not self.use_aux_hidden_state_outputs:

View File

@@ -140,6 +140,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else:
self.is_first_pp_rank = True
self.is_last_pp_rank = True
# Persistent buffer for intermediate tensors (non-first PP ranks).
self.intermediate_tensors: IntermediateTensors | None = None
# Data parallelism.
self.dp_size = self.parallel_config.data_parallel_size
@@ -301,6 +303,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.is_pooling_model and self.is_last_pp_rank:
self.pooling_runner = PoolingRunner(self.model)
if not self.is_first_pp_rank:
# For non-first PP ranks, create intermediate tensors sized
# for the max capture size so they can be sliced per batch.
# Save as persistent member so runtime can copy received data
# into the same addresses that the CUDA graphs captured.
self.intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=self.max_num_tokens,
dtype=self.model_config.dtype,
device=self.device,
)
def get_model(self) -> nn.Module:
return self.model
@@ -396,14 +409,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Disable any use of KVConnector for dummy runs.
self.kv_connector.set_disabled(True)
# For non-first PP ranks, create dummy intermediate_tensors.
# Get the intermediate tensors for the dummy run.
intermediate_tensors = None
if not self.is_first_pp_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=num_tokens,
dtype=self.model_config.dtype,
device=self.device,
)
assert self.intermediate_tensors is not None
intermediate_tensors = self.intermediate_tensors[:num_tokens]
# Execute the model.
self.execute_model(
@@ -528,14 +538,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
return 0
# TODO (zhanqiu): support CUDA graph for PP.
if self.use_pp:
logger.warning_once(
"Skipping CUDA graph capture because pipeline parallel is "
"enabled. Pipeline parallel is currently eager-only.",
)
return 0
start_time = time.perf_counter()
gc.collect()
torch.accelerator.empty_cache()
@@ -546,6 +548,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.model,
self.model_state,
self.input_buffers,
self.intermediate_tensors,
self.block_tables,
self.attn_groups,
self.kv_cache_config,
@@ -1010,7 +1013,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
"input_ids": input_batch.input_ids,
"positions": input_batch.positions,
"inputs_embeds": inputs_embeds,
"intermediate_tensors": intermediate_tensors,
# NOTE: Values returned by `prepare_inputs` will override the default
# values above.
**self.model_state.prepare_inputs(input_batch, self.req_states),
@@ -1019,7 +1021,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Update for non-first PP ranks.
model_inputs["input_ids"] = None
model_inputs["inputs_embeds"] = None
# Prepare the intermediate tensors.
assert intermediate_tensors is not None
assert self.intermediate_tensors is not None
n = input_batch.num_tokens_after_padding
intermediate_tensors = IntermediateTensors(
{
k: v[:n].copy_(intermediate_tensors.tensors[k][:n])
for k, v in self.intermediate_tensors.tensors.items()
},
intermediate_tensors.kv_connector_output,
)
model_inputs["intermediate_tensors"] = intermediate_tensors
# Run model.
if batch_desc.cg_mode == CUDAGraphMode.FULL:
@@ -1028,11 +1042,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# because they are already copied to the CUDA graph input buffers.
self.kv_connector.pre_forward(scheduler_output)
model_output = self.cudagraph_manager.run_fullgraph(batch_desc)
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None
else:
# For piecewise and eager mode, just call model().
batch_descriptor = BatchDescriptor(
@@ -1052,11 +1061,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
):
self.kv_connector.pre_forward(scheduler_output)
model_output = self.model(**model_inputs)
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None
if self.is_last_pp_rank:
if self.use_aux_hidden_state_outputs:
assert isinstance(model_output, tuple)
hidden_states, aux_hidden_states = model_output
else:
assert isinstance(model_output, torch.Tensor)
hidden_states = model_output
aux_hidden_states = None
output_intermediate_tensors = None
else:
assert isinstance(model_output, IntermediateTensors)
hidden_states = None
aux_hidden_states = None
output_intermediate_tensors = model_output
kv_connector_output = self.kv_connector.post_forward(scheduler_output)
self.execute_model_state = ExecuteModelState(
@@ -1071,11 +1090,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if not self.is_last_pp_rank:
# Non-last PP rank: return IntermediateTensors for sending.
assert isinstance(hidden_states, IntermediateTensors)
hidden_states.kv_connector_output = kv_connector_output
return hidden_states
# Last rank (or no PP): hidden_states is a tensor for sampling.
assert isinstance(hidden_states, torch.Tensor)
assert output_intermediate_tensors is not None
output_intermediate_tensors.kv_connector_output = kv_connector_output
return output_intermediate_tensors
return None
@torch.inference_mode()
@@ -1259,7 +1276,7 @@ class ExecuteModelState(NamedTuple):
input_batch: InputBatch
attn_metadata: dict[str, Any] | None
slot_mappings_by_layer: dict[str, torch.Tensor] | None
hidden_states: torch.Tensor | IntermediateTensors
hidden_states: torch.Tensor | None
aux_hidden_states: list[torch.Tensor] | None
kv_connector_output: KVConnectorOutput | None
num_tokens_across_dp: torch.Tensor | None