[ROCm] Fix fused_moe_fake signature mismatch and other AITER bugs (#36100)

Signed-off-by: Li <chuali@amd.com>
This commit is contained in:
Chuan (Richard) Li
2026-03-23 00:48:31 -07:00
committed by GitHub
parent a16133a0f1
commit e99fb98867
4 changed files with 16 additions and 26 deletions

View File

@@ -137,6 +137,10 @@ def _rocm_aiter_fused_moe_fake(
a2_scale: torch.Tensor | None = None,
num_local_tokens: torch.Tensor | None = None,
output_dtype: torch.dtype | None = None,
hidden_pad: int = 0,
intermediate_pad: int = 0,
bias1: torch.Tensor | None = None,
bias2: torch.Tensor | None = None,
) -> torch.Tensor:
if output_dtype is not None:
return torch.empty_like(hidden_states, dtype=output_dtype)
@@ -1700,7 +1704,7 @@ class rocm_aiter_ops:
)
@staticmethod
def triton_fp4_gemm_dynamic_qaunt(
def triton_fp4_gemm_dynamic_quant(
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,

View File

@@ -765,7 +765,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
if self.emulate:
logger.warning_once(
f"The current mode (supports_mx={current_platform.supports_mx()}, "
f"use_mxfp4_aiter_moe={self.use_rocm_aiter_moe}, "
f"use_rocm_aiter_moe={self.use_rocm_aiter_moe}, "
f"ocp_mx_scheme={self.ocp_mx_scheme}) "
"does not support native MXFP4/MXFP6 "
"computation. Simulated weight dequantization and activation "

View File

@@ -3,13 +3,12 @@
from collections.abc import Callable
from fractions import Fraction
from functools import cache, partial
from functools import partial
from typing import Any
import torch
import torch.nn.functional as F
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
@@ -37,22 +36,6 @@ from .quark_scheme import QuarkScheme
logger = init_logger(__name__)
# TODO: move registration of custom op to aiter_ops.py
# `from vllm._aiter_ops import rocm_aiter_ops`
# use `rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()`
# for envs checks which does not require @cache anymore.
# triton kernel is torch compile compatible.
# does not require direct registration.
# use `rocm_aiter_ops.triton_fp4_gemm_dynamic_qaunt`.
@cache
def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
return (
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
and envs.VLLM_ROCM_USE_AITER
)
try:
from aiter.ops.shuffle import shuffle_weight
from aiter.ops.triton.gemm_afp4wfp4 import (
@@ -63,7 +46,7 @@ try:
from vllm.utils.torch_utils import direct_register_custom_op
if is_rocm_aiter_fp4_asm_gemm_enabled():
if rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled():
from aiter import gemm_a4w4, per_1x32_f4_quant_hip
def gemm_with_dynamic_quant(
@@ -233,7 +216,9 @@ class QuarkOCP_MX(QuarkScheme):
self.input_dtype != "mxfp4" or self.weight_dtype != "mxfp4"
)
self.rocm_use_aiter_fp4_asm_gemm = is_rocm_aiter_fp4_asm_gemm_enabled()
self.rocm_use_aiter_fp4_asm_gemm = (
rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()
)
if not self.emulate and (dynamic_mxfp4_quant is None or gemm_afp4wfp4 is None):
# Currently need these kernels if not emulating

View File

@@ -157,13 +157,13 @@ if current_platform.is_rocm():
total_tokens: int,
):
assert kv_cache_layout in ["NHD", "SHUFFLE"], (
"kv_cache_layout only support NHD, SHUFFLE"
"kv_cache_layout only supports NHD, SHUFFLE"
)
head_dim = key.shape[2]
x = 16 // key_cache.element_size()
# assert dequant is True, "Currently, we only support "\
# "gather cache with dequant"
# For k cache layout: [num_blocks, num_heads, page_size, head_dim]
# For k cache layout: [num_blocks, page_size, num_heads, head_dim]
assert head_dim == key_cache.shape[3], (
"We assume your kv cache layout is [num_blocks, "
"page_size, num_heads, head_dim], but got otherwise"
@@ -832,7 +832,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
raise NotImplementedError(
"Encoder self-attention is not implemented for FlashAttentionImpl"
"Encoder self-attention is not implemented for AiterFlashAttentionImpl"
)
def extend_for_sliding_window(
@@ -1047,7 +1047,8 @@ class AiterFlashAttentionImpl(AttentionImpl):
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported for FlashAttentionImpl"
"fused output quantization is not yet supported "
"for AiterFlashAttentionImpl"
)
if attn_metadata is None: