[Bugfix] [ROCm] [AITER]: Fix aiter block quant not compatible with torch compile dynamo (#28716)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
TJian
2025-11-14 10:30:50 -08:00
committed by GitHub
parent 964d65deed
commit a425dc256e
3 changed files with 180 additions and 7 deletions

View File

@@ -43,6 +43,36 @@ def if_aiter_supported(func: Callable) -> Callable:
return wrapper
def _rocm_aiter_group_fp8_quant_impl(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size"
from aiter import QuantType, dtypes, get_hip_quant
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
return aiter_per1x128_quant(x.contiguous(), quant_dtype=dtypes.fp8)
def _rocm_aiter_group_fp8_quant_fake(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter import dtypes
M, N = x.shape
x_fp8 = torch.empty((M, N), dtype=dtypes.fp8, device=x.device)
out_bs = torch.empty(
(
M,
(N + group_size - 1) // group_size,
),
dtype=torch.float32,
device=x.device,
)
return x_fp8, out_bs
def _rocm_aiter_fused_moe_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
@@ -512,6 +542,14 @@ class rocm_aiter_ops:
)
# register all the custom ops here
direct_register_custom_op(
op_name="rocm_aiter_group_fp8_quant",
op_func=_rocm_aiter_group_fp8_quant_impl,
mutates_args=[],
fake_impl=_rocm_aiter_group_fp8_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_asm_moe_tkw1",
op_func=_rocm_aiter_asm_moe_tkw1_impl,
@@ -887,14 +925,12 @@ class rocm_aiter_ops:
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
@staticmethod
def per_1x128_fp8_quant(
def group_fp8_quant(
input_2d: torch.Tensor,
group_size: int = 128,
) -> tuple[torch.Tensor, ...]:
"""Only applies quantization method for fp8 data type only."""
from aiter import QuantType, dtypes, get_hip_quant
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
return aiter_per1x128_quant(input_2d.contiguous(), quant_dtype=dtypes.fp8)
assert group_size == 128, "Group size must be 128"
return torch.ops.vllm.rocm_aiter_group_fp8_quant(input_2d, group_size)
@staticmethod
def is_triton_gemm_w8a8_tuned(n: int, k: int) -> bool:

View File

@@ -342,7 +342,7 @@ class W8A8BlockFp8LinearOp:
)
# MI300 uses tuned AITER ASM/C++ kernel
else:
q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d)
q_input, input_scale = rocm_aiter_ops.group_fp8_quant(input_2d)
return gemm_a8w8_blockscale_op(
q_input,