[ROCm] Fix fused_moe_fake signature mismatch and other AITER bugs (#36100)
Signed-off-by: Li <chuali@amd.com>
This commit is contained in:
committed by
GitHub
parent
a16133a0f1
commit
e99fb98867
@@ -137,6 +137,10 @@ def _rocm_aiter_fused_moe_fake(
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
num_local_tokens: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype | None = None,
|
||||
hidden_pad: int = 0,
|
||||
intermediate_pad: int = 0,
|
||||
bias1: torch.Tensor | None = None,
|
||||
bias2: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if output_dtype is not None:
|
||||
return torch.empty_like(hidden_states, dtype=output_dtype)
|
||||
@@ -1700,7 +1704,7 @@ class rocm_aiter_ops:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def triton_fp4_gemm_dynamic_qaunt(
|
||||
def triton_fp4_gemm_dynamic_quant(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
|
||||
@@ -765,7 +765,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
if self.emulate:
|
||||
logger.warning_once(
|
||||
f"The current mode (supports_mx={current_platform.supports_mx()}, "
|
||||
f"use_mxfp4_aiter_moe={self.use_rocm_aiter_moe}, "
|
||||
f"use_rocm_aiter_moe={self.use_rocm_aiter_moe}, "
|
||||
f"ocp_mx_scheme={self.ocp_mx_scheme}) "
|
||||
"does not support native MXFP4/MXFP6 "
|
||||
"computation. Simulated weight dequantization and activation "
|
||||
|
||||
@@ -3,13 +3,12 @@
|
||||
|
||||
from collections.abc import Callable
|
||||
from fractions import Fraction
|
||||
from functools import cache, partial
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
@@ -37,22 +36,6 @@ from .quark_scheme import QuarkScheme
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# TODO: move registration of custom op to aiter_ops.py
|
||||
# `from vllm._aiter_ops import rocm_aiter_ops`
|
||||
# use `rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()`
|
||||
# for envs checks which does not require @cache anymore.
|
||||
# triton kernel is torch compile compatible.
|
||||
# does not require direct registration.
|
||||
# use `rocm_aiter_ops.triton_fp4_gemm_dynamic_qaunt`.
|
||||
@cache
|
||||
def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
|
||||
return (
|
||||
current_platform.is_rocm()
|
||||
and envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
from aiter.ops.triton.gemm_afp4wfp4 import (
|
||||
@@ -63,7 +46,7 @@ try:
|
||||
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
if is_rocm_aiter_fp4_asm_gemm_enabled():
|
||||
if rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled():
|
||||
from aiter import gemm_a4w4, per_1x32_f4_quant_hip
|
||||
|
||||
def gemm_with_dynamic_quant(
|
||||
@@ -233,7 +216,9 @@ class QuarkOCP_MX(QuarkScheme):
|
||||
self.input_dtype != "mxfp4" or self.weight_dtype != "mxfp4"
|
||||
)
|
||||
|
||||
self.rocm_use_aiter_fp4_asm_gemm = is_rocm_aiter_fp4_asm_gemm_enabled()
|
||||
self.rocm_use_aiter_fp4_asm_gemm = (
|
||||
rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()
|
||||
)
|
||||
|
||||
if not self.emulate and (dynamic_mxfp4_quant is None or gemm_afp4wfp4 is None):
|
||||
# Currently need these kernels if not emulating
|
||||
|
||||
@@ -157,13 +157,13 @@ if current_platform.is_rocm():
|
||||
total_tokens: int,
|
||||
):
|
||||
assert kv_cache_layout in ["NHD", "SHUFFLE"], (
|
||||
"kv_cache_layout only support NHD, SHUFFLE"
|
||||
"kv_cache_layout only supports NHD, SHUFFLE"
|
||||
)
|
||||
head_dim = key.shape[2]
|
||||
x = 16 // key_cache.element_size()
|
||||
# assert dequant is True, "Currently, we only support "\
|
||||
# "gather cache with dequant"
|
||||
# For k cache layout: [num_blocks, num_heads, page_size, head_dim]
|
||||
# For k cache layout: [num_blocks, page_size, num_heads, head_dim]
|
||||
assert head_dim == key_cache.shape[3], (
|
||||
"We assume your kv cache layout is [num_blocks, "
|
||||
"page_size, num_heads, head_dim], but got otherwise"
|
||||
@@ -832,7 +832,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
|
||||
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention is not implemented for FlashAttentionImpl"
|
||||
"Encoder self-attention is not implemented for AiterFlashAttentionImpl"
|
||||
)
|
||||
|
||||
def extend_for_sliding_window(
|
||||
@@ -1047,7 +1047,8 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported for FlashAttentionImpl"
|
||||
"fused output quantization is not yet supported "
|
||||
"for AiterFlashAttentionImpl"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
|
||||
Reference in New Issue
Block a user