[CI/Build][Kernel][BugFix][AMD] Fix per_token_group_quant_fp8 to use correct fp8 min/max values and update atol/rtol in test_quantfp8_group_functionality (#30292)
Signed-off-by: Randall Smith <ransmith@amd.com> Co-authored-by: Randall Smith <ransmith@amd.com>
This commit is contained in:
@@ -62,7 +62,7 @@ def test_quantfp8_group_functionality(
|
||||
assert scales_col.stride(1) == batch_size
|
||||
|
||||
# Test column-major scales consistency
|
||||
assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8)
|
||||
torch.testing.assert_close(scales_col, scales_native, rtol=1e-9, atol=1e-8)
|
||||
|
||||
# 3. Test CUDA implementation (only for divisible dimensions)
|
||||
if is_divisible:
|
||||
@@ -71,7 +71,7 @@ def test_quantfp8_group_functionality(
|
||||
assert scales_cuda.shape == (batch_size, expected_num_groups)
|
||||
|
||||
# Verify CUDA/native consistency
|
||||
assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8)
|
||||
torch.testing.assert_close(scales_cuda, scales_native, rtol=2e-7, atol=2e-8)
|
||||
|
||||
# Quantized values should mostly match
|
||||
diff_count = (x_quant_cuda != x_quant_native).sum().item()
|
||||
|
||||
Reference in New Issue
Block a user