[ Kernel ] Enable Dynamic Per Token fp8 (#6547)

This commit is contained in:
Robert Shaw
2024-07-19 19:08:15 -04:00
committed by GitHub
parent 07eb6f19f3
commit 4cc24f01b1
7 changed files with 67 additions and 38 deletions

View File

@@ -27,7 +27,8 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
device="cuda") + 1e-6 # avoid nans
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn)
ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x)
ops_out, ops_scales = ops.scaled_fp8_quant(x,
use_per_token_if_dynamic=True)
assert torch.allclose(ref_scales, ops_scales)
assert torch.allclose(ref_out.to(dtype=torch.float32),