[Quantizaton] [AMD] Add support for running DeepSeek int8 w8a8 MoE on ROCm (#17558)

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This commit is contained in:
rasmith
2025-05-02 23:41:10 -05:00
committed by GitHub
parent d47b605eca
commit e3d0a1d190
2 changed files with 29 additions and 3 deletions

View File

@@ -85,6 +85,32 @@ def block_dequant(
return x_dq_block
if current_platform.is_rocm():
from triton.language import core
# NOTE: This can be removed when hip.libdevice.round() is available.
@core.extern
def round_f32(arg0, _builder=None):
return core.extern_elementwise("",
"", [arg0], {
(core.dtype("fp32"), ):
("llvm.round", core.dtype("fp32")),
(core.dtype("fp64"), ):
("llvm.round", core.dtype("fp64")),
},
is_pure=True,
_builder=_builder)
@triton.jit
def round_int8(x):
return round_f32(x).to(tl.int8)
else:
@triton.jit
def round_int8(x):
return tl.extra.cuda.libdevice.round(x).to(tl.int8)
@triton.jit
def _per_token_quant_int8(
x_ptr,
@@ -106,7 +132,7 @@ def _per_token_quant_int8(
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8)
x_q = round_int8(x_q)
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + row_id, scale_x)