diff --git a/tests/test_multilayer.py b/tests/test_multilayer.py index 6a34ffd1..9a952877 100644 --- a/tests/test_multilayer.py +++ b/tests/test_multilayer.py @@ -111,11 +111,13 @@ def main(): for layer in range(NUM_LAYERS): with torch.no_grad(): # Runner + run_hidden_saved = run_hidden.clone() runner.compute_activation_global_scales(run_hidden, topk_weights, topk_ids) run_out = runner.run(run_hidden, topk_weights, topk_ids) - run_hidden = run_out # MoE output replaces hidden (simplified, no residual) + run_hidden = run_hidden + run_hidden_saved # Residual connection # BF16 reference + ref_hidden_saved = ref_hidden.clone() ref_out = torch.zeros(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) for i, e in enumerate(expert_indices): dk = f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight" @@ -138,7 +140,7 @@ def main(): up = up.clamp(min=-SWIGLU_LIMIT, max=SWIGLU_LIMIT) act = gate_silu * up ref_out[t] += w * (act @ down_bf16.T) - ref_hidden = ref_out + ref_hidden = ref_out + ref_hidden_saved # Residual cos = F.cosine_similarity(ref_hidden.flatten().unsqueeze(0), run_hidden.flatten().unsqueeze(0)).item() has_nan = torch.isnan(run_hidden).any().item()