[Model] Add MiMo-V2-Flash support (#30836)
Signed-off-by: Abatom <abzhonghua@gmail.com> Signed-off-by: Jumiar <liuanqim10@126.com> Signed-off-by: Zyann7 <zyann7@outlook.com> Co-authored-by: Jumiar <liuanqim10@126.com> Co-authored-by: Zyann7 <zyann7@outlook.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -277,6 +277,7 @@ class LinearBase(CustomOp):
|
||||
self.params_dtype = params_dtype
|
||||
self.quant_config = quant_config
|
||||
self.prefix = prefix
|
||||
self.allow_fp8_block_shape_mismatch = False
|
||||
if quant_config is None:
|
||||
self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
|
||||
else:
|
||||
@@ -475,6 +476,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
disable_tp=disable_tp,
|
||||
)
|
||||
|
||||
self._maybe_allow_fp8_block_shape_mismatch()
|
||||
self.gather_output = gather_output
|
||||
|
||||
if output_sizes is None:
|
||||
@@ -509,6 +511,33 @@ class ColumnParallelLinear(LinearBase):
|
||||
self.register_parameter("bias", None)
|
||||
self.update_param_tp_status()
|
||||
|
||||
def _maybe_allow_fp8_block_shape_mismatch(self) -> None:
|
||||
quant_config = getattr(self, "quant_config", None)
|
||||
weight_block = getattr(quant_config, "weight_block_size", None)
|
||||
if (
|
||||
weight_block is None
|
||||
or len(weight_block) < 1
|
||||
or len(self.output_partition_sizes) <= 1
|
||||
):
|
||||
return
|
||||
|
||||
try:
|
||||
block_n = int(weight_block[0])
|
||||
except (ValueError, TypeError):
|
||||
return
|
||||
|
||||
if block_n <= 0:
|
||||
return
|
||||
|
||||
if any(size % block_n != 0 for size in self.output_partition_sizes):
|
||||
self.allow_fp8_block_shape_mismatch = True
|
||||
logger.debug(
|
||||
"Allowing FP8 block shape mismatch for %s (block_n=%d, partitions=%s)",
|
||||
getattr(self, "prefix", "<unknown>"),
|
||||
block_n,
|
||||
self.output_partition_sizes,
|
||||
)
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
|
||||
@@ -906,9 +935,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
v_head_size: int | None = None,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = head_size
|
||||
self.v_head_size = v_head_size if v_head_size is not None else head_size
|
||||
self.total_num_heads = total_num_heads
|
||||
if total_num_kv_heads is None:
|
||||
total_num_kv_heads = total_num_heads
|
||||
@@ -924,12 +955,14 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
self.num_kv_head_replicas = 1
|
||||
input_size = self.hidden_size
|
||||
output_size = (
|
||||
(self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
|
||||
)
|
||||
self.num_heads * self.head_size
|
||||
+ self.num_kv_heads * self.head_size
|
||||
+ self.num_kv_heads * self.v_head_size
|
||||
) * tp_size
|
||||
self.output_sizes = [
|
||||
self.num_heads * self.head_size * tp_size, # q_proj
|
||||
self.num_kv_heads * self.head_size * tp_size, # k_proj
|
||||
self.num_kv_heads * self.head_size * tp_size, # v_proj
|
||||
self.num_kv_heads * self.v_head_size * tp_size, # v_proj
|
||||
]
|
||||
|
||||
super().__init__(
|
||||
@@ -950,7 +983,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
"q": 0,
|
||||
"k": self.num_heads * self.head_size,
|
||||
"v": (self.num_heads + self.num_kv_heads) * self.head_size,
|
||||
"total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size,
|
||||
"total": (self.num_heads + self.num_kv_heads) * self.head_size
|
||||
+ self.num_kv_heads * self.v_head_size,
|
||||
}
|
||||
return shard_offset_mapping.get(loaded_shard_id)
|
||||
|
||||
@@ -958,7 +992,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_size_mapping = {
|
||||
"q": self.num_heads * self.head_size,
|
||||
"k": self.num_kv_heads * self.head_size,
|
||||
"v": self.num_kv_heads * self.head_size,
|
||||
"v": self.num_kv_heads * self.v_head_size,
|
||||
}
|
||||
return shard_size_mapping.get(loaded_shard_id)
|
||||
|
||||
@@ -985,7 +1019,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
(
|
||||
"v",
|
||||
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
|
||||
self.total_num_kv_heads * self.head_size,
|
||||
self.total_num_kv_heads * self.v_head_size,
|
||||
),
|
||||
]
|
||||
|
||||
@@ -1110,7 +1144,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
(
|
||||
"v",
|
||||
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
|
||||
self.total_num_kv_heads * self.head_size,
|
||||
self.total_num_kv_heads * self.v_head_size,
|
||||
),
|
||||
]
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
@@ -1139,11 +1173,12 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
"v": (
|
||||
(self.total_num_heads + self.total_num_kv_heads)
|
||||
* self.head_size,
|
||||
self.total_num_kv_heads * self.head_size,
|
||||
self.total_num_kv_heads * self.v_head_size,
|
||||
),
|
||||
"total": (
|
||||
(self.total_num_heads + 2 * self.total_num_kv_heads)
|
||||
* self.head_size,
|
||||
(self.total_num_heads + self.total_num_kv_heads)
|
||||
* self.head_size
|
||||
+ self.total_num_kv_heads * self.v_head_size,
|
||||
0,
|
||||
),
|
||||
}
|
||||
@@ -1170,7 +1205,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
elif loaded_shard_id == "v":
|
||||
shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
shard_size = self.num_kv_heads * self.v_head_size
|
||||
# Special case for Quantized Weights.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
@@ -1199,10 +1234,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
),
|
||||
"v": (
|
||||
(self.num_heads + self.num_kv_heads) * self.head_size,
|
||||
self.num_kv_heads * self.head_size,
|
||||
self.num_kv_heads * self.v_head_size,
|
||||
),
|
||||
"total": (
|
||||
(self.num_heads + 2 * self.num_kv_heads) * self.head_size,
|
||||
(self.num_heads + self.num_kv_heads) * self.head_size
|
||||
+ self.num_kv_heads * self.v_head_size,
|
||||
0,
|
||||
),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user