diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index c4ba8053c..cf0da35f8 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -137,6 +137,10 @@ def _rocm_aiter_fused_moe_fake( a2_scale: torch.Tensor | None = None, num_local_tokens: torch.Tensor | None = None, output_dtype: torch.dtype | None = None, + hidden_pad: int = 0, + intermediate_pad: int = 0, + bias1: torch.Tensor | None = None, + bias2: torch.Tensor | None = None, ) -> torch.Tensor: if output_dtype is not None: return torch.empty_like(hidden_states, dtype=output_dtype) @@ -1700,7 +1704,7 @@ class rocm_aiter_ops: ) @staticmethod - def triton_fp4_gemm_dynamic_qaunt( + def triton_fp4_gemm_dynamic_quant( x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index b2b77e668..93eb2f7f6 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -765,7 +765,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): if self.emulate: logger.warning_once( f"The current mode (supports_mx={current_platform.supports_mx()}, " - f"use_mxfp4_aiter_moe={self.use_rocm_aiter_moe}, " + f"use_rocm_aiter_moe={self.use_rocm_aiter_moe}, " f"ocp_mx_scheme={self.ocp_mx_scheme}) " "does not support native MXFP4/MXFP6 " "computation. Simulated weight dequantization and activation " diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py index 1b30f5b82..0b0a224f3 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py @@ -3,13 +3,12 @@ from collections.abc import Callable from fractions import Fraction -from functools import cache, partial +from functools import partial from typing import Any import torch import torch.nn.functional as F -from vllm import envs from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( @@ -37,22 +36,6 @@ from .quark_scheme import QuarkScheme logger = init_logger(__name__) -# TODO: move registration of custom op to aiter_ops.py -# `from vllm._aiter_ops import rocm_aiter_ops` -# use `rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()` -# for envs checks which does not require @cache anymore. -# triton kernel is torch compile compatible. -# does not require direct registration. -# use `rocm_aiter_ops.triton_fp4_gemm_dynamic_qaunt`. -@cache -def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool: - return ( - current_platform.is_rocm() - and envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM - and envs.VLLM_ROCM_USE_AITER - ) - - try: from aiter.ops.shuffle import shuffle_weight from aiter.ops.triton.gemm_afp4wfp4 import ( @@ -63,7 +46,7 @@ try: from vllm.utils.torch_utils import direct_register_custom_op - if is_rocm_aiter_fp4_asm_gemm_enabled(): + if rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled(): from aiter import gemm_a4w4, per_1x32_f4_quant_hip def gemm_with_dynamic_quant( @@ -233,7 +216,9 @@ class QuarkOCP_MX(QuarkScheme): self.input_dtype != "mxfp4" or self.weight_dtype != "mxfp4" ) - self.rocm_use_aiter_fp4_asm_gemm = is_rocm_aiter_fp4_asm_gemm_enabled() + self.rocm_use_aiter_fp4_asm_gemm = ( + rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled() + ) if not self.emulate and (dynamic_mxfp4_quant is None or gemm_afp4wfp4 is None): # Currently need these kernels if not emulating diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index d563fbcbc..22f524cd9 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -157,13 +157,13 @@ if current_platform.is_rocm(): total_tokens: int, ): assert kv_cache_layout in ["NHD", "SHUFFLE"], ( - "kv_cache_layout only support NHD, SHUFFLE" + "kv_cache_layout only supports NHD, SHUFFLE" ) head_dim = key.shape[2] x = 16 // key_cache.element_size() # assert dequant is True, "Currently, we only support "\ # "gather cache with dequant" - # For k cache layout: [num_blocks, num_heads, page_size, head_dim] + # For k cache layout: [num_blocks, page_size, num_heads, head_dim] assert head_dim == key_cache.shape[3], ( "We assume your kv cache layout is [num_blocks, " "page_size, num_heads, head_dim], but got otherwise" @@ -832,7 +832,7 @@ class AiterFlashAttentionImpl(AttentionImpl): if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]: raise NotImplementedError( - "Encoder self-attention is not implemented for FlashAttentionImpl" + "Encoder self-attention is not implemented for AiterFlashAttentionImpl" ) def extend_for_sliding_window( @@ -1047,7 +1047,8 @@ class AiterFlashAttentionImpl(AttentionImpl): if output_scale is not None or output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported for FlashAttentionImpl" + "fused output quantization is not yet supported " + "for AiterFlashAttentionImpl" ) if attn_metadata is None: