[ Kernel ] Enable Dynamic Per Token fp8 (#6547)
This commit is contained in:
@@ -27,7 +27,8 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
|
||||
device="cuda") + 1e-6 # avoid nans
|
||||
|
||||
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn)
|
||||
ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x)
|
||||
ops_out, ops_scales = ops.scaled_fp8_quant(x,
|
||||
use_per_token_if_dynamic=True)
|
||||
|
||||
assert torch.allclose(ref_scales, ops_scales)
|
||||
assert torch.allclose(ref_out.to(dtype=torch.float32),
|
||||
|
||||
Reference in New Issue
Block a user