[V1] [Hybrid] Enable Full CUDA graph by default for hybrid models in V1 (#22594)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@@ -4,6 +4,7 @@ from copy import deepcopy
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.config.compilation import CUDAGraphMode
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||||
@@ -275,6 +276,42 @@ class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
|
|||||||
"%d for performance.", 1024)
|
"%d for performance.", 1024)
|
||||||
|
|
||||||
|
|
||||||
|
class MambaModelConfig(VerifyAndUpdateConfig):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||||
|
"""
|
||||||
|
Enable FULL_AND_PIECEWISE cuda graph mode by default (required
|
||||||
|
to get good performance for mamba layers in V1).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vllm_config: vLLM Config
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not envs.VLLM_USE_V1:
|
||||||
|
return
|
||||||
|
|
||||||
|
model_config = vllm_config.model_config
|
||||||
|
compilation_config = vllm_config.compilation_config
|
||||||
|
|
||||||
|
model_cls, _ = ModelRegistry.resolve_model_cls(
|
||||||
|
model_config.architecture,
|
||||||
|
model_config=model_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO(tdoublep): remove as full cuda graph support is added
|
||||||
|
FCG_NOT_SUPPORTED_MODELS = [
|
||||||
|
"Lfm2ForCausalLM", "MiniMaxText01ForCausalLM"
|
||||||
|
]
|
||||||
|
|
||||||
|
if (model_config.architecture not in FCG_NOT_SUPPORTED_MODELS
|
||||||
|
and compilation_config.cudagraph_mode is None):
|
||||||
|
logger.info(
|
||||||
|
"Hybrid or mamba-based model detected: setting cudagraph mode "
|
||||||
|
"to FULL_AND_PIECEWISE in order to optimize performance.")
|
||||||
|
compilation_config.cudagraph_mode = CUDAGraphMode.FULL_AND_PIECEWISE
|
||||||
|
|
||||||
|
|
||||||
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -293,6 +330,9 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
|||||||
if not envs.VLLM_USE_V1:
|
if not envs.VLLM_USE_V1:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Enable FULL_AND_PIECEWISE by default
|
||||||
|
MambaModelConfig.verify_and_update_config(vllm_config)
|
||||||
|
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
@@ -374,4 +414,6 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
|||||||
"JambaForSequenceClassification": JambaForSequenceClassificationConfig,
|
"JambaForSequenceClassification": JambaForSequenceClassificationConfig,
|
||||||
"GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig,
|
"GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig,
|
||||||
"GptOssForCausalLM": GptOssForCausalLMConfig,
|
"GptOssForCausalLM": GptOssForCausalLMConfig,
|
||||||
|
"MambaForCausalLM": MambaModelConfig,
|
||||||
|
"Mamba2ForCausalLM": MambaModelConfig,
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user