Fix multi-layer test: add residual connections

This commit is contained in:
2026-05-17 22:55:40 +00:00
parent 11dce13afe
commit 8637020487

View File

@@ -111,11 +111,13 @@ def main():
for layer in range(NUM_LAYERS):
with torch.no_grad():
# Runner
run_hidden_saved = run_hidden.clone()
runner.compute_activation_global_scales(run_hidden, topk_weights, topk_ids)
run_out = runner.run(run_hidden, topk_weights, topk_ids)
run_hidden = run_out # MoE output replaces hidden (simplified, no residual)
run_hidden = run_hidden + run_hidden_saved # Residual connection
# BF16 reference
ref_hidden_saved = ref_hidden.clone()
ref_out = torch.zeros(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE)
for i, e in enumerate(expert_indices):
dk = f"layers.{LAYER_IDX}.mlp.experts.{e}.down_proj.weight"
@@ -138,7 +140,7 @@ def main():
up = up.clamp(min=-SWIGLU_LIMIT, max=SWIGLU_LIMIT)
act = gate_silu * up
ref_out[t] += w * (act @ down_bf16.T)
ref_hidden = ref_out
ref_hidden = ref_out + ref_hidden_saved # Residual
cos = F.cosine_similarity(ref_hidden.flatten().unsqueeze(0), run_hidden.flatten().unsqueeze(0)).item()
has_nan = torch.isnan(run_hidden).any().item()