more debug2

This commit is contained in:
2026-05-15 05:08:53 +00:00
parent 912e4622d7
commit 76e9b078a2

View File

@@ -28,7 +28,10 @@ def cutlass_nvfp4_blockscaled_gemm(
"""Single NVFP4 block-scaled GEMM using CUTLASS."""
if not _CUTLASS_AVAILABLE:
raise RuntimeError("CUTLASS NVFP4 GEMM extension not available")
return _C.forward(A_packed, SFA, B_packed, SFB, M, N, K, alpha)
result = _C.forward(A_packed, SFA, B_packed, SFB, M, N, K, alpha)
if result < 0:
raise RuntimeError(f"CUTLASS NVFP4 GEMM failed with code {result} — can_implement may have rejected the problem (M={M} N={N} K={K})")
return result
def cutlass_grouped_nvfp4_gemm(
@@ -75,6 +78,15 @@ def cutlass_grouped_nvfp4_gemm(
M_expert = token_indices.shape[0]
# DEBUG: verify data going into GEMM
if e == 0:
print(f"[GEMM-IN] expert={e} M={M_expert} N={N} K={K} "
f"w shape={expert_w.shape} w_sf shape={expert_w_sf.shape} "
f"w absmax={expert_w.view(torch.int8).abs().max().item()} "
f"w_sf range=[{expert_w_sf.to(torch.float32).min().item():.4e}, "
f"{expert_w_sf.to(torch.float32).max().item():.4e}] "
f"w_sf nonzero_frac={(expert_w_sf.view(torch.uint8) != 0).float().mean().item():.4f}")
# Run CUTLASS NVFP4 block-scaled GEMM
expert_out = cutlass_nvfp4_blockscaled_gemm(
expert_x, expert_x_sf,