From 5385de314237a0c7f5eb5ea05beac6e7a88d055f Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 16 May 2026 03:15:29 +0000 Subject: [PATCH] fix: layertest tests L1 GEMM only with correct output size L1 produces (tokens, 6144) gate+up, not (tokens, 7168) hidden. Compare against BF16 L1 reference only. --- tests/layertest.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/tests/layertest.py b/tests/layertest.py index 1d30b4bc..76581a35 100644 --- a/tests/layertest.py +++ b/tests/layertest.py @@ -145,7 +145,7 @@ def moe_forward_bf16(hidden_states, experts, expert_ids, expert_weights): # ── CuTeDSL NVFP4 Kernel MoE Forward ────────────────────────────────── -def moe_forward_nvfp4(hidden_states, nvfp4_tensors, layer_idx, expert_ids, expert_weights): +def moe_forward_nvfp4_l1_only(hidden_states, nvfp4_tensors, layer_idx, expert_ids, expert_weights): """Run MoE forward pass using the CuTeDSL NVFP4 kernel via bridge.""" num_tokens, hidden_size = hidden_states.shape top_k = expert_ids.shape[1] @@ -270,25 +270,36 @@ def main(): expert_ids = torch.tensor([[0, 1]] * num_tokens, dtype=torch.int32, device=DEVICE) expert_weights = torch.tensor([[0.6, 0.4]] * num_tokens, dtype=torch.float32, device=DEVICE) - # ── BF16 reference forward pass ── - print("\n Running BF16 reference...") - ref_output = moe_forward_bf16(hidden_states, nvfp4_experts_bf16, expert_ids, expert_weights) - print(f" BF16 ref: amax={ref_output.abs().max():.4f} mean={ref_output.float().mean():.6f}") + # ── BF16 L1 reference (gate+up only) ── + print("\n Running BF16 L1 reference...") + ref_l1 = torch.zeros(num_tokens, 6144, dtype=torch.bfloat16, device=DEVICE) + for t in range(num_tokens): + for k in range(top_k): + e = expert_ids[t, k].item() + w = expert_weights[t, k].item() + if e not in nvfp4_experts_bf16: + continue + x = hidden_states[t] + gate = x @ nvfp4_experts_bf16[e]["gate_proj"].T # (3072,) + up = x @ nvfp4_experts_bf16[e]["up_proj"].T # (3072,) + ref_l1[t] += w * torch.cat([gate, up]) + + print(f" BF16 L1 ref: amax={ref_l1.abs().max():.4f} mean={ref_l1.float().mean():.6f}") del nvfp4_experts_bf16 torch.cuda.empty_cache() - # ── CuTeDSL NVFP4 kernel forward pass ── - print("\n Running CuTeDSL NVFP4 kernel (first run compiles, ~1-2 min)...") - kernel_output = moe_forward_nvfp4(hidden_states, nvfp4_tensors, LAYER_IDX, expert_ids, expert_weights) - print(f" Kernel: amax={kernel_output.abs().max():.4f} mean={kernel_output.float().mean():.6f}") + # ── CuTeDSL NVFP4 L1 kernel ── + print("\n Running CuTeDSL NVFP4 L1 kernel (first run compiles, ~1-2 min)...") + kernel_l1 = moe_forward_nvfp4_l1_only(hidden_states, nvfp4_tensors, LAYER_IDX, expert_ids, expert_weights) + print(f" Kernel L1: amax={kernel_l1.abs().max():.4f} mean={kernel_l1.float().mean():.6f}") # ── Compare ── cosine = torch.nn.functional.cosine_similarity( - kernel_output.flatten().unsqueeze(0).float(), - ref_output.flatten().unsqueeze(0).float(), + kernel_l1.flatten().unsqueeze(0).float(), + ref_l1.flatten().unsqueeze(0).float(), ).item() - mse = (kernel_output.float() - ref_output.float()).pow(2).mean().item() + mse = (kernel_l1.float() - ref_l1.float()).pow(2).mean().item() print(f"\n{'=' * 70}") print(f" RESULT: cosine={cosine:.6f} MSE={mse:.6e}")