[Perf] Cuda Kernel for Per Token Group Quant (#21083)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2025-07-22 10:27:15 -04:00
committed by GitHub
parent 2c8db17cfd
commit 774d0c014b
6 changed files with 285 additions and 4 deletions

View File

@@ -366,6 +366,7 @@ def per_token_group_quant_fp8(
dtype: Optional[torch.dtype] = None,
column_major_scales: bool = False,
out_q: Optional[torch.Tensor] = None,
use_ue8m0: bool = is_blackwell_deep_gemm_used(),
) -> tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
@@ -397,8 +398,7 @@ def per_token_group_quant_fp8(
if x_q is None:
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
# Allocate the scale tensor in either row- or column-major format.
if column_major_scales:
shape = (x.shape[-1] // group_size, ) + x.shape[:-1]
x_s = torch.empty(shape, device=x.device,
@@ -407,6 +407,15 @@ def per_token_group_quant_fp8(
shape = x.shape[:-1] + (x.shape[-1] // group_size, )
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
# prefer CUDA kernel if available
if current_platform.is_cuda() and x.is_contiguous():
torch.ops._C.per_token_group_fp8_quant(x, x_q, x_s, group_size, eps,
fp8_min, fp8_max, use_ue8m0)
return x_q, x_s
# TRITON FALLBACK
M = x.numel() // group_size
N = group_size
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
@@ -423,7 +432,7 @@ def per_token_group_quant_fp8(
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
use_ue8m0=is_blackwell_deep_gemm_used(),
use_ue8m0=use_ue8m0,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
@@ -439,7 +448,7 @@ def per_token_group_quant_fp8(
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
use_ue8m0=is_blackwell_deep_gemm_used(),
use_ue8m0=use_ue8m0,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,