[ROCm] Fix fused_moe_fake signature mismatch and other AITER bugs (#36100)
Signed-off-by: Li <chuali@amd.com>
This commit is contained in:
committed by
GitHub
parent
a16133a0f1
commit
e99fb98867
@@ -765,7 +765,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
if self.emulate:
|
||||
logger.warning_once(
|
||||
f"The current mode (supports_mx={current_platform.supports_mx()}, "
|
||||
f"use_mxfp4_aiter_moe={self.use_rocm_aiter_moe}, "
|
||||
f"use_rocm_aiter_moe={self.use_rocm_aiter_moe}, "
|
||||
f"ocp_mx_scheme={self.ocp_mx_scheme}) "
|
||||
"does not support native MXFP4/MXFP6 "
|
||||
"computation. Simulated weight dequantization and activation "
|
||||
|
||||
@@ -3,13 +3,12 @@
|
||||
|
||||
from collections.abc import Callable
|
||||
from fractions import Fraction
|
||||
from functools import cache, partial
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
@@ -37,22 +36,6 @@ from .quark_scheme import QuarkScheme
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# TODO: move registration of custom op to aiter_ops.py
|
||||
# `from vllm._aiter_ops import rocm_aiter_ops`
|
||||
# use `rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()`
|
||||
# for envs checks which does not require @cache anymore.
|
||||
# triton kernel is torch compile compatible.
|
||||
# does not require direct registration.
|
||||
# use `rocm_aiter_ops.triton_fp4_gemm_dynamic_qaunt`.
|
||||
@cache
|
||||
def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
|
||||
return (
|
||||
current_platform.is_rocm()
|
||||
and envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
from aiter.ops.triton.gemm_afp4wfp4 import (
|
||||
@@ -63,7 +46,7 @@ try:
|
||||
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
if is_rocm_aiter_fp4_asm_gemm_enabled():
|
||||
if rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled():
|
||||
from aiter import gemm_a4w4, per_1x32_f4_quant_hip
|
||||
|
||||
def gemm_with_dynamic_quant(
|
||||
@@ -233,7 +216,9 @@ class QuarkOCP_MX(QuarkScheme):
|
||||
self.input_dtype != "mxfp4" or self.weight_dtype != "mxfp4"
|
||||
)
|
||||
|
||||
self.rocm_use_aiter_fp4_asm_gemm = is_rocm_aiter_fp4_asm_gemm_enabled()
|
||||
self.rocm_use_aiter_fp4_asm_gemm = (
|
||||
rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()
|
||||
)
|
||||
|
||||
if not self.emulate and (dynamic_mxfp4_quant is None or gemm_afp4wfp4 is None):
|
||||
# Currently need these kernels if not emulating
|
||||
|
||||
Reference in New Issue
Block a user