more debug2
This commit is contained in:
@@ -28,7 +28,10 @@ def cutlass_nvfp4_blockscaled_gemm(
|
||||
"""Single NVFP4 block-scaled GEMM using CUTLASS."""
|
||||
if not _CUTLASS_AVAILABLE:
|
||||
raise RuntimeError("CUTLASS NVFP4 GEMM extension not available")
|
||||
return _C.forward(A_packed, SFA, B_packed, SFB, M, N, K, alpha)
|
||||
result = _C.forward(A_packed, SFA, B_packed, SFB, M, N, K, alpha)
|
||||
if result < 0:
|
||||
raise RuntimeError(f"CUTLASS NVFP4 GEMM failed with code {result} — can_implement may have rejected the problem (M={M} N={N} K={K})")
|
||||
return result
|
||||
|
||||
|
||||
def cutlass_grouped_nvfp4_gemm(
|
||||
@@ -75,6 +78,15 @@ def cutlass_grouped_nvfp4_gemm(
|
||||
|
||||
M_expert = token_indices.shape[0]
|
||||
|
||||
# DEBUG: verify data going into GEMM
|
||||
if e == 0:
|
||||
print(f"[GEMM-IN] expert={e} M={M_expert} N={N} K={K} "
|
||||
f"w shape={expert_w.shape} w_sf shape={expert_w_sf.shape} "
|
||||
f"w absmax={expert_w.view(torch.int8).abs().max().item()} "
|
||||
f"w_sf range=[{expert_w_sf.to(torch.float32).min().item():.4e}, "
|
||||
f"{expert_w_sf.to(torch.float32).max().item():.4e}] "
|
||||
f"w_sf nonzero_frac={(expert_w_sf.view(torch.uint8) != 0).float().mean().item():.4f}")
|
||||
|
||||
# Run CUTLASS NVFP4 block-scaled GEMM
|
||||
expert_out = cutlass_nvfp4_blockscaled_gemm(
|
||||
expert_x, expert_x_sf,
|
||||
|
||||
Reference in New Issue
Block a user