[Core] Support full cuda graph in v1 (#16072)
Signed-off-by: Chanh Nguyen <cnguyen@linkedin.com> Co-authored-by: Chanh Nguyen <cnguyen@linkedin.com>
This commit is contained in:
@@ -3605,6 +3605,10 @@ class CompilationConfig(BaseModel):
|
||||
are always used, it can set this to False. Otherwise, it should
|
||||
set this to True, and the compiler will copy the input to an
|
||||
internally managed buffer. Default is False.
|
||||
- full_cuda_graph: whether to use a full cuda graph for the entire forward
|
||||
pass rather than splitting certain operations such as attention into subgraphs.
|
||||
Thus this flag cannot be used together with splitting_ops. This may provide
|
||||
performance benefits for smaller models.
|
||||
- Inductor compilation:
|
||||
- use_inductor: whether to use inductor compilation.
|
||||
- False: inductor compilation is not used. graph runs in eager.
|
||||
@@ -3649,6 +3653,7 @@ class CompilationConfig(BaseModel):
|
||||
cudagraph_num_of_warmups: int = 0
|
||||
cudagraph_capture_sizes: Optional[list[int]] = None
|
||||
cudagraph_copy_inputs: bool = False
|
||||
full_cuda_graph: bool = False
|
||||
|
||||
class PassConfig(BaseModel):
|
||||
"""
|
||||
@@ -3871,10 +3876,14 @@ class CompilationConfig(BaseModel):
|
||||
self.max_capture_size] = self.max_capture_size
|
||||
|
||||
def set_splitting_ops_for_v1(self):
|
||||
# If default, override splitting ops for piecewise cudagraph on V1.
|
||||
# NOTE: this function needs to be called
|
||||
if self.splitting_ops and self.full_cuda_graph:
|
||||
raise ValueError("full_cuda_graph cannot be used together with "
|
||||
"splitting_ops, as Full CUDA graph will override "
|
||||
f"the splitting_ops: {self.splitting_ops}")
|
||||
|
||||
if not self.splitting_ops:
|
||||
self.splitting_ops = [
|
||||
self.splitting_ops = [] if self.full_cuda_graph else [
|
||||
"vllm.unified_attention",
|
||||
"vllm.unified_attention_with_output",
|
||||
]
|
||||
@@ -4151,6 +4160,12 @@ class VllmConfig:
|
||||
"Disabling `torch.compile`.")
|
||||
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
|
||||
if self.compilation_config.full_cuda_graph and \
|
||||
not self.model_config.disable_cascade_attn:
|
||||
logger.warning_once(
|
||||
"full_cuda_graph is not supported with "
|
||||
"cascade attention. Disabling cascade attention.")
|
||||
self.model_config.disable_cascade_attn = True
|
||||
|
||||
if self.model_config and self.model_config.use_mla and \
|
||||
not (current_platform.is_cuda() or current_platform.is_rocm()):
|
||||
|
||||
Reference in New Issue
Block a user