[RFC][ROCm][AITER] Keep all AITER kernels in _aiter_ops class like _custom_ops and _ipex_ops (#24490)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
vllmellm
2025-11-10 17:20:53 +01:00
committed by GitHub
parent 40e2eeeb92
commit f080a83511
25 changed files with 1193 additions and 924 deletions

View File

@@ -12,6 +12,7 @@ import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
@@ -68,78 +69,6 @@ def cutlass_scaled_mm(
)
def rocm_aiter_gemm_w8a8_blockscale_impl(
input_2d: torch.Tensor,
weight: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
group_size: int,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
def is_aiter_triton_kernel_tuned(n, k):
return (n, k) in [
(1024, 8192),
(2112, 7168),
(3072, 1536),
(32768, 8192),
(4096, 7168),
(4608, 7168),
(512, 7168),
(7168, 2048),
(7168, 256),
(8192, 1024),
(8192, 32768),
]
n, k = weight.shape
if input_scale is not None:
q_input = input_2d
elif not current_platform.is_fp8_fnuz() and is_aiter_triton_kernel_tuned(n, k):
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
# MI350 case uses triton kernel
q_input, input_scale = per_token_group_quant_fp8(
input_2d,
group_size,
column_major_scales=False,
use_ue8m0=False,
)
else:
# MI300 uses tuned AITER ASM/C++ kernel
import aiter as rocm_aiter
from aiter import gemm_a8w8_blockscale, get_hip_quant
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
q_input, input_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8
)
return gemm_a8w8_blockscale(
q_input, weight, input_scale, weight_scale, dtype=output_dtype
)
def rocm_aiter_gemm_w8a8_blockscale_fake(
input_2d: torch.Tensor,
weight: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
group_size: int,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
m = input_2d.shape[0]
n = weight.shape[0]
return torch.empty(m, n, dtype=output_dtype, device=input_2d.device)
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8_blockscale",
op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
)
# TODO we should be able to change the type of block_size to GroupShape
# after we resolve GroupShape compilation issue
# https://github.com/vllm-project/vllm/issues/25270
@@ -385,14 +314,40 @@ class W8A8BlockFp8LinearOp:
input_scale: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.act_quant_group_shape == GroupShape(1, 128)
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
input_2d,
weight,
input_scale,
weight_scale,
self.act_quant_group_shape.col,
input_2d.dtype,
)
n, k = weight.shape
if input_scale is not None:
q_input = input_2d
# MI350 case uses triton kernel
if (
not current_platform.is_fp8_fnuz()
and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k)
):
q_input, input_scale = per_token_group_quant_fp8(
input_2d,
self.act_quant_group_shape.col,
column_major_scales=False,
use_ue8m0=False,
)
return rocm_aiter_ops.triton_gemm_a8w8_blockscale(
q_input,
weight,
input_scale,
weight_scale,
input_2d.dtype,
)
# MI300 uses tuned AITER ASM/C++ kernel
else:
q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d)
return rocm_aiter_ops.gemm_w8a8_blockscale(
q_input,
weight,
input_scale,
weight_scale,
input_2d.dtype,
)
def _run_triton(
self,
@@ -971,15 +926,6 @@ def requant_weight_ue8m0_inplace(
s_old.copy_(s_requant)
def check_aiter_fp8_linear_support() -> bool:
"""AITER is only supported on ROCm for MI3XX"""
return (
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_LINEAR
)
def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
"""Pad the weight tensor. This is an optimization on ROCm platform, which
can benefit from tensors located far enough from one another in memory"""