[Gemma4] Support quantized MoE (#39045)
Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com>
This commit is contained in:
@@ -1248,21 +1248,27 @@ class Gemma4Model(nn.Module):
|
|||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
]
|
]
|
||||||
|
|
||||||
# MoE expert weight mapping: checkpoint 3D packed tensors are
|
# MoE expert weight mapping: checkpoint can have either:
|
||||||
# exploded in _weight_iterator to per-expert 2D weights like:
|
# 1. 3D packed tensors (exploded in _weight_iterator to per-expert 2D)
|
||||||
|
# 2. Already per-expert 2D weights (if quantized)
|
||||||
|
# Map to FusedMoE parameters:
|
||||||
# moe.experts.{id}.gate_proj → FusedMoE w1 (shard of w13)
|
# moe.experts.{id}.gate_proj → FusedMoE w1 (shard of w13)
|
||||||
# moe.experts.{id}.up_proj → FusedMoE w3 (shard of w13)
|
# moe.experts.{id}.up_proj → FusedMoE w3 (shard of w13)
|
||||||
# moe.experts.{id}.down_proj → FusedMoE w2
|
# moe.experts.{id}.down_proj → FusedMoE w2
|
||||||
# We build the mapping directly since Gemma4 uses bare param
|
#
|
||||||
# names (no .weight suffix) unlike standard MoE checkpoints.
|
# Use prefix matching to handle both weights and
|
||||||
|
# quantization scale parameters. The param_name is a prefix ending
|
||||||
|
# in underscore, and weight_name ends with a dot, so that:
|
||||||
|
# "experts.0.gate_proj.weight_scale" -> "experts.w13_weight_scale"
|
||||||
|
# "experts.0.gate_proj.weight" -> "experts.w13_weight"
|
||||||
num_experts = getattr(self.config, "num_experts", None) or 0
|
num_experts = getattr(self.config, "num_experts", None) or 0
|
||||||
expert_params_mapping = [
|
expert_params_mapping = [
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
(
|
(
|
||||||
"experts.w13_weight"
|
"experts.w13_"
|
||||||
if proj_name in ["gate_proj", "up_proj"]
|
if proj_name in ["gate_proj", "up_proj"]
|
||||||
else "experts.w2_weight",
|
else "experts.w2_",
|
||||||
f"experts.{expert_id}.{proj_name}",
|
f"experts.{expert_id}.{proj_name}.",
|
||||||
expert_id,
|
expert_id,
|
||||||
shard_id,
|
shard_id,
|
||||||
)
|
)
|
||||||
@@ -1322,9 +1328,21 @@ class Gemma4Model(nn.Module):
|
|||||||
expert_id,
|
expert_id,
|
||||||
shard_id,
|
shard_id,
|
||||||
) in expert_params_mapping:
|
) in expert_params_mapping:
|
||||||
if weight_name not in name:
|
# Match both:
|
||||||
|
# - Bare weights: "experts.0.down_proj" (from 3D explosion)
|
||||||
|
# - With suffix: "experts.0.down_proj.weight_scale" (2D quantized)
|
||||||
|
# weight_name has trailing dot, so check with and without it
|
||||||
|
weight_name_base = weight_name.rstrip(".")
|
||||||
|
if weight_name in name:
|
||||||
|
# Has suffix (e.g., .weight_scale)
|
||||||
|
moe_name = name.replace(weight_name, param_name)
|
||||||
|
elif name.endswith(weight_name_base):
|
||||||
|
# Bare weight (no suffix)
|
||||||
|
moe_name = name.replace(
|
||||||
|
weight_name_base, param_name.rstrip("_") + "_weight"
|
||||||
|
)
|
||||||
|
else:
|
||||||
continue
|
continue
|
||||||
moe_name = name.replace(weight_name, param_name)
|
|
||||||
if moe_name not in params_dict:
|
if moe_name not in params_dict:
|
||||||
continue
|
continue
|
||||||
if is_pp_missing_parameter(moe_name, self):
|
if is_pp_missing_parameter(moe_name, self):
|
||||||
@@ -1334,15 +1352,12 @@ class Gemma4Model(nn.Module):
|
|||||||
# orientation for FusedMoE after _weight_iterator:
|
# orientation for FusedMoE after _weight_iterator:
|
||||||
# gate/up: [I, H] → w1/w3 expects [I, H]
|
# gate/up: [I, H] → w1/w3 expects [I, H]
|
||||||
# down: [H, I] → w2 expects [H, I]
|
# down: [H, I] → w2 expects [H, I]
|
||||||
assert loaded_weight.dim() == 2, (
|
# Scales and other quantization params may be 1D or scalar.
|
||||||
f"Expected 2D expert weight for {weight_name}, "
|
|
||||||
f"got shape {loaded_weight.shape}"
|
|
||||||
)
|
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(
|
weight_loader(
|
||||||
param,
|
param,
|
||||||
loaded_weight,
|
loaded_weight,
|
||||||
weight_name + ".weight",
|
moe_name, # Pass mapped name (handles both weights and scales)
|
||||||
shard_id=shard_id,
|
shard_id=shard_id,
|
||||||
expert_id=expert_id,
|
expert_id=expert_id,
|
||||||
)
|
)
|
||||||
@@ -1499,6 +1514,11 @@ class Gemma4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
|
|||||||
".moe.down_proj",
|
".moe.down_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Remap individual 2D expert weights:
|
||||||
|
# .experts.{id}.{proj} → .moe.experts.{id}.{proj}
|
||||||
|
# (This handles per-expert 2D quantized weights)
|
||||||
|
name = re.sub(r"\.experts\.(\d+)\.", r".moe.experts.\1.", name)
|
||||||
|
|
||||||
# MoE expert weights: checkpoint stores as 3D packed
|
# MoE expert weights: checkpoint stores as 3D packed
|
||||||
# tensors. Explode into per-expert 2D weights for
|
# tensors. Explode into per-expert 2D weights for
|
||||||
# FusedMoE weight_loader.
|
# FusedMoE weight_loader.
|
||||||
|
|||||||
Reference in New Issue
Block a user