{ROCm]: gpt-oss fusion/padding fixes (#38043)
Signed-off-by: Andreas Karatzas <akaratza@amd.com> Signed-off-by: Rohan138 <rohanpotdar138@gmail.com> Signed-off-by: Rohan Potdar <66227218+Rohan138@users.noreply.github.com> Co-authored-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -20,10 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
mxfp4_w4a16_moe_quant_config,
|
||||
ocp_mx_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
_swizzle_mxfp4,
|
||||
get_padding_alignment,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kMxfp4Static,
|
||||
@@ -396,9 +393,8 @@ def mxfp4_round_up_hidden_size_and_intermediate_size(
|
||||
intermediate_size = round_up(intermediate_size, 128)
|
||||
hidden_size = round_up(hidden_size, 128)
|
||||
elif current_platform.is_rocm():
|
||||
pad_align = get_padding_alignment()
|
||||
intermediate_size = round_up(intermediate_size, pad_align)
|
||||
hidden_size = round_up(hidden_size, pad_align)
|
||||
intermediate_size = round_up(intermediate_size, 256)
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
else:
|
||||
intermediate_size = round_up(intermediate_size, 64)
|
||||
return hidden_size, intermediate_size
|
||||
|
||||
@@ -6,7 +6,6 @@ import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.import_utils import has_triton_kernels
|
||||
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
|
||||
@@ -85,14 +84,6 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps=8):
|
||||
return quant_tensor, InFlexData(), scale
|
||||
|
||||
|
||||
def get_padding_alignment():
|
||||
return (
|
||||
256
|
||||
if triton.runtime.driver.active.get_current_target().arch in ("gfx950",)
|
||||
else 128
|
||||
)
|
||||
|
||||
|
||||
def _dequant_mxfp4(
|
||||
x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user