[Model][Quantization] Add GGUF support for MiniMax-M2.1 (#36965)
Signed-off-by: kangletian <Letian.Kang@amd.com>
This commit is contained in:
@@ -948,6 +948,7 @@ class ModelConfig:
|
||||
# imports during override detection (e.g., MXFP4 imports Triton)
|
||||
"mxfp4",
|
||||
"cpu_awq",
|
||||
"gguf",
|
||||
]
|
||||
quantization_methods = [
|
||||
q for q in supported_quantization if q not in overrides
|
||||
|
||||
@@ -3,7 +3,10 @@
|
||||
|
||||
from collections.abc import Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
@@ -79,6 +82,16 @@ class GGUFConfig(QuantizationConfig):
|
||||
def from_config(cls, config: dict[str, Any]) -> "GGUFConfig":
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg: dict[str, Any], user_quant: str | None
|
||||
) -> "QuantizationMethods | None":
|
||||
# When user explicitly specifies --quantization gguf, override
|
||||
# whatever quantization method is in the HF model config (e.g. fp8).
|
||||
if user_quant == "gguf":
|
||||
return "gguf"
|
||||
return None
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "QuantizeMethodBase | None":
|
||||
|
||||
@@ -24,6 +24,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
get_gguf_extra_tensor_names,
|
||||
get_gguf_weight_type_map,
|
||||
gguf_quant_weights_iterator,
|
||||
gguf_quant_weights_iterator_multi,
|
||||
)
|
||||
from vllm.transformers_utils.gguf_utils import detect_gguf_multimodal
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
@@ -74,6 +75,31 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
"or <repo_id>:<quant_type>)"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_all_gguf_files(model_path: str) -> list[str]:
|
||||
"""Discover all GGUF shard files from a single shard path.
|
||||
|
||||
Supports variable-width shard indices by dynamically detecting
|
||||
the padding from the original filename.
|
||||
E.g. ``*-00001-of-00005.gguf`` → all 5 shards,
|
||||
``*-01-of-15.gguf`` → all 15 shards.
|
||||
"""
|
||||
match = re.search(r"-(\d+)-of-(\d+)\.gguf$", model_path)
|
||||
if not match:
|
||||
return [model_path]
|
||||
total = int(match.group(2))
|
||||
num_digits = len(match.group(1))
|
||||
prefix = model_path[: match.start(1)]
|
||||
suffix = model_path[match.end(2) :]
|
||||
files = []
|
||||
for i in range(1, total + 1):
|
||||
shard_path = f"{prefix}{i:0{num_digits}d}-of-{total:0{num_digits}d}{suffix}"
|
||||
if os.path.isfile(shard_path):
|
||||
files.append(shard_path)
|
||||
if files:
|
||||
logger.info("Discovered %d GGUF shard files", len(files))
|
||||
return files if files else [model_path]
|
||||
|
||||
def _get_gguf_weights_map(self, model_config: ModelConfig):
|
||||
"""
|
||||
GGUF uses this naming convention for their tensors from HF checkpoint:
|
||||
@@ -145,6 +171,29 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
r"\.mlp\.experts\.[0-9]+\.(gate|up|down)_proj\.weight"
|
||||
)
|
||||
)
|
||||
if model_type == "minimax_m2":
|
||||
model_type = "minimax-m2"
|
||||
# GGUF layer map assumes merged expert weights
|
||||
# map them manually like deepseek2
|
||||
for idx in range(config.num_hidden_layers):
|
||||
gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = (
|
||||
f"model.layers.{idx}.block_sparse_moe.e_score_correction_bias"
|
||||
)
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = (
|
||||
f"model.layers.{idx}.block_sparse_moe.experts.0.w2.weight"
|
||||
)
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = (
|
||||
f"model.layers.{idx}.block_sparse_moe.experts.0.w1.weight"
|
||||
)
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = (
|
||||
f"model.layers.{idx}.block_sparse_moe.experts.0.w3.weight"
|
||||
)
|
||||
sideload_params.append(
|
||||
re.compile(
|
||||
f"model\\.layers\\.{idx}"
|
||||
r"\.block_sparse_moe\.experts\.(gate_up_proj|down_proj)"
|
||||
)
|
||||
)
|
||||
|
||||
arch = None
|
||||
for key, value in gguf.MODEL_ARCH_NAMES.items():
|
||||
@@ -190,6 +239,13 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
revert_hf_rename(name): tensor for name, tensor in state_dict.items()
|
||||
}
|
||||
|
||||
if model_type == "minimax-m2" and not hf_checkpoint_map:
|
||||
# Reverse HF convention: mlp -> block_sparse_moe
|
||||
state_dict = {
|
||||
name.replace(".mlp.", ".block_sparse_moe."): tensor
|
||||
for name, tensor in state_dict.items()
|
||||
}
|
||||
|
||||
def find_hf_name_in_tensor_map(hf_name: str) -> str | None:
|
||||
"""
|
||||
Map HuggingFace parameter name to GGUF tensor name.
|
||||
@@ -277,9 +333,10 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
model_name_or_path: str,
|
||||
gguf_to_hf_name_map: dict[str, str],
|
||||
) -> dict[str, str]:
|
||||
weight_type_map = get_gguf_weight_type_map(
|
||||
model_name_or_path, gguf_to_hf_name_map
|
||||
)
|
||||
gguf_files = self._get_all_gguf_files(model_name_or_path)
|
||||
weight_type_map = {}
|
||||
for f in gguf_files:
|
||||
weight_type_map.update(get_gguf_weight_type_map(f, gguf_to_hf_name_map))
|
||||
is_multimodal = hasattr(model_config.hf_config, "vision_config")
|
||||
if is_multimodal:
|
||||
mmproj_file = detect_gguf_multimodal(model_name_or_path)
|
||||
@@ -321,7 +378,15 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
)
|
||||
yield from gguf_quant_weights_iterator(mmproj_file, gguf_to_hf_name_map)
|
||||
|
||||
yield from gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map)
|
||||
gguf_files = self._get_all_gguf_files(model_name_or_path)
|
||||
if len(gguf_files) > 1:
|
||||
yield from gguf_quant_weights_iterator_multi(
|
||||
gguf_files, gguf_to_hf_name_map
|
||||
)
|
||||
else:
|
||||
yield from gguf_quant_weights_iterator(
|
||||
model_name_or_path, gguf_to_hf_name_map
|
||||
)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config)
|
||||
@@ -340,9 +405,11 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
local_model_path = self._prepare_weights(model_config)
|
||||
gguf_weights_map = self._get_gguf_weights_map(model_config)
|
||||
# we can only know if tie word embeddings after mapping weights
|
||||
if "lm_head.weight" in get_gguf_extra_tensor_names(
|
||||
local_model_path, gguf_weights_map
|
||||
):
|
||||
gguf_files = self._get_all_gguf_files(local_model_path)
|
||||
all_extra_names = []
|
||||
for f in gguf_files:
|
||||
all_extra_names.extend(get_gguf_extra_tensor_names(f, gguf_weights_map))
|
||||
if "lm_head.weight" in all_extra_names:
|
||||
model_config.hf_config.update({"tie_word_embeddings": True})
|
||||
|
||||
weight_type_map = self._get_gguf_weight_type(
|
||||
|
||||
@@ -1222,6 +1222,49 @@ def gguf_quant_weights_iterator(
|
||||
yield name, param
|
||||
|
||||
|
||||
def gguf_quant_weights_iterator_multi(
|
||||
gguf_files: list[str], gguf_to_hf_name_map: dict[str, str]
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""
|
||||
Iterate over the quant weights across multiple GGUF shard files
|
||||
and convert them to torch tensors.
|
||||
|
||||
Like gguf_quant_weights_iterator, we yield all weight types first
|
||||
before yielding any weights data to avoid issues with packed layers
|
||||
that have different quant types.
|
||||
"""
|
||||
readers = [gguf.GGUFReader(f) for f in gguf_files]
|
||||
|
||||
# First pass: yield all weight types across all shards
|
||||
for reader in readers:
|
||||
for tensor in reader.tensors:
|
||||
if tensor.name in gguf_to_hf_name_map:
|
||||
weight_type = tensor.tensor_type
|
||||
name = gguf_to_hf_name_map[tensor.name]
|
||||
if weight_type.name not in ("F32", "BF16", "F16"):
|
||||
weight_type_name = name.replace("weight", "qweight_type")
|
||||
weight_type = torch.tensor(weight_type)
|
||||
yield weight_type_name, weight_type
|
||||
|
||||
# Second pass: yield all weight data across all shards
|
||||
for reader in readers:
|
||||
for tensor in reader.tensors:
|
||||
if tensor.name in gguf_to_hf_name_map:
|
||||
weight = tensor.data
|
||||
weight_type = tensor.tensor_type
|
||||
name = gguf_to_hf_name_map[tensor.name]
|
||||
if weight_type.name not in ("F32", "BF16", "F16"):
|
||||
name = name.replace("weight", "qweight")
|
||||
if weight_type.name == "BF16" and tensor.data.dtype == np.uint8:
|
||||
weight = weight.view(np.uint16)
|
||||
if reader.byte_order == "S":
|
||||
weight = weight.byteswap()
|
||||
param = torch.tensor(weight).view(torch.bfloat16)
|
||||
else:
|
||||
param = torch.tensor(weight)
|
||||
yield name, param
|
||||
|
||||
|
||||
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
|
||||
"""convert PySafeSlice object from safetensors to torch.Tensor
|
||||
|
||||
|
||||
@@ -331,7 +331,7 @@ class MiniMaxM2Model(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=None,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.embed_tokens",
|
||||
)
|
||||
else:
|
||||
@@ -518,7 +518,10 @@ class MiniMaxM2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=None
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
Reference in New Issue
Block a user