fix: layertest tests L1 GEMM only with correct output size

L1 produces (tokens, 6144) gate+up, not (tokens, 7168) hidden.
Compare against BF16 L1 reference only.
This commit is contained in:
2026-05-16 03:15:29 +00:00
parent 0cdcc4144a
commit 5385de3142

View File

@@ -145,7 +145,7 @@ def moe_forward_bf16(hidden_states, experts, expert_ids, expert_weights):
# ── CuTeDSL NVFP4 Kernel MoE Forward ──────────────────────────────────
def moe_forward_nvfp4(hidden_states, nvfp4_tensors, layer_idx, expert_ids, expert_weights):
def moe_forward_nvfp4_l1_only(hidden_states, nvfp4_tensors, layer_idx, expert_ids, expert_weights):
"""Run MoE forward pass using the CuTeDSL NVFP4 kernel via bridge."""
num_tokens, hidden_size = hidden_states.shape
top_k = expert_ids.shape[1]
@@ -270,25 +270,36 @@ def main():
expert_ids = torch.tensor([[0, 1]] * num_tokens, dtype=torch.int32, device=DEVICE)
expert_weights = torch.tensor([[0.6, 0.4]] * num_tokens, dtype=torch.float32, device=DEVICE)
# ── BF16 reference forward pass ──
print("\n Running BF16 reference...")
ref_output = moe_forward_bf16(hidden_states, nvfp4_experts_bf16, expert_ids, expert_weights)
print(f" BF16 ref: amax={ref_output.abs().max():.4f} mean={ref_output.float().mean():.6f}")
# ── BF16 L1 reference (gate+up only) ──
print("\n Running BF16 L1 reference...")
ref_l1 = torch.zeros(num_tokens, 6144, dtype=torch.bfloat16, device=DEVICE)
for t in range(num_tokens):
for k in range(top_k):
e = expert_ids[t, k].item()
w = expert_weights[t, k].item()
if e not in nvfp4_experts_bf16:
continue
x = hidden_states[t]
gate = x @ nvfp4_experts_bf16[e]["gate_proj"].T # (3072,)
up = x @ nvfp4_experts_bf16[e]["up_proj"].T # (3072,)
ref_l1[t] += w * torch.cat([gate, up])
print(f" BF16 L1 ref: amax={ref_l1.abs().max():.4f} mean={ref_l1.float().mean():.6f}")
del nvfp4_experts_bf16
torch.cuda.empty_cache()
# ── CuTeDSL NVFP4 kernel forward pass ──
print("\n Running CuTeDSL NVFP4 kernel (first run compiles, ~1-2 min)...")
kernel_output = moe_forward_nvfp4(hidden_states, nvfp4_tensors, LAYER_IDX, expert_ids, expert_weights)
print(f" Kernel: amax={kernel_output.abs().max():.4f} mean={kernel_output.float().mean():.6f}")
# ── CuTeDSL NVFP4 L1 kernel ──
print("\n Running CuTeDSL NVFP4 L1 kernel (first run compiles, ~1-2 min)...")
kernel_l1 = moe_forward_nvfp4_l1_only(hidden_states, nvfp4_tensors, LAYER_IDX, expert_ids, expert_weights)
print(f" Kernel L1: amax={kernel_l1.abs().max():.4f} mean={kernel_l1.float().mean():.6f}")
# ── Compare ──
cosine = torch.nn.functional.cosine_similarity(
kernel_output.flatten().unsqueeze(0).float(),
ref_output.flatten().unsqueeze(0).float(),
kernel_l1.flatten().unsqueeze(0).float(),
ref_l1.flatten().unsqueeze(0).float(),
).item()
mse = (kernel_output.float() - ref_output.float()).pow(2).mean().item()
mse = (kernel_l1.float() - ref_l1.float()).pow(2).mean().item()
print(f"\n{'=' * 70}")
print(f" RESULT: cosine={cosine:.6f} MSE={mse:.6e}")