[ROCm] Add extra step in config initialization to populate custom ops before compilation config init (#34848)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
committed by
GitHub
parent
9f9a675b23
commit
6042e66cd5
@@ -809,6 +809,8 @@ class VllmConfig:
|
||||
if "-quant_fp8" not in custom_ops:
|
||||
custom_ops.append("+quant_fp8")
|
||||
|
||||
current_platform.apply_config_platform_defaults(self)
|
||||
|
||||
if self.compilation_config.mode is None:
|
||||
if self.optimization_level > OptimizationLevel.O0:
|
||||
self.compilation_config.mode = CompilationMode.VLLM_COMPILE
|
||||
|
||||
@@ -393,6 +393,20 @@ class Platform:
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def apply_config_platform_defaults(cls, vllm_config: "VllmConfig") -> None:
|
||||
"""
|
||||
Apply the platform-specific default values to the config.
|
||||
|
||||
This function is called during the initialization of global VllmConfig, after
|
||||
parsing cli arguments.
|
||||
It can modify the defaults of the config according to the platform. For example,
|
||||
it can enable custom_ops based on the enabled features.
|
||||
|
||||
The config is passed by reference, so it can be modified in place.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
"""
|
||||
|
||||
@@ -482,19 +482,61 @@ class RocmPlatform(Platform):
|
||||
return device_props.total_memory
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
def apply_config_platform_defaults(cls, vllm_config: "VllmConfig") -> None:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
compilation_config = vllm_config.compilation_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
is_eager_execution = compilation_config == CUDAGraphMode.NONE
|
||||
is_eager_execution = compilation_config.cudagraph_mode == CUDAGraphMode.NONE
|
||||
use_aiter_fused_moe = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
|
||||
use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enabled()
|
||||
use_aiter_fused_se = rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
|
||||
use_aiter_triton_rope = rocm_aiter_ops.is_triton_rotary_embed_enabled()
|
||||
# Aiter rms norm perform best when CUDA Graph capture is enabled.
|
||||
if (
|
||||
use_aiter_rms_norm
|
||||
and not is_eager_execution
|
||||
and "-rms_norm" not in compilation_config.custom_ops
|
||||
):
|
||||
compilation_config.custom_ops.append("+rms_norm")
|
||||
|
||||
if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
|
||||
compilation_config.custom_ops.append("+quant_fp8")
|
||||
|
||||
if use_aiter_fused_se and "-grouped_topk" in compilation_config.custom_ops:
|
||||
logger.warning_once(
|
||||
"VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled, which "
|
||||
"requires the 'grouped_topk' custom op. Overriding the "
|
||||
"user-provided '-grouped_topk'."
|
||||
)
|
||||
compilation_config.custom_ops.remove("-grouped_topk")
|
||||
# Ensure grouped_topk is always enabled when using AITER if
|
||||
# its not disabled by user
|
||||
if (
|
||||
use_aiter_fused_moe
|
||||
and "+grouped_topk" not in compilation_config.custom_ops
|
||||
and "-grouped_topk" not in compilation_config.custom_ops
|
||||
):
|
||||
compilation_config.custom_ops.append("+grouped_topk")
|
||||
# Enable rotary embedding when using AITER if its not disabled by user
|
||||
if (
|
||||
use_aiter_triton_rope
|
||||
and "+rotary_embedding" not in compilation_config.custom_ops
|
||||
and "-rotary_embedding" not in compilation_config.custom_ops
|
||||
):
|
||||
compilation_config.custom_ops.append("+rotary_embedding")
|
||||
|
||||
# Default dispatch to rocm's sparse_attn_indexer implementation
|
||||
compilation_config.custom_ops.append("+sparse_attn_indexer")
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
compilation_config = vllm_config.compilation_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
if compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
# decode context parallel does not support full cudagraphs
|
||||
@@ -533,42 +575,6 @@ class RocmPlatform(Platform):
|
||||
|
||||
if parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
|
||||
# Aiter rms norm perform best when CUDA Graph capture is enabled.
|
||||
if (
|
||||
use_aiter_rms_norm
|
||||
and not is_eager_execution
|
||||
and "-rms_norm" not in compilation_config.custom_ops
|
||||
):
|
||||
compilation_config.custom_ops.append("+rms_norm")
|
||||
|
||||
if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
|
||||
compilation_config.custom_ops.append("+quant_fp8")
|
||||
|
||||
if use_aiter_fused_se and "-grouped_topk" in compilation_config.custom_ops:
|
||||
logger.warning_once(
|
||||
"VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled, which "
|
||||
"requires the 'grouped_topk' custom op. Overriding the "
|
||||
"user-provided '-grouped_topk'."
|
||||
)
|
||||
compilation_config.custom_ops.remove("-grouped_topk")
|
||||
# Ensure grouped_topk is always enabled when using AITER if
|
||||
# its not disabled by user
|
||||
if (
|
||||
use_aiter_fused_moe
|
||||
and "+grouped_topk" not in compilation_config.custom_ops
|
||||
and "-grouped_topk" not in compilation_config.custom_ops
|
||||
):
|
||||
compilation_config.custom_ops.append("+grouped_topk")
|
||||
# Enable rotary embedding when using AITER if its not disabled by user
|
||||
if (
|
||||
use_aiter_triton_rope
|
||||
and "+rotary_embedding" not in compilation_config.custom_ops
|
||||
and "-rotary_embedding" not in compilation_config.custom_ops
|
||||
):
|
||||
compilation_config.custom_ops.append("+rotary_embedding")
|
||||
|
||||
# Default dispatch to rocm's sparse_attn_indexer implementation
|
||||
compilation_config.custom_ops.append("+sparse_attn_indexer")
|
||||
|
||||
@classmethod
|
||||
def verify_model_arch(cls, model_arch: str) -> None:
|
||||
|
||||
Reference in New Issue
Block a user