[Bugfix][Quantization] Ensure input contiguity in per_token_quant_int8 (#31637)
Signed-off-by: vensen <vensenmu@gmail.com>
This commit is contained in:
@@ -122,15 +122,17 @@ def _per_token_quant_int8(
|
||||
|
||||
|
||||
def per_token_quant_int8(x):
|
||||
original_shape = x.shape
|
||||
if x.dim() > 2:
|
||||
x = x.view(-1, original_shape[-1])
|
||||
M = x.numel() // x.shape[-1]
|
||||
N = x.shape[-1]
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
|
||||
scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=torch.float32)
|
||||
x_q = torch.empty((M, N), device=x.device, dtype=torch.int8)
|
||||
scales = torch.empty((M, 1), device=x.device, dtype=torch.float32)
|
||||
BLOCK = triton.next_power_of_2(N)
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK // 256, 1), 8)
|
||||
|
||||
assert x.is_contiguous()
|
||||
x = x.contiguous()
|
||||
_per_token_quant_int8[(M,)](
|
||||
x,
|
||||
x_q,
|
||||
@@ -142,7 +144,8 @@ def per_token_quant_int8(x):
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
|
||||
x_q = x_q.view(*original_shape)
|
||||
scales = scales.view(*original_shape[:-1], 1)
|
||||
return x_q, scales
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user