fix: use randint+view for FP4/FP8 tensors in diag

This commit is contained in:
2026-05-23 08:29:16 +00:00
parent 44a917b277
commit e87ec6c07a

View File

@@ -24,10 +24,10 @@ def test_nvfp4_primitives():
# Construct a BlockScaledGEMM with E4M3 scales (NVFP4)
M, N, K = 128, 256, 512
a_fp4 = torch.randn(M, K // 2, device='cuda', dtype=torch.float4_e2m1fn_x2)
b_fp4 = torch.randn(K // 2, N, device='cuda', dtype=torch.float4_e2m1fn_x2)
a_sf = torch.randn(M, K // 16, device='cuda', dtype=torch.float8_e4m3fn)
b_sf = torch.randn(N, K // 16, device='cuda', dtype=torch.float8_e4m3fn)
a_fp4 = torch.randint(0, 256, (M, K // 2), device='cuda', dtype=torch.uint8).view(torch.float4_e2m1fn_x2)
b_fp4 = torch.randint(0, 256, (K // 2, N), device='cuda', dtype=torch.uint8).view(torch.float4_e2m1fn_x2)
a_sf = torch.randint(0, 256, (M, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn)
b_sf = torch.randint(0, 256, (N, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn)
from dsv4.kernels.gemm.dense import BlockScaledGEMM
try:
@@ -52,8 +52,8 @@ def test_nvfp4_primitives():
print(f" BlockScaledGEMM construction failed: {e}")
# Also try with E8M0 scales (MXFP4)
a_sf_e8m0 = torch.randn(M, K // 32, device='cuda', dtype=torch.float8_e8m0fnu)
b_sf_e8m0 = torch.randn(N, K // 32, device='cuda', dtype=torch.float8_e8m0fnu)
a_sf_e8m0 = torch.randint(0, 256, (M, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu)
b_sf_e8m0 = torch.randint(0, 256, (N, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu)
try:
gemm_e8m0 = BlockScaledGEMM(
a_ptr=a_fp4,