[Perf] Cuda Kernel for Int8 Per Token Group Quant (#21476)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2025-07-25 20:07:07 -04:00
committed by GitHub
parent 41d3082c41
commit 75d29cf4e1
6 changed files with 47 additions and 3 deletions

View File

@@ -238,13 +238,20 @@ def per_token_group_quant_int8(
int8_min = iinfo.min
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size, ),
device=x.device,
dtype=torch.float32,
)
# prefer CUDA kernel if available
if current_platform.is_cuda():
torch.ops._C.per_token_group_quant_int8(x, x_q, x_s, group_size, eps,
float(int8_min),
float(int8_max))
return x_q, x_s
M = x.numel() // group_size
N = group_size
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps