[Kernel] Pass a device pointer into the quantize kernel for the scales (#5159)

This commit is contained in:
Tyler Michael Smith
2024-06-03 12:52:30 -04:00
committed by GitHub
parent 0ab278ca31
commit cbb2f59cc8
5 changed files with 16 additions and 11 deletions

View File

@@ -26,6 +26,8 @@ def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype,
torch.iinfo(torch.int8).min,
torch.iinfo(torch.int8).max).to(torch.int8)
out2 = torch.empty_like(x, dtype=torch.int8)
ops.static_scaled_int8_quant(out2, x, scale)
scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda")
ops.static_scaled_int8_quant(out2, x, scale_argument)
assert torch.allclose(out1, out2,
atol=1) # big atol to account for rounding errors