[Core] Support full cuda graph in v1 (#16072)

Signed-off-by: Chanh Nguyen <cnguyen@linkedin.com>
Co-authored-by: Chanh Nguyen <cnguyen@linkedin.com>
This commit is contained in:
Chanh Nguyen
2025-05-07 22:30:15 -07:00
committed by GitHub
parent 3d13ca0e24
commit 7ea2adb802
5 changed files with 190 additions and 13 deletions

View File

@@ -3605,6 +3605,10 @@ class CompilationConfig(BaseModel):
are always used, it can set this to False. Otherwise, it should
set this to True, and the compiler will copy the input to an
internally managed buffer. Default is False.
- full_cuda_graph: whether to use a full cuda graph for the entire forward
pass rather than splitting certain operations such as attention into subgraphs.
Thus this flag cannot be used together with splitting_ops. This may provide
performance benefits for smaller models.
- Inductor compilation:
- use_inductor: whether to use inductor compilation.
- False: inductor compilation is not used. graph runs in eager.
@@ -3649,6 +3653,7 @@ class CompilationConfig(BaseModel):
cudagraph_num_of_warmups: int = 0
cudagraph_capture_sizes: Optional[list[int]] = None
cudagraph_copy_inputs: bool = False
full_cuda_graph: bool = False
class PassConfig(BaseModel):
"""
@@ -3871,10 +3876,14 @@ class CompilationConfig(BaseModel):
self.max_capture_size] = self.max_capture_size
def set_splitting_ops_for_v1(self):
# If default, override splitting ops for piecewise cudagraph on V1.
# NOTE: this function needs to be called
if self.splitting_ops and self.full_cuda_graph:
raise ValueError("full_cuda_graph cannot be used together with "
"splitting_ops, as Full CUDA graph will override "
f"the splitting_ops: {self.splitting_ops}")
if not self.splitting_ops:
self.splitting_ops = [
self.splitting_ops = [] if self.full_cuda_graph else [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
]
@@ -4151,6 +4160,12 @@ class VllmConfig:
"Disabling `torch.compile`.")
self.compilation_config.level = CompilationLevel.NO_COMPILATION
if self.compilation_config.full_cuda_graph and \
not self.model_config.disable_cascade_attn:
logger.warning_once(
"full_cuda_graph is not supported with "
"cascade attention. Disabling cascade attention.")
self.model_config.disable_cascade_attn = True
if self.model_config and self.model_config.use_mla and \
not (current_platform.is_cuda() or current_platform.is_rocm()):