Fix expert weight indexing for 1D tensor

This commit is contained in:
2026-05-31 09:23:10 +00:00
parent 33004dcbf4
commit 429fc3db40

View File

@@ -215,7 +215,8 @@ def main():
routed_out = torch.zeros_like(x_ffn_normed)
for i, (out, wt) in enumerate(zip(expert_outputs, expert_weights)):
routed_out = routed_out + (out.float() * wt.item()).bfloat16()
w_val = wt.item() if wt.dim() == 0 else wt[i].item() if wt.dim() == 1 else wt.flatten()[i].item()
routed_out = routed_out + (out.float() * w_val).bfloat16()
routed_out = (routed_out.float() * routed_scaling).bfloat16()
# Shared expert