Add llama 4 scaling support (#28145)
Signed-off-by: Julien Denize <julien.denize@mistral.ai>
This commit is contained in:
@@ -160,6 +160,14 @@ class LlamaAttention(nn.Module):
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
llama_4_scaling_config = getattr(config, "llama_4_scaling", None)
|
||||
self.do_llama_4_scaling = llama_4_scaling_config is not None
|
||||
if self.do_llama_4_scaling:
|
||||
self.llama_4_scaling_original_max_position_embeddings = (
|
||||
llama_4_scaling_config["original_max_position_embeddings"]
|
||||
)
|
||||
self.llama_4_scaling_beta = llama_4_scaling_config["beta"]
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=hidden_size,
|
||||
head_size=self.head_dim,
|
||||
@@ -221,6 +229,17 @@ class LlamaAttention(nn.Module):
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
|
||||
# Llama4 scaling
|
||||
scaling = 1 + self.llama_4_scaling_beta * torch.log(
|
||||
1
|
||||
+ torch.floor(
|
||||
positions / self.llama_4_scaling_original_max_position_embeddings
|
||||
)
|
||||
)
|
||||
# Broadcast over head_dim
|
||||
return scaling.unsqueeze(-1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@@ -229,6 +248,9 @@ class LlamaAttention(nn.Module):
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
if self.do_llama_4_scaling:
|
||||
attn_scale = self._get_llama_4_attn_scale(positions)
|
||||
q = (q * attn_scale).to(q.dtype)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user