[Perf] Enable cuda graph for deepepHT, 5.3% throughput improvement, 4.4% TTFT improvement (#29558)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -232,44 +232,6 @@ class CudaPlatformBase(Platform):
|
||||
logger.info(
|
||||
"Forcing kv cache block size to 64 for FlashMLASparse backend."
|
||||
)
|
||||
# lazy import to avoid circular import
|
||||
from vllm.config import CUDAGraphMode
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
# decode context parallel does not support full cudagraphs
|
||||
if parallel_config.decode_context_parallel_size > 1:
|
||||
logger.warning_once(
|
||||
"Decode context parallel (DCP) is enabled, which is "
|
||||
"incompatible with full CUDA graphs. "
|
||||
"Overriding cudagraph_mode to PIECEWISE."
|
||||
)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
# prefill context parallel do not support full cudagraphs
|
||||
elif parallel_config.prefill_context_parallel_size > 1:
|
||||
logger.warning_once(
|
||||
"Prefill context parallel (PCP) is enabled, which is "
|
||||
"incompatible with full CUDA graphs. "
|
||||
"Overriding cudagraph_mode to PIECEWISE."
|
||||
)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
if (
|
||||
parallel_config.all2all_backend == "deepep_high_throughput"
|
||||
and parallel_config.data_parallel_size > 1
|
||||
and compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
):
|
||||
# TODO: Piecewise Cuda graph might be enabled
|
||||
# if torch compile cache key issue fixed
|
||||
# See https://github.com/vllm-project/vllm/pull/25093
|
||||
logger.info(
|
||||
"WideEP: Disabling CUDA Graphs since DeepEP high-throughput "
|
||||
"kernels are optimized for prefill and are incompatible with "
|
||||
"CUDA Graphs. "
|
||||
"In order to use CUDA Graphs for decode-optimized workloads, "
|
||||
"use --all2all-backend with another option, such as "
|
||||
"deepep_low_latency, pplx, or allgather_reducescatter."
|
||||
)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(
|
||||
|
||||
Reference in New Issue
Block a user