Fix expert weight indexing for 1D tensor
This commit is contained in:
@@ -215,7 +215,8 @@ def main():
|
||||
|
||||
routed_out = torch.zeros_like(x_ffn_normed)
|
||||
for i, (out, wt) in enumerate(zip(expert_outputs, expert_weights)):
|
||||
routed_out = routed_out + (out.float() * wt.item()).bfloat16()
|
||||
w_val = wt.item() if wt.dim() == 0 else wt[i].item() if wt.dim() == 1 else wt.flatten()[i].item()
|
||||
routed_out = routed_out + (out.float() * w_val).bfloat16()
|
||||
routed_out = (routed_out.float() * routed_scaling).bfloat16()
|
||||
|
||||
# Shared expert
|
||||
|
||||
Reference in New Issue
Block a user