Standardise get_rope to use rope_parameters["partial_rotary_factor"], not rotary_dim (#30389)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -187,7 +187,6 @@ class MiniMaxText01Attention(nn.Module):
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
num_kv_heads: int,
|
||||
rotary_dim: int,
|
||||
max_position: int = 4096 * 32,
|
||||
rope_parameters: dict | None = None,
|
||||
sliding_window: int | None = None,
|
||||
@@ -245,7 +244,6 @@ class MiniMaxText01Attention(nn.Module):
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
head_size=self.head_dim,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position=max_position,
|
||||
rope_parameters=rope_parameters,
|
||||
is_neox_style=True,
|
||||
@@ -290,6 +288,8 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
head_dim = getattr(config, "head_dim", None)
|
||||
if head_dim is None:
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
rotary_dim = getattr(config, "rotary_dim", head_dim)
|
||||
config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim
|
||||
if hasattr(config, "max_model_len") and isinstance(config.max_model_len, int):
|
||||
max_position_embeddings = min(
|
||||
config.max_position_embeddings, config.max_model_len
|
||||
@@ -321,9 +321,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
head_dim=head_dim,
|
||||
rotary_dim=config.rotary_dim
|
||||
if hasattr(config, "rotary_dim")
|
||||
else head_dim,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=config.rope_parameters,
|
||||
|
||||
Reference in New Issue
Block a user