{ROCm]: gpt-oss fusion/padding fixes (#38043)
Signed-off-by: Andreas Karatzas <akaratza@amd.com> Signed-off-by: Rohan138 <rohanpotdar138@gmail.com> Signed-off-by: Rohan Potdar <66227218+Rohan138@users.noreply.github.com> Co-authored-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -152,13 +152,11 @@ def enable_rope_kvcache_fusion(cfg: "VllmConfig") -> bool:
|
||||
|
||||
|
||||
def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool:
|
||||
"""Enable if using AITER RMSNorm and AITER Triton GEMMs
|
||||
and hidden size is 2880 i.e. gpt-oss; otherwise Inductor handles fusion."""
|
||||
"""Enable if using AITER RMSNorm and hidden size is 2880 i.e. gpt-oss."""
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
return (
|
||||
rocm_aiter_ops.is_rmsnorm_enabled()
|
||||
and not rocm_aiter_ops.is_triton_gemm_enabled()
|
||||
and cfg.model_config is not None
|
||||
and cfg.model_config.get_hidden_size() == 2880
|
||||
)
|
||||
|
||||
@@ -20,10 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
mxfp4_w4a16_moe_quant_config,
|
||||
ocp_mx_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
_swizzle_mxfp4,
|
||||
get_padding_alignment,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kMxfp4Static,
|
||||
@@ -396,9 +393,8 @@ def mxfp4_round_up_hidden_size_and_intermediate_size(
|
||||
intermediate_size = round_up(intermediate_size, 128)
|
||||
hidden_size = round_up(hidden_size, 128)
|
||||
elif current_platform.is_rocm():
|
||||
pad_align = get_padding_alignment()
|
||||
intermediate_size = round_up(intermediate_size, pad_align)
|
||||
hidden_size = round_up(hidden_size, pad_align)
|
||||
intermediate_size = round_up(intermediate_size, 256)
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
else:
|
||||
intermediate_size = round_up(intermediate_size, 64)
|
||||
return hidden_size, intermediate_size
|
||||
|
||||
@@ -6,7 +6,6 @@ import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.import_utils import has_triton_kernels
|
||||
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
|
||||
@@ -85,14 +84,6 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps=8):
|
||||
return quant_tensor, InFlexData(), scale
|
||||
|
||||
|
||||
def get_padding_alignment():
|
||||
return (
|
||||
256
|
||||
if triton.runtime.driver.active.get_current_target().arch in ("gfx950",)
|
||||
else 128
|
||||
)
|
||||
|
||||
|
||||
def _dequant_mxfp4(
|
||||
x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user