[ROCm] Fix fused_moe_fake signature mismatch and other AITER bugs (#36100)

Signed-off-by: Li <chuali@amd.com>
This commit is contained in:
Chuan (Richard) Li
2026-03-23 00:48:31 -07:00
committed by GitHub
parent a16133a0f1
commit e99fb98867
4 changed files with 16 additions and 26 deletions

View File

@@ -137,6 +137,10 @@ def _rocm_aiter_fused_moe_fake(
a2_scale: torch.Tensor | None = None, a2_scale: torch.Tensor | None = None,
num_local_tokens: torch.Tensor | None = None, num_local_tokens: torch.Tensor | None = None,
output_dtype: torch.dtype | None = None, output_dtype: torch.dtype | None = None,
hidden_pad: int = 0,
intermediate_pad: int = 0,
bias1: torch.Tensor | None = None,
bias2: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
if output_dtype is not None: if output_dtype is not None:
return torch.empty_like(hidden_states, dtype=output_dtype) return torch.empty_like(hidden_states, dtype=output_dtype)
@@ -1700,7 +1704,7 @@ class rocm_aiter_ops:
) )
@staticmethod @staticmethod
def triton_fp4_gemm_dynamic_qaunt( def triton_fp4_gemm_dynamic_quant(
x: torch.Tensor, x: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
weight_scale: torch.Tensor, weight_scale: torch.Tensor,

View File

@@ -765,7 +765,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
if self.emulate: if self.emulate:
logger.warning_once( logger.warning_once(
f"The current mode (supports_mx={current_platform.supports_mx()}, " f"The current mode (supports_mx={current_platform.supports_mx()}, "
f"use_mxfp4_aiter_moe={self.use_rocm_aiter_moe}, " f"use_rocm_aiter_moe={self.use_rocm_aiter_moe}, "
f"ocp_mx_scheme={self.ocp_mx_scheme}) " f"ocp_mx_scheme={self.ocp_mx_scheme}) "
"does not support native MXFP4/MXFP6 " "does not support native MXFP4/MXFP6 "
"computation. Simulated weight dequantization and activation " "computation. Simulated weight dequantization and activation "

View File

@@ -3,13 +3,12 @@
from collections.abc import Callable from collections.abc import Callable
from fractions import Fraction from fractions import Fraction
from functools import cache, partial from functools import partial
from typing import Any from typing import Any
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
@@ -37,22 +36,6 @@ from .quark_scheme import QuarkScheme
logger = init_logger(__name__) logger = init_logger(__name__)
# TODO: move registration of custom op to aiter_ops.py
# `from vllm._aiter_ops import rocm_aiter_ops`
# use `rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()`
# for envs checks which does not require @cache anymore.
# triton kernel is torch compile compatible.
# does not require direct registration.
# use `rocm_aiter_ops.triton_fp4_gemm_dynamic_qaunt`.
@cache
def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
return (
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
and envs.VLLM_ROCM_USE_AITER
)
try: try:
from aiter.ops.shuffle import shuffle_weight from aiter.ops.shuffle import shuffle_weight
from aiter.ops.triton.gemm_afp4wfp4 import ( from aiter.ops.triton.gemm_afp4wfp4 import (
@@ -63,7 +46,7 @@ try:
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
if is_rocm_aiter_fp4_asm_gemm_enabled(): if rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled():
from aiter import gemm_a4w4, per_1x32_f4_quant_hip from aiter import gemm_a4w4, per_1x32_f4_quant_hip
def gemm_with_dynamic_quant( def gemm_with_dynamic_quant(
@@ -233,7 +216,9 @@ class QuarkOCP_MX(QuarkScheme):
self.input_dtype != "mxfp4" or self.weight_dtype != "mxfp4" self.input_dtype != "mxfp4" or self.weight_dtype != "mxfp4"
) )
self.rocm_use_aiter_fp4_asm_gemm = is_rocm_aiter_fp4_asm_gemm_enabled() self.rocm_use_aiter_fp4_asm_gemm = (
rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()
)
if not self.emulate and (dynamic_mxfp4_quant is None or gemm_afp4wfp4 is None): if not self.emulate and (dynamic_mxfp4_quant is None or gemm_afp4wfp4 is None):
# Currently need these kernels if not emulating # Currently need these kernels if not emulating

View File

@@ -157,13 +157,13 @@ if current_platform.is_rocm():
total_tokens: int, total_tokens: int,
): ):
assert kv_cache_layout in ["NHD", "SHUFFLE"], ( assert kv_cache_layout in ["NHD", "SHUFFLE"], (
"kv_cache_layout only support NHD, SHUFFLE" "kv_cache_layout only supports NHD, SHUFFLE"
) )
head_dim = key.shape[2] head_dim = key.shape[2]
x = 16 // key_cache.element_size() x = 16 // key_cache.element_size()
# assert dequant is True, "Currently, we only support "\ # assert dequant is True, "Currently, we only support "\
# "gather cache with dequant" # "gather cache with dequant"
# For k cache layout: [num_blocks, num_heads, page_size, head_dim] # For k cache layout: [num_blocks, page_size, num_heads, head_dim]
assert head_dim == key_cache.shape[3], ( assert head_dim == key_cache.shape[3], (
"We assume your kv cache layout is [num_blocks, " "We assume your kv cache layout is [num_blocks, "
"page_size, num_heads, head_dim], but got otherwise" "page_size, num_heads, head_dim], but got otherwise"
@@ -832,7 +832,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]: if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
raise NotImplementedError( raise NotImplementedError(
"Encoder self-attention is not implemented for FlashAttentionImpl" "Encoder self-attention is not implemented for AiterFlashAttentionImpl"
) )
def extend_for_sliding_window( def extend_for_sliding_window(
@@ -1047,7 +1047,8 @@ class AiterFlashAttentionImpl(AttentionImpl):
if output_scale is not None or output_block_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused output quantization is not yet supported for FlashAttentionImpl" "fused output quantization is not yet supported "
"for AiterFlashAttentionImpl"
) )
if attn_metadata is None: if attn_metadata is None: