From 95130298981428053a23376e89a771afc7fc37af Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 5 Jan 2026 18:20:46 -0500 Subject: [PATCH] [Bugfix] Properly apply v_scale for mimo_v2_flash (#31175) Signed-off-by: mgoin --- vllm/model_executor/models/mimo_v2_flash.py | 23 +++++++++------------ 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/models/mimo_v2_flash.py b/vllm/model_executor/models/mimo_v2_flash.py index 12b486f00..98d40a384 100644 --- a/vllm/model_executor/models/mimo_v2_flash.py +++ b/vllm/model_executor/models/mimo_v2_flash.py @@ -211,6 +211,7 @@ class MiMoV2Attention(nn.Module): num_kv_heads: int, head_dim: int, v_head_dim: int | None = None, + v_scale: float | None = None, sliding_window_size: int = -1, attention_bias: bool = False, add_swa_attention_sink_bias: bool = False, @@ -241,6 +242,7 @@ class MiMoV2Attention(nn.Module): self.k_size = self.num_kv_heads * self.head_dim self.v_size = self.num_kv_heads * self.v_head_dim + self.v_scale = v_scale self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -304,6 +306,10 @@ class MiMoV2Attention(nn.Module): q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) q, k = self.rotary_emb(positions, q, k) + # Apply v_scale before attention + if self.v_scale is not None: + v = v * self.v_scale + v = v.view(-1, self.num_kv_heads, self.v_head_dim) v = torch.nn.functional.pad(v, [0, self.head_dim - self.v_head_dim], value=0) v = v.view(-1, self.num_kv_heads * self.head_dim) @@ -332,6 +338,8 @@ class MiMoV2FlashDecoderLayer(nn.Module): rope_theta = getattr(config, "rope_theta", 1000000) max_position_embeddings = getattr(config, "max_position_embeddings", 32768) + v_scale = getattr(config, "attention_value_scale", None) + if self.is_compressed_softmax_layer(): self.self_attn = MiMoV2Attention( hidden_size=self.hidden_size, @@ -339,6 +347,7 @@ class MiMoV2FlashDecoderLayer(nn.Module): num_kv_heads=config.swa_num_key_value_heads, head_dim=config.swa_head_dim, v_head_dim=getattr(config, "swa_v_head_dim", None), + v_scale=v_scale, sliding_window_size=config.sliding_window_size, attention_bias=config.attention_bias, add_swa_attention_sink_bias=getattr( @@ -358,6 +367,7 @@ class MiMoV2FlashDecoderLayer(nn.Module): num_kv_heads=config.num_key_value_heads, head_dim=config.head_dim, v_head_dim=getattr(config, "v_head_dim", None), + v_scale=v_scale, sliding_window_size=-1, # normal attention attention_bias=config.attention_bias, layer_id=layer_id, @@ -433,7 +443,6 @@ class MiMoV2Model(nn.Module): self.quant_config = quant_config self.vocab_size = config.vocab_size self.num_redundant_experts = eplb_config.num_redundant_experts - self.v_scale = getattr(config, "attention_value_scale", None) if get_pp_group().is_first_rank or ( config.tie_word_embeddings and get_pp_group().is_last_rank @@ -605,18 +614,6 @@ class MiMoV2Model(nn.Module): param = params_dict[name_rewritten] weight_loader = getattr(param, "weight_loader", default_weight_loader) - - if param_name == "qkv_proj" and shard_id == "v": - v_scale = ( - self.v_scale - if self.v_scale is not None - else getattr(self.config, "attention_value_scale", None) - ) - if v_scale is not None and ( - name.endswith("weight_scale_inv") or name.endswith(".bias") - ): - loaded_weight *= float(v_scale) - weight_loader(param, loaded_weight, shard_id) loaded_params.add(name_rewritten)