From 0e9358c11daf3f5a2d4e8f80a100b6d5e070e1a1 Mon Sep 17 00:00:00 2001 From: Rohan Potdar <66227218+Rohan138@users.noreply.github.com> Date: Fri, 27 Mar 2026 11:19:15 -0500 Subject: [PATCH] {ROCm]: gpt-oss fusion/padding fixes (#38043) Signed-off-by: Andreas Karatzas Signed-off-by: Rohan138 Signed-off-by: Rohan Potdar <66227218+Rohan138@users.noreply.github.com> Co-authored-by: Andreas Karatzas --- vllm/config/vllm.py | 4 +--- vllm/model_executor/layers/fused_moe/oracle/mxfp4.py | 10 +++------- .../layers/quantization/utils/mxfp4_utils.py | 9 --------- 3 files changed, 4 insertions(+), 19 deletions(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 382b66a70..b6be7f10b 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -152,13 +152,11 @@ def enable_rope_kvcache_fusion(cfg: "VllmConfig") -> bool: def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool: - """Enable if using AITER RMSNorm and AITER Triton GEMMs - and hidden size is 2880 i.e. gpt-oss; otherwise Inductor handles fusion.""" + """Enable if using AITER RMSNorm and hidden size is 2880 i.e. gpt-oss.""" from vllm._aiter_ops import rocm_aiter_ops return ( rocm_aiter_ops.is_rmsnorm_enabled() - and not rocm_aiter_ops.is_triton_gemm_enabled() and cfg.model_config is not None and cfg.model_config.get_hidden_size() == 2880 ) diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index 77df6edf9..9008bdeec 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -20,10 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import ( mxfp4_w4a16_moe_quant_config, ocp_mx_moe_quant_config, ) -from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( - _swizzle_mxfp4, - get_padding_alignment, -) +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4 from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kMxfp4Static, @@ -396,9 +393,8 @@ def mxfp4_round_up_hidden_size_and_intermediate_size( intermediate_size = round_up(intermediate_size, 128) hidden_size = round_up(hidden_size, 128) elif current_platform.is_rocm(): - pad_align = get_padding_alignment() - intermediate_size = round_up(intermediate_size, pad_align) - hidden_size = round_up(hidden_size, pad_align) + intermediate_size = round_up(intermediate_size, 256) + hidden_size = round_up(hidden_size, 256) else: intermediate_size = round_up(intermediate_size, 64) return hidden_size, intermediate_size diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 49ddc8acc..40b26528b 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -6,7 +6,6 @@ import torch from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.triton_utils import triton from vllm.utils.import_utils import has_triton_kernels from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer @@ -85,14 +84,6 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps=8): return quant_tensor, InFlexData(), scale -def get_padding_alignment(): - return ( - 256 - if triton.runtime.driver.active.get_current_target().arch in ("gfx950",) - else 128 - ) - - def _dequant_mxfp4( x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype ) -> torch.Tensor: