[Gemma4] Support quantized MoE (#39045)

Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com>
This commit is contained in:
Dipika Sikka
2026-04-08 21:57:53 -04:00
committed by GitHub
parent eb4205fee5
commit 3aecdf08b4

View File

@@ -1248,21 +1248,27 @@ class Gemma4Model(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
# MoE expert weight mapping: checkpoint 3D packed tensors are # MoE expert weight mapping: checkpoint can have either:
# exploded in _weight_iterator to per-expert 2D weights like: # 1. 3D packed tensors (exploded in _weight_iterator to per-expert 2D)
# 2. Already per-expert 2D weights (if quantized)
# Map to FusedMoE parameters:
# moe.experts.{id}.gate_proj → FusedMoE w1 (shard of w13) # moe.experts.{id}.gate_proj → FusedMoE w1 (shard of w13)
# moe.experts.{id}.up_proj → FusedMoE w3 (shard of w13) # moe.experts.{id}.up_proj → FusedMoE w3 (shard of w13)
# moe.experts.{id}.down_proj → FusedMoE w2 # moe.experts.{id}.down_proj → FusedMoE w2
# We build the mapping directly since Gemma4 uses bare param #
# names (no .weight suffix) unlike standard MoE checkpoints. # Use prefix matching to handle both weights and
# quantization scale parameters. The param_name is a prefix ending
# in underscore, and weight_name ends with a dot, so that:
# "experts.0.gate_proj.weight_scale" -> "experts.w13_weight_scale"
# "experts.0.gate_proj.weight" -> "experts.w13_weight"
num_experts = getattr(self.config, "num_experts", None) or 0 num_experts = getattr(self.config, "num_experts", None) or 0
expert_params_mapping = [ expert_params_mapping = [
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
( (
"experts.w13_weight" "experts.w13_"
if proj_name in ["gate_proj", "up_proj"] if proj_name in ["gate_proj", "up_proj"]
else "experts.w2_weight", else "experts.w2_",
f"experts.{expert_id}.{proj_name}", f"experts.{expert_id}.{proj_name}.",
expert_id, expert_id,
shard_id, shard_id,
) )
@@ -1322,9 +1328,21 @@ class Gemma4Model(nn.Module):
expert_id, expert_id,
shard_id, shard_id,
) in expert_params_mapping: ) in expert_params_mapping:
if weight_name not in name: # Match both:
# - Bare weights: "experts.0.down_proj" (from 3D explosion)
# - With suffix: "experts.0.down_proj.weight_scale" (2D quantized)
# weight_name has trailing dot, so check with and without it
weight_name_base = weight_name.rstrip(".")
if weight_name in name:
# Has suffix (e.g., .weight_scale)
moe_name = name.replace(weight_name, param_name)
elif name.endswith(weight_name_base):
# Bare weight (no suffix)
moe_name = name.replace(
weight_name_base, param_name.rstrip("_") + "_weight"
)
else:
continue continue
moe_name = name.replace(weight_name, param_name)
if moe_name not in params_dict: if moe_name not in params_dict:
continue continue
if is_pp_missing_parameter(moe_name, self): if is_pp_missing_parameter(moe_name, self):
@@ -1334,15 +1352,12 @@ class Gemma4Model(nn.Module):
# orientation for FusedMoE after _weight_iterator: # orientation for FusedMoE after _weight_iterator:
# gate/up: [I, H] → w1/w3 expects [I, H] # gate/up: [I, H] → w1/w3 expects [I, H]
# down: [H, I] → w2 expects [H, I] # down: [H, I] → w2 expects [H, I]
assert loaded_weight.dim() == 2, ( # Scales and other quantization params may be 1D or scalar.
f"Expected 2D expert weight for {weight_name}, "
f"got shape {loaded_weight.shape}"
)
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader( weight_loader(
param, param,
loaded_weight, loaded_weight,
weight_name + ".weight", moe_name, # Pass mapped name (handles both weights and scales)
shard_id=shard_id, shard_id=shard_id,
expert_id=expert_id, expert_id=expert_id,
) )
@@ -1499,6 +1514,11 @@ class Gemma4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
".moe.down_proj", ".moe.down_proj",
) )
# Remap individual 2D expert weights:
# .experts.{id}.{proj} → .moe.experts.{id}.{proj}
# (This handles per-expert 2D quantized weights)
name = re.sub(r"\.experts\.(\d+)\.", r".moe.experts.\1.", name)
# MoE expert weights: checkpoint stores as 3D packed # MoE expert weights: checkpoint stores as 3D packed
# tensors. Explode into per-expert 2D weights for # tensors. Explode into per-expert 2D weights for
# FusedMoE weight_loader. # FusedMoE weight_loader.