From e87ec6c07ae7a7d3b741b31c5daf603d7d05a236 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 08:29:16 +0000 Subject: [PATCH] fix: use randint+view for FP4/FP8 tensors in diag --- tests/unit/test_nvfp4_primitives.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/unit/test_nvfp4_primitives.py b/tests/unit/test_nvfp4_primitives.py index 1b27d3dc..692080d6 100644 --- a/tests/unit/test_nvfp4_primitives.py +++ b/tests/unit/test_nvfp4_primitives.py @@ -24,10 +24,10 @@ def test_nvfp4_primitives(): # Construct a BlockScaledGEMM with E4M3 scales (NVFP4) M, N, K = 128, 256, 512 - a_fp4 = torch.randn(M, K // 2, device='cuda', dtype=torch.float4_e2m1fn_x2) - b_fp4 = torch.randn(K // 2, N, device='cuda', dtype=torch.float4_e2m1fn_x2) - a_sf = torch.randn(M, K // 16, device='cuda', dtype=torch.float8_e4m3fn) - b_sf = torch.randn(N, K // 16, device='cuda', dtype=torch.float8_e4m3fn) + a_fp4 = torch.randint(0, 256, (M, K // 2), device='cuda', dtype=torch.uint8).view(torch.float4_e2m1fn_x2) + b_fp4 = torch.randint(0, 256, (K // 2, N), device='cuda', dtype=torch.uint8).view(torch.float4_e2m1fn_x2) + a_sf = torch.randint(0, 256, (M, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn) + b_sf = torch.randint(0, 256, (N, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn) from dsv4.kernels.gemm.dense import BlockScaledGEMM try: @@ -52,8 +52,8 @@ def test_nvfp4_primitives(): print(f" BlockScaledGEMM construction failed: {e}") # Also try with E8M0 scales (MXFP4) - a_sf_e8m0 = torch.randn(M, K // 32, device='cuda', dtype=torch.float8_e8m0fnu) - b_sf_e8m0 = torch.randn(N, K // 32, device='cuda', dtype=torch.float8_e8m0fnu) + a_sf_e8m0 = torch.randint(0, 256, (M, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu) + b_sf_e8m0 = torch.randint(0, 256, (N, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu) try: gemm_e8m0 = BlockScaledGEMM( a_ptr=a_fp4,