[Quantizaton] [AMD] Add support for running DeepSeek int8 w8a8 MoE on ROCm (#17558)
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This commit is contained in:
@@ -85,6 +85,32 @@ def block_dequant(
|
||||
return x_dq_block
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
from triton.language import core
|
||||
|
||||
# NOTE: This can be removed when hip.libdevice.round() is available.
|
||||
@core.extern
|
||||
def round_f32(arg0, _builder=None):
|
||||
return core.extern_elementwise("",
|
||||
"", [arg0], {
|
||||
(core.dtype("fp32"), ):
|
||||
("llvm.round", core.dtype("fp32")),
|
||||
(core.dtype("fp64"), ):
|
||||
("llvm.round", core.dtype("fp64")),
|
||||
},
|
||||
is_pure=True,
|
||||
_builder=_builder)
|
||||
|
||||
@triton.jit
|
||||
def round_int8(x):
|
||||
return round_f32(x).to(tl.int8)
|
||||
else:
|
||||
|
||||
@triton.jit
|
||||
def round_int8(x):
|
||||
return tl.extra.cuda.libdevice.round(x).to(tl.int8)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _per_token_quant_int8(
|
||||
x_ptr,
|
||||
@@ -106,7 +132,7 @@ def _per_token_quant_int8(
|
||||
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
|
||||
scale_x = absmax / 127
|
||||
x_q = x * (127 / absmax)
|
||||
x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8)
|
||||
x_q = round_int8(x_q)
|
||||
|
||||
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
|
||||
tl.store(scale_ptr + row_id, scale_x)
|
||||
|
||||
Reference in New Issue
Block a user