[Gemma4] Support quantized MoE (#39045)
Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com>
This commit is contained in:
@@ -1248,21 +1248,27 @@ class Gemma4Model(nn.Module):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
# MoE expert weight mapping: checkpoint 3D packed tensors are
|
||||
# exploded in _weight_iterator to per-expert 2D weights like:
|
||||
# MoE expert weight mapping: checkpoint can have either:
|
||||
# 1. 3D packed tensors (exploded in _weight_iterator to per-expert 2D)
|
||||
# 2. Already per-expert 2D weights (if quantized)
|
||||
# Map to FusedMoE parameters:
|
||||
# moe.experts.{id}.gate_proj → FusedMoE w1 (shard of w13)
|
||||
# moe.experts.{id}.up_proj → FusedMoE w3 (shard of w13)
|
||||
# moe.experts.{id}.down_proj → FusedMoE w2
|
||||
# We build the mapping directly since Gemma4 uses bare param
|
||||
# names (no .weight suffix) unlike standard MoE checkpoints.
|
||||
#
|
||||
# Use prefix matching to handle both weights and
|
||||
# quantization scale parameters. The param_name is a prefix ending
|
||||
# in underscore, and weight_name ends with a dot, so that:
|
||||
# "experts.0.gate_proj.weight_scale" -> "experts.w13_weight_scale"
|
||||
# "experts.0.gate_proj.weight" -> "experts.w13_weight"
|
||||
num_experts = getattr(self.config, "num_experts", None) or 0
|
||||
expert_params_mapping = [
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
(
|
||||
"experts.w13_weight"
|
||||
"experts.w13_"
|
||||
if proj_name in ["gate_proj", "up_proj"]
|
||||
else "experts.w2_weight",
|
||||
f"experts.{expert_id}.{proj_name}",
|
||||
else "experts.w2_",
|
||||
f"experts.{expert_id}.{proj_name}.",
|
||||
expert_id,
|
||||
shard_id,
|
||||
)
|
||||
@@ -1322,9 +1328,21 @@ class Gemma4Model(nn.Module):
|
||||
expert_id,
|
||||
shard_id,
|
||||
) in expert_params_mapping:
|
||||
if weight_name not in name:
|
||||
# Match both:
|
||||
# - Bare weights: "experts.0.down_proj" (from 3D explosion)
|
||||
# - With suffix: "experts.0.down_proj.weight_scale" (2D quantized)
|
||||
# weight_name has trailing dot, so check with and without it
|
||||
weight_name_base = weight_name.rstrip(".")
|
||||
if weight_name in name:
|
||||
# Has suffix (e.g., .weight_scale)
|
||||
moe_name = name.replace(weight_name, param_name)
|
||||
elif name.endswith(weight_name_base):
|
||||
# Bare weight (no suffix)
|
||||
moe_name = name.replace(
|
||||
weight_name_base, param_name.rstrip("_") + "_weight"
|
||||
)
|
||||
else:
|
||||
continue
|
||||
moe_name = name.replace(weight_name, param_name)
|
||||
if moe_name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(moe_name, self):
|
||||
@@ -1334,15 +1352,12 @@ class Gemma4Model(nn.Module):
|
||||
# orientation for FusedMoE after _weight_iterator:
|
||||
# gate/up: [I, H] → w1/w3 expects [I, H]
|
||||
# down: [H, I] → w2 expects [H, I]
|
||||
assert loaded_weight.dim() == 2, (
|
||||
f"Expected 2D expert weight for {weight_name}, "
|
||||
f"got shape {loaded_weight.shape}"
|
||||
)
|
||||
# Scales and other quantization params may be 1D or scalar.
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
weight_name + ".weight",
|
||||
moe_name, # Pass mapped name (handles both weights and scales)
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
@@ -1499,6 +1514,11 @@ class Gemma4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
|
||||
".moe.down_proj",
|
||||
)
|
||||
|
||||
# Remap individual 2D expert weights:
|
||||
# .experts.{id}.{proj} → .moe.experts.{id}.{proj}
|
||||
# (This handles per-expert 2D quantized weights)
|
||||
name = re.sub(r"\.experts\.(\d+)\.", r".moe.experts.\1.", name)
|
||||
|
||||
# MoE expert weights: checkpoint stores as 3D packed
|
||||
# tensors. Explode into per-expert 2D weights for
|
||||
# FusedMoE weight_loader.
|
||||
|
||||
Reference in New Issue
Block a user