[Bugfix] Properly apply v_scale for mimo_v2_flash (#31175)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -211,6 +211,7 @@ class MiMoV2Attention(nn.Module):
|
|||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
head_dim: int,
|
head_dim: int,
|
||||||
v_head_dim: int | None = None,
|
v_head_dim: int | None = None,
|
||||||
|
v_scale: float | None = None,
|
||||||
sliding_window_size: int = -1,
|
sliding_window_size: int = -1,
|
||||||
attention_bias: bool = False,
|
attention_bias: bool = False,
|
||||||
add_swa_attention_sink_bias: bool = False,
|
add_swa_attention_sink_bias: bool = False,
|
||||||
@@ -241,6 +242,7 @@ class MiMoV2Attention(nn.Module):
|
|||||||
self.k_size = self.num_kv_heads * self.head_dim
|
self.k_size = self.num_kv_heads * self.head_dim
|
||||||
self.v_size = self.num_kv_heads * self.v_head_dim
|
self.v_size = self.num_kv_heads * self.v_head_dim
|
||||||
|
|
||||||
|
self.v_scale = v_scale
|
||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
@@ -304,6 +306,10 @@ class MiMoV2Attention(nn.Module):
|
|||||||
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
|
||||||
|
# Apply v_scale before attention
|
||||||
|
if self.v_scale is not None:
|
||||||
|
v = v * self.v_scale
|
||||||
|
|
||||||
v = v.view(-1, self.num_kv_heads, self.v_head_dim)
|
v = v.view(-1, self.num_kv_heads, self.v_head_dim)
|
||||||
v = torch.nn.functional.pad(v, [0, self.head_dim - self.v_head_dim], value=0)
|
v = torch.nn.functional.pad(v, [0, self.head_dim - self.v_head_dim], value=0)
|
||||||
v = v.view(-1, self.num_kv_heads * self.head_dim)
|
v = v.view(-1, self.num_kv_heads * self.head_dim)
|
||||||
@@ -332,6 +338,8 @@ class MiMoV2FlashDecoderLayer(nn.Module):
|
|||||||
rope_theta = getattr(config, "rope_theta", 1000000)
|
rope_theta = getattr(config, "rope_theta", 1000000)
|
||||||
max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
|
max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
|
||||||
|
|
||||||
|
v_scale = getattr(config, "attention_value_scale", None)
|
||||||
|
|
||||||
if self.is_compressed_softmax_layer():
|
if self.is_compressed_softmax_layer():
|
||||||
self.self_attn = MiMoV2Attention(
|
self.self_attn = MiMoV2Attention(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
@@ -339,6 +347,7 @@ class MiMoV2FlashDecoderLayer(nn.Module):
|
|||||||
num_kv_heads=config.swa_num_key_value_heads,
|
num_kv_heads=config.swa_num_key_value_heads,
|
||||||
head_dim=config.swa_head_dim,
|
head_dim=config.swa_head_dim,
|
||||||
v_head_dim=getattr(config, "swa_v_head_dim", None),
|
v_head_dim=getattr(config, "swa_v_head_dim", None),
|
||||||
|
v_scale=v_scale,
|
||||||
sliding_window_size=config.sliding_window_size,
|
sliding_window_size=config.sliding_window_size,
|
||||||
attention_bias=config.attention_bias,
|
attention_bias=config.attention_bias,
|
||||||
add_swa_attention_sink_bias=getattr(
|
add_swa_attention_sink_bias=getattr(
|
||||||
@@ -358,6 +367,7 @@ class MiMoV2FlashDecoderLayer(nn.Module):
|
|||||||
num_kv_heads=config.num_key_value_heads,
|
num_kv_heads=config.num_key_value_heads,
|
||||||
head_dim=config.head_dim,
|
head_dim=config.head_dim,
|
||||||
v_head_dim=getattr(config, "v_head_dim", None),
|
v_head_dim=getattr(config, "v_head_dim", None),
|
||||||
|
v_scale=v_scale,
|
||||||
sliding_window_size=-1, # normal attention
|
sliding_window_size=-1, # normal attention
|
||||||
attention_bias=config.attention_bias,
|
attention_bias=config.attention_bias,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
@@ -433,7 +443,6 @@ class MiMoV2Model(nn.Module):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.num_redundant_experts = eplb_config.num_redundant_experts
|
self.num_redundant_experts = eplb_config.num_redundant_experts
|
||||||
self.v_scale = getattr(config, "attention_value_scale", None)
|
|
||||||
|
|
||||||
if get_pp_group().is_first_rank or (
|
if get_pp_group().is_first_rank or (
|
||||||
config.tie_word_embeddings and get_pp_group().is_last_rank
|
config.tie_word_embeddings and get_pp_group().is_last_rank
|
||||||
@@ -605,18 +614,6 @@ class MiMoV2Model(nn.Module):
|
|||||||
|
|
||||||
param = params_dict[name_rewritten]
|
param = params_dict[name_rewritten]
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
|
||||||
if param_name == "qkv_proj" and shard_id == "v":
|
|
||||||
v_scale = (
|
|
||||||
self.v_scale
|
|
||||||
if self.v_scale is not None
|
|
||||||
else getattr(self.config, "attention_value_scale", None)
|
|
||||||
)
|
|
||||||
if v_scale is not None and (
|
|
||||||
name.endswith("weight_scale_inv") or name.endswith(".bias")
|
|
||||||
):
|
|
||||||
loaded_weight *= float(v_scale)
|
|
||||||
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
loaded_params.add(name_rewritten)
|
loaded_params.add(name_rewritten)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user