fix: use randint+view for FP4/FP8 tensors in diag
This commit is contained in:
@@ -24,10 +24,10 @@ def test_nvfp4_primitives():
|
||||
|
||||
# Construct a BlockScaledGEMM with E4M3 scales (NVFP4)
|
||||
M, N, K = 128, 256, 512
|
||||
a_fp4 = torch.randn(M, K // 2, device='cuda', dtype=torch.float4_e2m1fn_x2)
|
||||
b_fp4 = torch.randn(K // 2, N, device='cuda', dtype=torch.float4_e2m1fn_x2)
|
||||
a_sf = torch.randn(M, K // 16, device='cuda', dtype=torch.float8_e4m3fn)
|
||||
b_sf = torch.randn(N, K // 16, device='cuda', dtype=torch.float8_e4m3fn)
|
||||
a_fp4 = torch.randint(0, 256, (M, K // 2), device='cuda', dtype=torch.uint8).view(torch.float4_e2m1fn_x2)
|
||||
b_fp4 = torch.randint(0, 256, (K // 2, N), device='cuda', dtype=torch.uint8).view(torch.float4_e2m1fn_x2)
|
||||
a_sf = torch.randint(0, 256, (M, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn)
|
||||
b_sf = torch.randint(0, 256, (N, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn)
|
||||
|
||||
from dsv4.kernels.gemm.dense import BlockScaledGEMM
|
||||
try:
|
||||
@@ -52,8 +52,8 @@ def test_nvfp4_primitives():
|
||||
print(f" BlockScaledGEMM construction failed: {e}")
|
||||
|
||||
# Also try with E8M0 scales (MXFP4)
|
||||
a_sf_e8m0 = torch.randn(M, K // 32, device='cuda', dtype=torch.float8_e8m0fnu)
|
||||
b_sf_e8m0 = torch.randn(N, K // 32, device='cuda', dtype=torch.float8_e8m0fnu)
|
||||
a_sf_e8m0 = torch.randint(0, 256, (M, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu)
|
||||
b_sf_e8m0 = torch.randint(0, 256, (N, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu)
|
||||
try:
|
||||
gemm_e8m0 = BlockScaledGEMM(
|
||||
a_ptr=a_fp4,
|
||||
|
||||
Reference in New Issue
Block a user