[Perf] Cuda Kernel for Int8 Per Token Group Quant (#21476)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -238,13 +238,20 @@ def per_token_group_quant_int8(
|
||||
int8_min = iinfo.min
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||
M = x.numel() // group_size
|
||||
N = group_size
|
||||
x_s = torch.empty(
|
||||
x.shape[:-1] + (x.shape[-1] // group_size, ),
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
# prefer CUDA kernel if available
|
||||
if current_platform.is_cuda():
|
||||
torch.ops._C.per_token_group_quant_int8(x, x_q, x_s, group_size, eps,
|
||||
float(int8_min),
|
||||
float(int8_max))
|
||||
return x_q, x_s
|
||||
|
||||
M = x.numel() // group_size
|
||||
N = group_size
|
||||
|
||||
BLOCK = triton.next_power_of_2(N)
|
||||
# heuristics for number of warps
|
||||
|
||||
Reference in New Issue
Block a user