[Core] Allow full cudagraph with separate attention routines and orthogonal to compilation, add support for FA2 and FlashInfer (#20059)
Signed-off-by: fhl <2410591650@qq.com> Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
@@ -32,7 +32,7 @@ from vllm import version
|
||||
from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
|
||||
PrefixCachingHashAlgo)
|
||||
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
|
||||
PassConfig)
|
||||
CUDAGraphMode, PassConfig)
|
||||
from vllm.config.parallel import DistributedExecutorBackend, ParallelConfig
|
||||
from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
|
||||
from vllm.config.utils import ConfigType, config
|
||||
@@ -3529,11 +3529,21 @@ class VllmConfig:
|
||||
else:
|
||||
self.compilation_config.level = \
|
||||
CompilationLevel.NO_COMPILATION
|
||||
|
||||
else:
|
||||
# NB: Passing both --enforce-eager and a compilation level
|
||||
# in V0 means the compilation level wins out.
|
||||
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
|
||||
# if cudagraph_mode is not explicitly set by users, set default value
|
||||
if self.compilation_config.cudagraph_mode is None:
|
||||
if envs.VLLM_USE_V1 and self.compilation_config.level \
|
||||
== CompilationLevel.PIECEWISE:
|
||||
self.compilation_config.cudagraph_mode = \
|
||||
CUDAGraphMode.PIECEWISE
|
||||
else:
|
||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
# async tp is built on top of sequence parallelism
|
||||
# and requires it to be enabled.
|
||||
if self.compilation_config.pass_config.enable_async_tp:
|
||||
@@ -3541,12 +3551,13 @@ class VllmConfig:
|
||||
True
|
||||
if self.compilation_config.pass_config.enable_sequence_parallelism:
|
||||
self.compilation_config.custom_ops.append("+rms_norm")
|
||||
if envs.VLLM_USE_V1 and self.model_config is not None and \
|
||||
not self.model_config.enforce_eager:
|
||||
# By default, V1 uses piecewise CUDA graphs. If full_cuda_graph
|
||||
# is set to True, full CUDA graphs will be used.
|
||||
|
||||
# disable cudagraph when enforce eager execution
|
||||
if self.model_config is not None and self.model_config.enforce_eager:
|
||||
logger.info("Cudagraph is disabled under eager mode")
|
||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
elif envs.VLLM_USE_V1:
|
||||
self.compilation_config.cudagraph_num_of_warmups = 1
|
||||
self.compilation_config.set_splitting_ops_for_v1()
|
||||
|
||||
self._set_cudagraph_sizes()
|
||||
|
||||
@@ -3566,12 +3577,6 @@ class VllmConfig:
|
||||
"Disabling `torch.compile`.")
|
||||
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
|
||||
if self.compilation_config.full_cuda_graph and \
|
||||
not self.model_config.disable_cascade_attn:
|
||||
logger.info("full_cuda_graph is not supported with "
|
||||
"cascade attention. Disabling cascade attention.")
|
||||
self.model_config.disable_cascade_attn = True
|
||||
|
||||
disable_chunked_prefill_reasons: list[str] = []
|
||||
|
||||
if self.model_config and self.model_config.pooler_config:
|
||||
@@ -3612,9 +3617,32 @@ class VllmConfig:
|
||||
"to True to enable.")
|
||||
current_platform.check_and_update_config(self)
|
||||
|
||||
# final check of cudagraph mode after platform-specific update
|
||||
if envs.VLLM_USE_V1:
|
||||
if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL \
|
||||
and self.model_config is not None and \
|
||||
not self.model_config.disable_cascade_attn:
|
||||
logger.info("CUDAGraphMode.FULL is not supported with "
|
||||
"cascade attention currently. Disabling cascade"
|
||||
"attention.")
|
||||
self.model_config.disable_cascade_attn = True
|
||||
|
||||
if self.compilation_config.cudagraph_mode\
|
||||
.requires_piecewise_compilation():
|
||||
assert self.compilation_config.level == \
|
||||
CompilationLevel.PIECEWISE, \
|
||||
"Compilation level should be CompilationLevel.PIECEWISE "\
|
||||
"when cudagraph_mode piecewise cudagraphs is used, "\
|
||||
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
|
||||
|
||||
if not self.instance_id:
|
||||
self.instance_id = random_uuid()[:5]
|
||||
|
||||
# Do this after all the updates to compilation_config.level
|
||||
if envs.VLLM_USE_V1 and \
|
||||
self.compilation_config.level == CompilationLevel.PIECEWISE:
|
||||
self.compilation_config.set_splitting_ops_for_v1()
|
||||
|
||||
if (envs.VLLM_USE_V1
|
||||
and not self.scheduler_config.disable_hybrid_kv_cache_manager):
|
||||
# logger should only print warning message for hybrid models. As we
|
||||
|
||||
Reference in New Issue
Block a user