Fix view→reshape for non-contiguous tensor
This commit is contained in:
@@ -227,7 +227,7 @@ def run_layer(hidden, layer_id, runners, weights, cos_sin, positions,
|
||||
|
||||
# Output projection: inverse RoPE + o_a BMM + o_b
|
||||
o_inv = apply_inv_gptj_rope(o_attn, positions, cos_sin, NOPE, ROPE)
|
||||
o_grouped = o_inv.view(NT, OG, HPG * HD).permute(1, 0, 2)
|
||||
o_grouped = o_inv.reshape(NT, OG, HPG * HD).permute(1, 0, 2)
|
||||
woa_3d = woa.view(OG, OL, HPG * HD)
|
||||
z = torch.bmm(o_grouped, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(NT, OG * OL)
|
||||
attn_out = r_wob.run(z)
|
||||
|
||||
Reference in New Issue
Block a user