From 429fc3db40126ffe5349607503628280dd8db1fc Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 31 May 2026 09:23:10 +0000 Subject: [PATCH] Fix expert weight indexing for 1D tensor --- tests/test_residual_diagnostic.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_residual_diagnostic.py b/tests/test_residual_diagnostic.py index 50edcdb4..1b80fde9 100644 --- a/tests/test_residual_diagnostic.py +++ b/tests/test_residual_diagnostic.py @@ -215,7 +215,8 @@ def main(): routed_out = torch.zeros_like(x_ffn_normed) for i, (out, wt) in enumerate(zip(expert_outputs, expert_weights)): - routed_out = routed_out + (out.float() * wt.item()).bfloat16() + w_val = wt.item() if wt.dim() == 0 else wt[i].item() if wt.dim() == 1 else wt.flatten()[i].item() + routed_out = routed_out + (out.float() * w_val).bfloat16() routed_out = (routed_out.float() * routed_scaling).bfloat16() # Shared expert