[Kernel] Pass a device pointer into the quantize kernel for the scales (#5159)
This commit is contained in:
committed by
GitHub
parent
0ab278ca31
commit
cbb2f59cc8
@@ -26,6 +26,8 @@ def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype,
|
||||
torch.iinfo(torch.int8).min,
|
||||
torch.iinfo(torch.int8).max).to(torch.int8)
|
||||
out2 = torch.empty_like(x, dtype=torch.int8)
|
||||
ops.static_scaled_int8_quant(out2, x, scale)
|
||||
scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda")
|
||||
|
||||
ops.static_scaled_int8_quant(out2, x, scale_argument)
|
||||
assert torch.allclose(out1, out2,
|
||||
atol=1) # big atol to account for rounding errors
|
||||
|
||||
Reference in New Issue
Block a user