diff --git a/tests/test_e2e_decode_b200.py b/tests/test_e2e_decode_b200.py index f694324a..48e8a2c3 100644 --- a/tests/test_e2e_decode_b200.py +++ b/tests/test_e2e_decode_b200.py @@ -227,7 +227,7 @@ def run_layer(hidden, layer_id, runners, weights, cos_sin, positions, # Output projection: inverse RoPE + o_a BMM + o_b o_inv = apply_inv_gptj_rope(o_attn, positions, cos_sin, NOPE, ROPE) - o_grouped = o_inv.view(NT, OG, HPG * HD).permute(1, 0, 2) + o_grouped = o_inv.reshape(NT, OG, HPG * HD).permute(1, 0, 2) woa_3d = woa.view(OG, OL, HPG * HD) z = torch.bmm(o_grouped, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(NT, OG * OL) attn_out = r_wob.run(z)