[Model] Add MiMo-V2-Flash support (#30836)

Signed-off-by: Abatom <abzhonghua@gmail.com>
Signed-off-by: Jumiar <liuanqim10@126.com>
Signed-off-by: Zyann7 <zyann7@outlook.com>
Co-authored-by: Jumiar <liuanqim10@126.com>
Co-authored-by: Zyann7 <zyann7@outlook.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Zhonghua Deng
2025-12-20 01:17:03 +08:00
committed by GitHub
parent 268a972c62
commit 969bbc7c61
8 changed files with 789 additions and 13 deletions

View File

@@ -277,6 +277,7 @@ class LinearBase(CustomOp):
self.params_dtype = params_dtype
self.quant_config = quant_config
self.prefix = prefix
self.allow_fp8_block_shape_mismatch = False
if quant_config is None:
self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
else:
@@ -475,6 +476,7 @@ class ColumnParallelLinear(LinearBase):
disable_tp=disable_tp,
)
self._maybe_allow_fp8_block_shape_mismatch()
self.gather_output = gather_output
if output_sizes is None:
@@ -509,6 +511,33 @@ class ColumnParallelLinear(LinearBase):
self.register_parameter("bias", None)
self.update_param_tp_status()
def _maybe_allow_fp8_block_shape_mismatch(self) -> None:
quant_config = getattr(self, "quant_config", None)
weight_block = getattr(quant_config, "weight_block_size", None)
if (
weight_block is None
or len(weight_block) < 1
or len(self.output_partition_sizes) <= 1
):
return
try:
block_n = int(weight_block[0])
except (ValueError, TypeError):
return
if block_n <= 0:
return
if any(size % block_n != 0 for size in self.output_partition_sizes):
self.allow_fp8_block_shape_mismatch = True
logger.debug(
"Allowing FP8 block shape mismatch for %s (block_n=%d, partitions=%s)",
getattr(self, "prefix", "<unknown>"),
block_n,
self.output_partition_sizes,
)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
output_dim = getattr(param, "output_dim", None)
@@ -906,9 +935,11 @@ class QKVParallelLinear(ColumnParallelLinear):
*,
return_bias: bool = True,
disable_tp: bool = False,
v_head_size: int | None = None,
):
self.hidden_size = hidden_size
self.head_size = head_size
self.v_head_size = v_head_size if v_head_size is not None else head_size
self.total_num_heads = total_num_heads
if total_num_kv_heads is None:
total_num_kv_heads = total_num_heads
@@ -924,12 +955,14 @@ class QKVParallelLinear(ColumnParallelLinear):
self.num_kv_head_replicas = 1
input_size = self.hidden_size
output_size = (
(self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
)
self.num_heads * self.head_size
+ self.num_kv_heads * self.head_size
+ self.num_kv_heads * self.v_head_size
) * tp_size
self.output_sizes = [
self.num_heads * self.head_size * tp_size, # q_proj
self.num_kv_heads * self.head_size * tp_size, # k_proj
self.num_kv_heads * self.head_size * tp_size, # v_proj
self.num_kv_heads * self.v_head_size * tp_size, # v_proj
]
super().__init__(
@@ -950,7 +983,8 @@ class QKVParallelLinear(ColumnParallelLinear):
"q": 0,
"k": self.num_heads * self.head_size,
"v": (self.num_heads + self.num_kv_heads) * self.head_size,
"total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size,
"total": (self.num_heads + self.num_kv_heads) * self.head_size
+ self.num_kv_heads * self.v_head_size,
}
return shard_offset_mapping.get(loaded_shard_id)
@@ -958,7 +992,7 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size_mapping = {
"q": self.num_heads * self.head_size,
"k": self.num_kv_heads * self.head_size,
"v": self.num_kv_heads * self.head_size,
"v": self.num_kv_heads * self.v_head_size,
}
return shard_size_mapping.get(loaded_shard_id)
@@ -985,7 +1019,7 @@ class QKVParallelLinear(ColumnParallelLinear):
(
"v",
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
self.total_num_kv_heads * self.head_size,
self.total_num_kv_heads * self.v_head_size,
),
]
@@ -1110,7 +1144,7 @@ class QKVParallelLinear(ColumnParallelLinear):
(
"v",
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
self.total_num_kv_heads * self.head_size,
self.total_num_kv_heads * self.v_head_size,
),
]
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
@@ -1139,11 +1173,12 @@ class QKVParallelLinear(ColumnParallelLinear):
"v": (
(self.total_num_heads + self.total_num_kv_heads)
* self.head_size,
self.total_num_kv_heads * self.head_size,
self.total_num_kv_heads * self.v_head_size,
),
"total": (
(self.total_num_heads + 2 * self.total_num_kv_heads)
* self.head_size,
(self.total_num_heads + self.total_num_kv_heads)
* self.head_size
+ self.total_num_kv_heads * self.v_head_size,
0,
),
}
@@ -1170,7 +1205,7 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size = self.num_kv_heads * self.head_size
elif loaded_shard_id == "v":
shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
shard_size = self.num_kv_heads * self.head_size
shard_size = self.num_kv_heads * self.v_head_size
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
@@ -1199,10 +1234,11 @@ class QKVParallelLinear(ColumnParallelLinear):
),
"v": (
(self.num_heads + self.num_kv_heads) * self.head_size,
self.num_kv_heads * self.head_size,
self.num_kv_heads * self.v_head_size,
),
"total": (
(self.num_heads + 2 * self.num_kv_heads) * self.head_size,
(self.num_heads + self.num_kv_heads) * self.head_size
+ self.num_kv_heads * self.v_head_size,
0,
),
}