[ROCm] Add extra step in config initialization to populate custom ops before compilation config init (#34848)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
Gregory Shtrasberg
2026-02-26 02:05:40 -06:00
committed by GitHub
parent 9f9a675b23
commit 6042e66cd5
3 changed files with 62 additions and 40 deletions

View File

@@ -809,6 +809,8 @@ class VllmConfig:
if "-quant_fp8" not in custom_ops:
custom_ops.append("+quant_fp8")
current_platform.apply_config_platform_defaults(self)
if self.compilation_config.mode is None:
if self.optimization_level > OptimizationLevel.O0:
self.compilation_config.mode = CompilationMode.VLLM_COMPILE

View File

@@ -393,6 +393,20 @@ class Platform:
"""
pass
@classmethod
def apply_config_platform_defaults(cls, vllm_config: "VllmConfig") -> None:
"""
Apply the platform-specific default values to the config.
This function is called during the initialization of global VllmConfig, after
parsing cli arguments.
It can modify the defaults of the config according to the platform. For example,
it can enable custom_ops based on the enabled features.
The config is passed by reference, so it can be modified in place.
"""
pass
@classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"""

View File

@@ -482,19 +482,61 @@ class RocmPlatform(Platform):
return device_props.total_memory
@classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
def apply_config_platform_defaults(cls, vllm_config: "VllmConfig") -> None:
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config.compilation import CUDAGraphMode
cache_config = vllm_config.cache_config
compilation_config = vllm_config.compilation_config
parallel_config = vllm_config.parallel_config
is_eager_execution = compilation_config == CUDAGraphMode.NONE
is_eager_execution = compilation_config.cudagraph_mode == CUDAGraphMode.NONE
use_aiter_fused_moe = rocm_aiter_ops.is_fused_moe_enabled()
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enabled()
use_aiter_fused_se = rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
use_aiter_triton_rope = rocm_aiter_ops.is_triton_rotary_embed_enabled()
# Aiter rms norm perform best when CUDA Graph capture is enabled.
if (
use_aiter_rms_norm
and not is_eager_execution
and "-rms_norm" not in compilation_config.custom_ops
):
compilation_config.custom_ops.append("+rms_norm")
if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
compilation_config.custom_ops.append("+quant_fp8")
if use_aiter_fused_se and "-grouped_topk" in compilation_config.custom_ops:
logger.warning_once(
"VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled, which "
"requires the 'grouped_topk' custom op. Overriding the "
"user-provided '-grouped_topk'."
)
compilation_config.custom_ops.remove("-grouped_topk")
# Ensure grouped_topk is always enabled when using AITER if
# its not disabled by user
if (
use_aiter_fused_moe
and "+grouped_topk" not in compilation_config.custom_ops
and "-grouped_topk" not in compilation_config.custom_ops
):
compilation_config.custom_ops.append("+grouped_topk")
# Enable rotary embedding when using AITER if its not disabled by user
if (
use_aiter_triton_rope
and "+rotary_embedding" not in compilation_config.custom_ops
and "-rotary_embedding" not in compilation_config.custom_ops
):
compilation_config.custom_ops.append("+rotary_embedding")
# Default dispatch to rocm's sparse_attn_indexer implementation
compilation_config.custom_ops.append("+sparse_attn_indexer")
@classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
from vllm.config.compilation import CUDAGraphMode
cache_config = vllm_config.cache_config
compilation_config = vllm_config.compilation_config
parallel_config = vllm_config.parallel_config
if compilation_config.cudagraph_mode.has_full_cudagraphs():
# decode context parallel does not support full cudagraphs
@@ -533,42 +575,6 @@ class RocmPlatform(Platform):
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
# Aiter rms norm perform best when CUDA Graph capture is enabled.
if (
use_aiter_rms_norm
and not is_eager_execution
and "-rms_norm" not in compilation_config.custom_ops
):
compilation_config.custom_ops.append("+rms_norm")
if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
compilation_config.custom_ops.append("+quant_fp8")
if use_aiter_fused_se and "-grouped_topk" in compilation_config.custom_ops:
logger.warning_once(
"VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled, which "
"requires the 'grouped_topk' custom op. Overriding the "
"user-provided '-grouped_topk'."
)
compilation_config.custom_ops.remove("-grouped_topk")
# Ensure grouped_topk is always enabled when using AITER if
# its not disabled by user
if (
use_aiter_fused_moe
and "+grouped_topk" not in compilation_config.custom_ops
and "-grouped_topk" not in compilation_config.custom_ops
):
compilation_config.custom_ops.append("+grouped_topk")
# Enable rotary embedding when using AITER if its not disabled by user
if (
use_aiter_triton_rope
and "+rotary_embedding" not in compilation_config.custom_ops
and "-rotary_embedding" not in compilation_config.custom_ops
):
compilation_config.custom_ops.append("+rotary_embedding")
# Default dispatch to rocm's sparse_attn_indexer implementation
compilation_config.custom_ops.append("+sparse_attn_indexer")
@classmethod
def verify_model_arch(cls, model_arch: str) -> None: