[Bugfix][Quantization] Ensure input contiguity in per_token_quant_int8 (#31637)

Signed-off-by: vensen <vensenmu@gmail.com>
This commit is contained in:
Vensen
2026-01-11 04:40:02 +08:00
committed by GitHub
parent 1c46dea001
commit 6ea001cfb7

View File

@@ -122,15 +122,17 @@ def _per_token_quant_int8(
def per_token_quant_int8(x):
original_shape = x.shape
if x.dim() > 2:
x = x.view(-1, original_shape[-1])
M = x.numel() // x.shape[-1]
N = x.shape[-1]
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=torch.float32)
x_q = torch.empty((M, N), device=x.device, dtype=torch.int8)
scales = torch.empty((M, 1), device=x.device, dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
assert x.is_contiguous()
x = x.contiguous()
_per_token_quant_int8[(M,)](
x,
x_q,
@@ -142,7 +144,8 @@ def per_token_quant_int8(x):
num_warps=num_warps,
num_stages=1,
)
x_q = x_q.view(*original_shape)
scales = scales.view(*original_shape[:-1], 1)
return x_q, scales