[Minor][Models] Pass partial_rotary_factor parameter to rope (#17266)

Signed-off-by: evian <eviantai@u.nus.edu>
Co-authored-by: evian <eviantai@u.nus.edu>
This commit is contained in:
Wanrui Dai
2025-04-28 12:28:59 +08:00
committed by GitHub
parent 8262a3e23b
commit 7fcc4223dc
3 changed files with 10 additions and 8 deletions

View File

@@ -130,8 +130,8 @@ class LlamaAttention(nn.Module):
self.head_dim = getattr(config, "head_dim",
self.hidden_size // self.total_num_heads)
# Phi models introduced a partial_rotary_factor parameter in the config
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
self.rotary_dim = int(partial_rotary_factor * self.head_dim)
self.partial_rotary_factor = getattr(config, "partial_rotary_factor",
1)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
@@ -163,11 +163,12 @@ class LlamaAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
partial_rotary_factor=self.partial_rotary_factor,
)
if hasattr(config, "interleaved_sliding_window"):