[Model Runner V2] Enable piecewise CUDA graphs for pipeline parallelism (#35162)
Signed-off-by: Zhanqiu Hu <zh338@cornell.edu> Signed-off-by: Woosuk Kwon <woosuk@inferact.ai> Co-authored-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -11,11 +11,16 @@ from tqdm import tqdm
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_pp_group,
|
||||
graph_capture,
|
||||
is_global_first_rank,
|
||||
)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.offloader.base import get_offloader
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
@@ -87,7 +92,15 @@ class CudaGraphManager:
|
||||
assert self.compilation_config is not None
|
||||
self.cudagraph_mode = cudagraph_mode
|
||||
self.decode_query_len = decode_query_len
|
||||
|
||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
self.pp_size = vllm_config.parallel_config.pipeline_parallel_size
|
||||
if self.pp_size > 1:
|
||||
self.is_first_pp_rank = get_pp_group().is_first_rank
|
||||
self.is_last_pp_rank = get_pp_group().is_last_rank
|
||||
else:
|
||||
self.is_first_pp_rank = True
|
||||
self.is_last_pp_rank = True
|
||||
|
||||
self.graphs: dict[BatchExecutionDescriptor, torch.cuda.CUDAGraph] = {}
|
||||
self.pool = current_platform.get_global_graph_pool() if cudagraph_mode else None
|
||||
@@ -267,12 +280,14 @@ class ModelCudaGraphManager(CudaGraphManager):
|
||||
self.hidden_states: torch.Tensor | None = None
|
||||
self.aux_hidden_states: list[torch.Tensor] = []
|
||||
self.use_aux_hidden_state_outputs = False
|
||||
self.intermediate_tensors: IntermediateTensors | None = None
|
||||
|
||||
def capture(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
intermediate_tensors: IntermediateTensors | None,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
@@ -293,6 +308,19 @@ class ModelCudaGraphManager(CudaGraphManager):
|
||||
if self.dp_size > 1
|
||||
else None
|
||||
)
|
||||
|
||||
model_inputs = {
|
||||
"input_ids": input_buffers.input_ids[:num_tokens],
|
||||
"positions": input_buffers.positions[:num_tokens],
|
||||
**model_state.prepare_dummy_inputs(num_reqs, num_tokens),
|
||||
}
|
||||
if not self.is_first_pp_rank:
|
||||
# Update for non-first PP ranks.
|
||||
model_inputs["input_ids"] = None
|
||||
model_inputs["inputs_embeds"] = None
|
||||
assert intermediate_tensors is not None
|
||||
model_inputs["intermediate_tensors"] = intermediate_tensors[:num_tokens]
|
||||
|
||||
attn_metadata, slot_mappings = prepare_inputs_to_capture(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
@@ -318,21 +346,15 @@ class ModelCudaGraphManager(CudaGraphManager):
|
||||
slot_mapping=slot_mappings,
|
||||
batch_descriptor=batch_descriptor,
|
||||
):
|
||||
model_inputs = {
|
||||
"input_ids": input_buffers.input_ids[:num_tokens],
|
||||
"positions": input_buffers.positions[:num_tokens],
|
||||
# TODO: Pass intermediate_tensors for PP CUDA graph
|
||||
# support (https://github.com/vllm-project/vllm/pull/35162).
|
||||
"intermediate_tensors": None,
|
||||
**model_state.prepare_dummy_inputs(num_reqs, num_tokens),
|
||||
}
|
||||
model_output = model(**model_inputs)
|
||||
|
||||
if cg_mode == CUDAGraphMode.PIECEWISE:
|
||||
# PW CUDA graph internally handles the model outputs.
|
||||
# No need to keep track of the hidden states.
|
||||
return None
|
||||
if cg_mode == CUDAGraphMode.PIECEWISE:
|
||||
# PW CUDA graph internally handles the model outputs.
|
||||
# No need to keep track of the hidden states.
|
||||
return None
|
||||
|
||||
if self.is_last_pp_rank:
|
||||
# Last PP rank (common case).
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
@@ -340,13 +362,26 @@ class ModelCudaGraphManager(CudaGraphManager):
|
||||
aux_hidden_states = []
|
||||
if self.hidden_states is None:
|
||||
self.hidden_states = torch.empty_like(hidden_states)
|
||||
self.hidden_states[:num_tokens] = hidden_states
|
||||
if self.use_aux_hidden_state_outputs and not self.aux_hidden_states:
|
||||
self.aux_hidden_states = [
|
||||
torch.empty_like(x) for x in aux_hidden_states
|
||||
]
|
||||
self.hidden_states[:num_tokens] = hidden_states
|
||||
for i, aux in enumerate(aux_hidden_states):
|
||||
self.aux_hidden_states[i][:num_tokens] = aux
|
||||
else:
|
||||
# Non-last PP rank.
|
||||
intermediate_tensors = model_output
|
||||
assert isinstance(intermediate_tensors, IntermediateTensors)
|
||||
if self.intermediate_tensors is None:
|
||||
self.intermediate_tensors = IntermediateTensors(
|
||||
{
|
||||
k: torch.empty_like(v)
|
||||
for k, v in intermediate_tensors.tensors.items()
|
||||
}
|
||||
)
|
||||
for k, v in intermediate_tensors.tensors.items():
|
||||
self.intermediate_tensors[k][:num_tokens] = v
|
||||
|
||||
return forward_fn
|
||||
|
||||
@@ -354,9 +389,13 @@ class ModelCudaGraphManager(CudaGraphManager):
|
||||
|
||||
def run_fullgraph(
|
||||
self, desc: BatchExecutionDescriptor
|
||||
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]] | IntermediateTensors:
|
||||
"""Replay a captured FULL cudagraph and return hidden states."""
|
||||
super().run_fullgraph(desc)
|
||||
if not self.is_last_pp_rank:
|
||||
assert self.intermediate_tensors is not None
|
||||
return self.intermediate_tensors[: desc.num_tokens]
|
||||
|
||||
assert self.hidden_states is not None
|
||||
hidden_states = self.hidden_states[: desc.num_tokens]
|
||||
if not self.use_aux_hidden_state_outputs:
|
||||
|
||||
@@ -140,6 +140,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
self.is_first_pp_rank = True
|
||||
self.is_last_pp_rank = True
|
||||
# Persistent buffer for intermediate tensors (non-first PP ranks).
|
||||
self.intermediate_tensors: IntermediateTensors | None = None
|
||||
|
||||
# Data parallelism.
|
||||
self.dp_size = self.parallel_config.data_parallel_size
|
||||
@@ -301,6 +303,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.is_pooling_model and self.is_last_pp_rank:
|
||||
self.pooling_runner = PoolingRunner(self.model)
|
||||
|
||||
if not self.is_first_pp_rank:
|
||||
# For non-first PP ranks, create intermediate tensors sized
|
||||
# for the max capture size so they can be sliced per batch.
|
||||
# Save as persistent member so runtime can copy received data
|
||||
# into the same addresses that the CUDA graphs captured.
|
||||
self.intermediate_tensors = self.model.make_empty_intermediate_tensors(
|
||||
batch_size=self.max_num_tokens,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
@@ -396,14 +409,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Disable any use of KVConnector for dummy runs.
|
||||
self.kv_connector.set_disabled(True)
|
||||
|
||||
# For non-first PP ranks, create dummy intermediate_tensors.
|
||||
# Get the intermediate tensors for the dummy run.
|
||||
intermediate_tensors = None
|
||||
if not self.is_first_pp_rank:
|
||||
intermediate_tensors = self.model.make_empty_intermediate_tensors(
|
||||
batch_size=num_tokens,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
assert self.intermediate_tensors is not None
|
||||
intermediate_tensors = self.intermediate_tensors[:num_tokens]
|
||||
|
||||
# Execute the model.
|
||||
self.execute_model(
|
||||
@@ -528,14 +538,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
return 0
|
||||
|
||||
# TODO (zhanqiu): support CUDA graph for PP.
|
||||
if self.use_pp:
|
||||
logger.warning_once(
|
||||
"Skipping CUDA graph capture because pipeline parallel is "
|
||||
"enabled. Pipeline parallel is currently eager-only.",
|
||||
)
|
||||
return 0
|
||||
|
||||
start_time = time.perf_counter()
|
||||
gc.collect()
|
||||
torch.accelerator.empty_cache()
|
||||
@@ -546,6 +548,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.model,
|
||||
self.model_state,
|
||||
self.input_buffers,
|
||||
self.intermediate_tensors,
|
||||
self.block_tables,
|
||||
self.attn_groups,
|
||||
self.kv_cache_config,
|
||||
@@ -1010,7 +1013,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
"input_ids": input_batch.input_ids,
|
||||
"positions": input_batch.positions,
|
||||
"inputs_embeds": inputs_embeds,
|
||||
"intermediate_tensors": intermediate_tensors,
|
||||
# NOTE: Values returned by `prepare_inputs` will override the default
|
||||
# values above.
|
||||
**self.model_state.prepare_inputs(input_batch, self.req_states),
|
||||
@@ -1019,7 +1021,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Update for non-first PP ranks.
|
||||
model_inputs["input_ids"] = None
|
||||
model_inputs["inputs_embeds"] = None
|
||||
|
||||
# Prepare the intermediate tensors.
|
||||
assert intermediate_tensors is not None
|
||||
assert self.intermediate_tensors is not None
|
||||
n = input_batch.num_tokens_after_padding
|
||||
intermediate_tensors = IntermediateTensors(
|
||||
{
|
||||
k: v[:n].copy_(intermediate_tensors.tensors[k][:n])
|
||||
for k, v in self.intermediate_tensors.tensors.items()
|
||||
},
|
||||
intermediate_tensors.kv_connector_output,
|
||||
)
|
||||
model_inputs["intermediate_tensors"] = intermediate_tensors
|
||||
|
||||
# Run model.
|
||||
if batch_desc.cg_mode == CUDAGraphMode.FULL:
|
||||
@@ -1028,11 +1042,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# because they are already copied to the CUDA graph input buffers.
|
||||
self.kv_connector.pre_forward(scheduler_output)
|
||||
model_output = self.cudagraph_manager.run_fullgraph(batch_desc)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
hidden_states = model_output
|
||||
aux_hidden_states = None
|
||||
else:
|
||||
# For piecewise and eager mode, just call model().
|
||||
batch_descriptor = BatchDescriptor(
|
||||
@@ -1052,11 +1061,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
):
|
||||
self.kv_connector.pre_forward(scheduler_output)
|
||||
model_output = self.model(**model_inputs)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
hidden_states = model_output
|
||||
aux_hidden_states = None
|
||||
|
||||
if self.is_last_pp_rank:
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
assert isinstance(model_output, tuple)
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
assert isinstance(model_output, torch.Tensor)
|
||||
hidden_states = model_output
|
||||
aux_hidden_states = None
|
||||
output_intermediate_tensors = None
|
||||
else:
|
||||
assert isinstance(model_output, IntermediateTensors)
|
||||
hidden_states = None
|
||||
aux_hidden_states = None
|
||||
output_intermediate_tensors = model_output
|
||||
|
||||
kv_connector_output = self.kv_connector.post_forward(scheduler_output)
|
||||
self.execute_model_state = ExecuteModelState(
|
||||
@@ -1071,11 +1090,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
if not self.is_last_pp_rank:
|
||||
# Non-last PP rank: return IntermediateTensors for sending.
|
||||
assert isinstance(hidden_states, IntermediateTensors)
|
||||
hidden_states.kv_connector_output = kv_connector_output
|
||||
return hidden_states
|
||||
# Last rank (or no PP): hidden_states is a tensor for sampling.
|
||||
assert isinstance(hidden_states, torch.Tensor)
|
||||
assert output_intermediate_tensors is not None
|
||||
output_intermediate_tensors.kv_connector_output = kv_connector_output
|
||||
return output_intermediate_tensors
|
||||
return None
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -1259,7 +1276,7 @@ class ExecuteModelState(NamedTuple):
|
||||
input_batch: InputBatch
|
||||
attn_metadata: dict[str, Any] | None
|
||||
slot_mappings_by_layer: dict[str, torch.Tensor] | None
|
||||
hidden_states: torch.Tensor | IntermediateTensors
|
||||
hidden_states: torch.Tensor | None
|
||||
aux_hidden_states: list[torch.Tensor] | None
|
||||
kv_connector_output: KVConnectorOutput | None
|
||||
num_tokens_across_dp: torch.Tensor | None
|
||||
|
||||
Reference in New Issue
Block a user