[BugFix] Fix EXAONE4 rotary embeddings (#23918)
Signed-off-by: lkm2835 <lkm2835@gmail.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -164,8 +164,8 @@ class Exaone4Attention(nn.Module):
|
|||||||
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
|
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
|
||||||
self.sliding_window = config.sliding_window if is_sliding else None
|
self.sliding_window = config.sliding_window if is_sliding else None
|
||||||
|
|
||||||
# apply rotary embeddings to every layer
|
# apply rotary embeddings to every layer in full attention models
|
||||||
self.apply_all_layers = not is_sliding
|
self.apply_rope_all_layers = "sliding_attention" not in config.layer_types
|
||||||
|
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@@ -201,7 +201,7 @@ class Exaone4Attention(nn.Module):
|
|||||||
k = self.k_norm(k)
|
k = self.k_norm(k)
|
||||||
k = k.flatten(-2, -1)
|
k = k.flatten(-2, -1)
|
||||||
|
|
||||||
if self.sliding_window or self.apply_all_layers:
|
if self.sliding_window or self.apply_rope_all_layers:
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
|
|||||||
Reference in New Issue
Block a user