[RFC][ROCm][AITER] Keep all AITER kernels in _aiter_ops class like _custom_ops and _ipex_ops (#24490)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -12,6 +12,7 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
@@ -68,78 +69,6 @@ def cutlass_scaled_mm(
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_gemm_w8a8_blockscale_impl(
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
group_size: int,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
def is_aiter_triton_kernel_tuned(n, k):
|
||||
return (n, k) in [
|
||||
(1024, 8192),
|
||||
(2112, 7168),
|
||||
(3072, 1536),
|
||||
(32768, 8192),
|
||||
(4096, 7168),
|
||||
(4608, 7168),
|
||||
(512, 7168),
|
||||
(7168, 2048),
|
||||
(7168, 256),
|
||||
(8192, 1024),
|
||||
(8192, 32768),
|
||||
]
|
||||
|
||||
n, k = weight.shape
|
||||
if input_scale is not None:
|
||||
q_input = input_2d
|
||||
elif not current_platform.is_fp8_fnuz() and is_aiter_triton_kernel_tuned(n, k):
|
||||
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
|
||||
|
||||
# MI350 case uses triton kernel
|
||||
q_input, input_scale = per_token_group_quant_fp8(
|
||||
input_2d,
|
||||
group_size,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=False,
|
||||
)
|
||||
else:
|
||||
# MI300 uses tuned AITER ASM/C++ kernel
|
||||
import aiter as rocm_aiter
|
||||
from aiter import gemm_a8w8_blockscale, get_hip_quant
|
||||
|
||||
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
|
||||
q_input, input_scale = aiter_per1x128_quant(
|
||||
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8
|
||||
)
|
||||
|
||||
return gemm_a8w8_blockscale(
|
||||
q_input, weight, input_scale, weight_scale, dtype=output_dtype
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_gemm_w8a8_blockscale_fake(
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
group_size: int,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
m = input_2d.shape[0]
|
||||
n = weight.shape[0]
|
||||
return torch.empty(m, n, dtype=output_dtype, device=input_2d.device)
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_gemm_w8a8_blockscale",
|
||||
op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
|
||||
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
|
||||
)
|
||||
|
||||
|
||||
# TODO we should be able to change the type of block_size to GroupShape
|
||||
# after we resolve GroupShape compilation issue
|
||||
# https://github.com/vllm-project/vllm/issues/25270
|
||||
@@ -385,14 +314,40 @@ class W8A8BlockFp8LinearOp:
|
||||
input_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.act_quant_group_shape == GroupShape(1, 128)
|
||||
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
|
||||
input_2d,
|
||||
weight,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
self.act_quant_group_shape.col,
|
||||
input_2d.dtype,
|
||||
)
|
||||
|
||||
n, k = weight.shape
|
||||
if input_scale is not None:
|
||||
q_input = input_2d
|
||||
|
||||
# MI350 case uses triton kernel
|
||||
if (
|
||||
not current_platform.is_fp8_fnuz()
|
||||
and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k)
|
||||
):
|
||||
q_input, input_scale = per_token_group_quant_fp8(
|
||||
input_2d,
|
||||
self.act_quant_group_shape.col,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=False,
|
||||
)
|
||||
return rocm_aiter_ops.triton_gemm_a8w8_blockscale(
|
||||
q_input,
|
||||
weight,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
input_2d.dtype,
|
||||
)
|
||||
|
||||
# MI300 uses tuned AITER ASM/C++ kernel
|
||||
else:
|
||||
q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d)
|
||||
return rocm_aiter_ops.gemm_w8a8_blockscale(
|
||||
q_input,
|
||||
weight,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
input_2d.dtype,
|
||||
)
|
||||
|
||||
def _run_triton(
|
||||
self,
|
||||
@@ -971,15 +926,6 @@ def requant_weight_ue8m0_inplace(
|
||||
s_old.copy_(s_requant)
|
||||
|
||||
|
||||
def check_aiter_fp8_linear_support() -> bool:
|
||||
"""AITER is only supported on ROCm for MI3XX"""
|
||||
return (
|
||||
current_platform.is_rocm()
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
and envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||
)
|
||||
|
||||
|
||||
def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
|
||||
"""Pad the weight tensor. This is an optimization on ROCm platform, which
|
||||
can benefit from tensors located far enough from one another in memory"""
|
||||
|
||||
Reference in New Issue
Block a user