fix: use quantize_activation_nvfp4 in diag
This commit is contained in:
@@ -93,10 +93,11 @@ def test_nvfp4_primitives():
|
||||
print(f" a_sf (E8M0) dtype = {a_sf_e8m0.dtype}, shape = {a_sf_e8m0.shape}")
|
||||
|
||||
# The runner uses E4M3 or E8M0? Check what quantize actually produces
|
||||
from dsv4.ops.quantize import quantize_to_nvfp4
|
||||
from dsv4.ops.quantize import quantize_activation_nvfp4
|
||||
x = torch.randn(4, 512, device='cuda', dtype=torch.bfloat16)
|
||||
x_fp4, x_sf = quantize_to_nvfp4(x)
|
||||
print(f" quantize_to_nvfp4 output: FP4 dtype={x_fp4.dtype}, SF dtype={x_sf.dtype}")
|
||||
global_scale = torch.tensor(448.0, device='cuda', dtype=torch.float32)
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(x, global_scale)
|
||||
print(f" quantize_activation_nvfp4 output: FP4 dtype={x_fp4.dtype}, SF dtype={x_sf.dtype}")
|
||||
if x_sf.dtype == torch.float8_e4m3fn:
|
||||
print(f" ✅ Scale factors are E4M3 — NVFP4 correct")
|
||||
elif x_sf.dtype == torch.float8_e8m0fnu:
|
||||
|
||||
Reference in New Issue
Block a user