Standardise get_rope to use rope_parameters["partial_rotary_factor"], not rotary_dim (#30389)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-12-11 20:45:23 +00:00
committed by GitHub
parent 92fea56fd1
commit cf3eacfe58
83 changed files with 260 additions and 314 deletions

View File

@@ -84,19 +84,18 @@ class PhiAttention(nn.Module):
prefix: str = "",
):
super().__init__()
self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads
self.head_size = self.hidden_size // config.num_attention_heads
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
assert config.num_attention_heads % tensor_model_parallel_world_size == 0
self.num_heads = config.num_attention_heads // tensor_model_parallel_world_size
# pylint: disable=C0103
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
self.head_size,
self.total_num_heads,
config.num_attention_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
@@ -109,13 +108,10 @@ class PhiAttention(nn.Module):
)
scaling = self.head_size**-0.5
rotary_dim = config.hidden_size // config.num_attention_heads
assert rotary_dim % 2 == 0
max_position_embeddings = getattr(config, "max_position_embeddings", 2048)
self.rotary_emb = get_rope(
self.head_size,
rotary_dim=rotary_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
)