diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index fba3c64a9..127c16ac7 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -809,6 +809,8 @@ class VllmConfig: if "-quant_fp8" not in custom_ops: custom_ops.append("+quant_fp8") + current_platform.apply_config_platform_defaults(self) + if self.compilation_config.mode is None: if self.optimization_level > OptimizationLevel.O0: self.compilation_config.mode = CompilationMode.VLLM_COMPILE diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 75e716479..5dae76757 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -393,6 +393,20 @@ class Platform: """ pass + @classmethod + def apply_config_platform_defaults(cls, vllm_config: "VllmConfig") -> None: + """ + Apply the platform-specific default values to the config. + + This function is called during the initialization of global VllmConfig, after + parsing cli arguments. + It can modify the defaults of the config according to the platform. For example, + it can enable custom_ops based on the enabled features. + + The config is passed by reference, so it can be modified in place. + """ + pass + @classmethod def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: """ diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index e1e2ffb1d..3808ecc6e 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -482,19 +482,61 @@ class RocmPlatform(Platform): return device_props.total_memory @classmethod - def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: + def apply_config_platform_defaults(cls, vllm_config: "VllmConfig") -> None: from vllm._aiter_ops import rocm_aiter_ops from vllm.config.compilation import CUDAGraphMode - cache_config = vllm_config.cache_config compilation_config = vllm_config.compilation_config - parallel_config = vllm_config.parallel_config - is_eager_execution = compilation_config == CUDAGraphMode.NONE + is_eager_execution = compilation_config.cudagraph_mode == CUDAGraphMode.NONE use_aiter_fused_moe = rocm_aiter_ops.is_fused_moe_enabled() use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled() use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enabled() use_aiter_fused_se = rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() use_aiter_triton_rope = rocm_aiter_ops.is_triton_rotary_embed_enabled() + # Aiter rms norm perform best when CUDA Graph capture is enabled. + if ( + use_aiter_rms_norm + and not is_eager_execution + and "-rms_norm" not in compilation_config.custom_ops + ): + compilation_config.custom_ops.append("+rms_norm") + + if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops: + compilation_config.custom_ops.append("+quant_fp8") + + if use_aiter_fused_se and "-grouped_topk" in compilation_config.custom_ops: + logger.warning_once( + "VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled, which " + "requires the 'grouped_topk' custom op. Overriding the " + "user-provided '-grouped_topk'." + ) + compilation_config.custom_ops.remove("-grouped_topk") + # Ensure grouped_topk is always enabled when using AITER if + # its not disabled by user + if ( + use_aiter_fused_moe + and "+grouped_topk" not in compilation_config.custom_ops + and "-grouped_topk" not in compilation_config.custom_ops + ): + compilation_config.custom_ops.append("+grouped_topk") + # Enable rotary embedding when using AITER if its not disabled by user + if ( + use_aiter_triton_rope + and "+rotary_embedding" not in compilation_config.custom_ops + and "-rotary_embedding" not in compilation_config.custom_ops + ): + compilation_config.custom_ops.append("+rotary_embedding") + + # Default dispatch to rocm's sparse_attn_indexer implementation + compilation_config.custom_ops.append("+sparse_attn_indexer") + + @classmethod + def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: + from vllm.config.compilation import CUDAGraphMode + + cache_config = vllm_config.cache_config + compilation_config = vllm_config.compilation_config + parallel_config = vllm_config.parallel_config if compilation_config.cudagraph_mode.has_full_cudagraphs(): # decode context parallel does not support full cudagraphs @@ -533,42 +575,6 @@ class RocmPlatform(Platform): if parallel_config.worker_cls == "auto": parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" - # Aiter rms norm perform best when CUDA Graph capture is enabled. - if ( - use_aiter_rms_norm - and not is_eager_execution - and "-rms_norm" not in compilation_config.custom_ops - ): - compilation_config.custom_ops.append("+rms_norm") - - if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops: - compilation_config.custom_ops.append("+quant_fp8") - - if use_aiter_fused_se and "-grouped_topk" in compilation_config.custom_ops: - logger.warning_once( - "VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled, which " - "requires the 'grouped_topk' custom op. Overriding the " - "user-provided '-grouped_topk'." - ) - compilation_config.custom_ops.remove("-grouped_topk") - # Ensure grouped_topk is always enabled when using AITER if - # its not disabled by user - if ( - use_aiter_fused_moe - and "+grouped_topk" not in compilation_config.custom_ops - and "-grouped_topk" not in compilation_config.custom_ops - ): - compilation_config.custom_ops.append("+grouped_topk") - # Enable rotary embedding when using AITER if its not disabled by user - if ( - use_aiter_triton_rope - and "+rotary_embedding" not in compilation_config.custom_ops - and "-rotary_embedding" not in compilation_config.custom_ops - ): - compilation_config.custom_ops.append("+rotary_embedding") - - # Default dispatch to rocm's sparse_attn_indexer implementation - compilation_config.custom_ops.append("+sparse_attn_indexer") @classmethod def verify_model_arch(cls, model_arch: str) -> None: