{ROCm]: gpt-oss fusion/padding fixes (#38043)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan Potdar <66227218+Rohan138@users.noreply.github.com>
Co-authored-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Rohan Potdar
2026-03-27 11:19:15 -05:00
committed by GitHub
parent 21d2b53f88
commit 0e9358c11d
3 changed files with 4 additions and 19 deletions

View File

@@ -152,13 +152,11 @@ def enable_rope_kvcache_fusion(cfg: "VllmConfig") -> bool:
def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool: def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool:
"""Enable if using AITER RMSNorm and AITER Triton GEMMs """Enable if using AITER RMSNorm and hidden size is 2880 i.e. gpt-oss."""
and hidden size is 2880 i.e. gpt-oss; otherwise Inductor handles fusion."""
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
return ( return (
rocm_aiter_ops.is_rmsnorm_enabled() rocm_aiter_ops.is_rmsnorm_enabled()
and not rocm_aiter_ops.is_triton_gemm_enabled()
and cfg.model_config is not None and cfg.model_config is not None
and cfg.model_config.get_hidden_size() == 2880 and cfg.model_config.get_hidden_size() == 2880
) )

View File

@@ -20,10 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
mxfp4_w4a16_moe_quant_config, mxfp4_w4a16_moe_quant_config,
ocp_mx_moe_quant_config, ocp_mx_moe_quant_config,
) )
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4
_swizzle_mxfp4,
get_padding_alignment,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
kMxfp4Static, kMxfp4Static,
@@ -396,9 +393,8 @@ def mxfp4_round_up_hidden_size_and_intermediate_size(
intermediate_size = round_up(intermediate_size, 128) intermediate_size = round_up(intermediate_size, 128)
hidden_size = round_up(hidden_size, 128) hidden_size = round_up(hidden_size, 128)
elif current_platform.is_rocm(): elif current_platform.is_rocm():
pad_align = get_padding_alignment() intermediate_size = round_up(intermediate_size, 256)
intermediate_size = round_up(intermediate_size, pad_align) hidden_size = round_up(hidden_size, 256)
hidden_size = round_up(hidden_size, pad_align)
else: else:
intermediate_size = round_up(intermediate_size, 64) intermediate_size = round_up(intermediate_size, 64)
return hidden_size, intermediate_size return hidden_size, intermediate_size

View File

@@ -6,7 +6,6 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils.import_utils import has_triton_kernels from vllm.utils.import_utils import has_triton_kernels
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
@@ -85,14 +84,6 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps=8):
return quant_tensor, InFlexData(), scale return quant_tensor, InFlexData(), scale
def get_padding_alignment():
return (
256
if triton.runtime.driver.active.get_current_target().arch in ("gfx950",)
else 128
)
def _dequant_mxfp4( def _dequant_mxfp4(
x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype
) -> torch.Tensor: ) -> torch.Tensor: