[ROCM] Fix ROCm warnings, environment flag access, and GEMM kernel naming for consistency in _aiter_ops.py (#28464)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm
2025-11-13 05:46:57 +08:00
committed by GitHub
parent 74a9a9faad
commit d8140b9833
5 changed files with 33 additions and 29 deletions

View File

@@ -117,4 +117,4 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return rocm_aiter_ops.gemm_w8a8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)
return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)

View File

@@ -328,7 +328,7 @@ class W8A8BlockFp8LinearOp:
if use_triton:
gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale
else:
gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_w8a8_blockscale
gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_a8w8_blockscale
if input_scale is not None:
q_input = input_2d

View File

@@ -8,6 +8,7 @@ import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.utils.torch_utils import direct_register_custom_op
@@ -105,8 +106,7 @@ def default_unquantized_gemm(
def use_aiter_triton_gemm(n, m, k, dtype):
if (
envs.VLLM_ROCM_USE_AITER == 0
or envs.VLLM_ROCM_USE_AITER_TRITON_GEMM == 0
not rocm_aiter_ops.is_triton_gemm_enabled()
# MI300's - fp8nuz=True
or current_platform.is_fp8_fnuz()
or dtype not in [torch.float16, torch.bfloat16]