[Bugfix] Properly apply v_scale for mimo_v2_flash (#31175)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2026-01-05 18:20:46 -05:00
committed by GitHub
parent f6c0009afa
commit 9513029898

View File

@@ -211,6 +211,7 @@ class MiMoV2Attention(nn.Module):
num_kv_heads: int,
head_dim: int,
v_head_dim: int | None = None,
v_scale: float | None = None,
sliding_window_size: int = -1,
attention_bias: bool = False,
add_swa_attention_sink_bias: bool = False,
@@ -241,6 +242,7 @@ class MiMoV2Attention(nn.Module):
self.k_size = self.num_kv_heads * self.head_dim
self.v_size = self.num_kv_heads * self.v_head_dim
self.v_scale = v_scale
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
@@ -304,6 +306,10 @@ class MiMoV2Attention(nn.Module):
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
# Apply v_scale before attention
if self.v_scale is not None:
v = v * self.v_scale
v = v.view(-1, self.num_kv_heads, self.v_head_dim)
v = torch.nn.functional.pad(v, [0, self.head_dim - self.v_head_dim], value=0)
v = v.view(-1, self.num_kv_heads * self.head_dim)
@@ -332,6 +338,8 @@ class MiMoV2FlashDecoderLayer(nn.Module):
rope_theta = getattr(config, "rope_theta", 1000000)
max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
v_scale = getattr(config, "attention_value_scale", None)
if self.is_compressed_softmax_layer():
self.self_attn = MiMoV2Attention(
hidden_size=self.hidden_size,
@@ -339,6 +347,7 @@ class MiMoV2FlashDecoderLayer(nn.Module):
num_kv_heads=config.swa_num_key_value_heads,
head_dim=config.swa_head_dim,
v_head_dim=getattr(config, "swa_v_head_dim", None),
v_scale=v_scale,
sliding_window_size=config.sliding_window_size,
attention_bias=config.attention_bias,
add_swa_attention_sink_bias=getattr(
@@ -358,6 +367,7 @@ class MiMoV2FlashDecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads,
head_dim=config.head_dim,
v_head_dim=getattr(config, "v_head_dim", None),
v_scale=v_scale,
sliding_window_size=-1, # normal attention
attention_bias=config.attention_bias,
layer_id=layer_id,
@@ -433,7 +443,6 @@ class MiMoV2Model(nn.Module):
self.quant_config = quant_config
self.vocab_size = config.vocab_size
self.num_redundant_experts = eplb_config.num_redundant_experts
self.v_scale = getattr(config, "attention_value_scale", None)
if get_pp_group().is_first_rank or (
config.tie_word_embeddings and get_pp_group().is_last_rank
@@ -605,18 +614,6 @@ class MiMoV2Model(nn.Module):
param = params_dict[name_rewritten]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
if param_name == "qkv_proj" and shard_id == "v":
v_scale = (
self.v_scale
if self.v_scale is not None
else getattr(self.config, "attention_value_scale", None)
)
if v_scale is not None and (
name.endswith("weight_scale_inv") or name.endswith(".bias")
):
loaded_weight *= float(v_scale)
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name_rewritten)