{ROCm]: gpt-oss fusion/padding fixes (#38043)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan Potdar <66227218+Rohan138@users.noreply.github.com>
Co-authored-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Rohan Potdar
2026-03-27 11:19:15 -05:00
committed by GitHub
parent 21d2b53f88
commit 0e9358c11d
3 changed files with 4 additions and 19 deletions

View File

@@ -152,13 +152,11 @@ def enable_rope_kvcache_fusion(cfg: "VllmConfig") -> bool:
def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool:
"""Enable if using AITER RMSNorm and AITER Triton GEMMs
and hidden size is 2880 i.e. gpt-oss; otherwise Inductor handles fusion."""
"""Enable if using AITER RMSNorm and hidden size is 2880 i.e. gpt-oss."""
from vllm._aiter_ops import rocm_aiter_ops
return (
rocm_aiter_ops.is_rmsnorm_enabled()
and not rocm_aiter_ops.is_triton_gemm_enabled()
and cfg.model_config is not None
and cfg.model_config.get_hidden_size() == 2880
)

View File

@@ -20,10 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
mxfp4_w4a16_moe_quant_config,
ocp_mx_moe_quant_config,
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
_swizzle_mxfp4,
get_padding_alignment,
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kMxfp4Static,
@@ -396,9 +393,8 @@ def mxfp4_round_up_hidden_size_and_intermediate_size(
intermediate_size = round_up(intermediate_size, 128)
hidden_size = round_up(hidden_size, 128)
elif current_platform.is_rocm():
pad_align = get_padding_alignment()
intermediate_size = round_up(intermediate_size, pad_align)
hidden_size = round_up(hidden_size, pad_align)
intermediate_size = round_up(intermediate_size, 256)
hidden_size = round_up(hidden_size, 256)
else:
intermediate_size = round_up(intermediate_size, 64)
return hidden_size, intermediate_size

View File

@@ -6,7 +6,6 @@ import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils.import_utils import has_triton_kernels
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
@@ -85,14 +84,6 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps=8):
return quant_tensor, InFlexData(), scale
def get_padding_alignment():
return (
256
if triton.runtime.driver.active.get_current_target().arch in ("gfx950",)
else 128
)
def _dequant_mxfp4(
x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype
) -> torch.Tensor: