[Compilation][WideEP] Enable Piecewise CUDAGraph for DeepEPHT (#24123)

This commit is contained in:
Wentao Ye
2025-09-09 10:21:10 -04:00
committed by GitHub
parent 6fb2788163
commit a55cf41a09
2 changed files with 21 additions and 10 deletions

View File

@@ -546,7 +546,8 @@ class CompilationConfig:
# full cudagraph outside the fx graph. This reduces some cpu
# overhead when the runtime batch_size is not cudagraph captured.
# see https://github.com/vllm-project/vllm/pull/20059 for details.
self.splitting_ops = self._attention_ops
# make a copy to avoid mutating the class-level list via reference.
self.splitting_ops = list(self._attention_ops)
elif len(self.splitting_ops) == 0:
logger.warning_once("Using piecewise compilation with empty "
"splitting_ops.")
@@ -561,6 +562,18 @@ class CompilationConfig:
self.cudagraph_mode = CUDAGraphMode.FULL
self.splitting_ops = []
if envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput":
# exclude MoE dispatch/combine from capture by ensuring
# piecewise splitting includes them, so communication remains
# outside CUDA graphs while compute can still be graphed.
moe_ops = [
"vllm.moe_forward",
"vllm.moe_forward_shared",
]
for op in moe_ops:
if op not in self.splitting_ops:
self.splitting_ops.append(op)
def splitting_ops_contain_attention(self) -> bool:
return self.splitting_ops is not None and all(
op in self.splitting_ops for op in self._attention_ops)