Fix view→reshape for non-contiguous tensor

This commit is contained in:
2026-05-19 15:54:40 +00:00
parent f0f8d8211b
commit 842e6e1381

View File

@@ -227,7 +227,7 @@ def run_layer(hidden, layer_id, runners, weights, cos_sin, positions,
# Output projection: inverse RoPE + o_a BMM + o_b
o_inv = apply_inv_gptj_rope(o_attn, positions, cos_sin, NOPE, ROPE)
o_grouped = o_inv.view(NT, OG, HPG * HD).permute(1, 0, 2)
o_grouped = o_inv.reshape(NT, OG, HPG * HD).permute(1, 0, 2)
woa_3d = woa.view(OG, OL, HPG * HD)
z = torch.bmm(o_grouped, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(NT, OG * OL)
attn_out = r_wob.run(z)