[ROCM] Fix ROCm warnings, environment flag access, and GEMM kernel naming for consistency in _aiter_ops.py (#28464)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
@@ -117,4 +117,4 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
# a to be [M, K]
|
||||
# b to be [N, K]
|
||||
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
|
||||
return rocm_aiter_ops.gemm_w8a8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)
|
||||
return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)
|
||||
|
||||
@@ -328,7 +328,7 @@ class W8A8BlockFp8LinearOp:
|
||||
if use_triton:
|
||||
gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale
|
||||
else:
|
||||
gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_w8a8_blockscale
|
||||
gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_a8w8_blockscale
|
||||
|
||||
if input_scale is not None:
|
||||
q_input = input_2d
|
||||
|
||||
@@ -8,6 +8,7 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
@@ -105,8 +106,7 @@ def default_unquantized_gemm(
|
||||
|
||||
def use_aiter_triton_gemm(n, m, k, dtype):
|
||||
if (
|
||||
envs.VLLM_ROCM_USE_AITER == 0
|
||||
or envs.VLLM_ROCM_USE_AITER_TRITON_GEMM == 0
|
||||
not rocm_aiter_ops.is_triton_gemm_enabled()
|
||||
# MI300's - fp8nuz=True
|
||||
or current_platform.is_fp8_fnuz()
|
||||
or dtype not in [torch.float16, torch.bfloat16]
|
||||
|
||||
Reference in New Issue
Block a user