From 842e6e138115a188a71402f4a43d404f42369152 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 19 May 2026 15:54:40 +0000 Subject: [PATCH] =?UTF-8?q?Fix=20view=E2=86=92reshape=20for=20non-contiguo?= =?UTF-8?q?us=20tensor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_e2e_decode_b200.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_e2e_decode_b200.py b/tests/test_e2e_decode_b200.py index f694324a..48e8a2c3 100644 --- a/tests/test_e2e_decode_b200.py +++ b/tests/test_e2e_decode_b200.py @@ -227,7 +227,7 @@ def run_layer(hidden, layer_id, runners, weights, cos_sin, positions, # Output projection: inverse RoPE + o_a BMM + o_b o_inv = apply_inv_gptj_rope(o_attn, positions, cos_sin, NOPE, ROPE) - o_grouped = o_inv.view(NT, OG, HPG * HD).permute(1, 0, 2) + o_grouped = o_inv.reshape(NT, OG, HPG * HD).permute(1, 0, 2) woa_3d = woa.view(OG, OL, HPG * HD) z = torch.bmm(o_grouped, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(NT, OG * OL) attn_out = r_wob.run(z)