diff --git a/tests/test_residual_diagnostic.py b/tests/test_residual_diagnostic.py index 50edcdb4..1b80fde9 100644 --- a/tests/test_residual_diagnostic.py +++ b/tests/test_residual_diagnostic.py @@ -215,7 +215,8 @@ def main(): routed_out = torch.zeros_like(x_ffn_normed) for i, (out, wt) in enumerate(zip(expert_outputs, expert_weights)): - routed_out = routed_out + (out.float() * wt.item()).bfloat16() + w_val = wt.item() if wt.dim() == 0 else wt[i].item() if wt.dim() == 1 else wt.flatten()[i].item() + routed_out = routed_out + (out.float() * w_val).bfloat16() routed_out = (routed_out.float() * routed_scaling).bfloat16() # Shared expert