From 3aecdf08b4a896a92e2cbd11c3d5a83d3c09abc1 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 8 Apr 2026 21:57:53 -0400 Subject: [PATCH] [Gemma4] Support quantized MoE (#39045) Signed-off-by: Dipika Sikka --- vllm/model_executor/models/gemma4.py | 48 ++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/models/gemma4.py b/vllm/model_executor/models/gemma4.py index 2e9fc6819..c41c0de52 100644 --- a/vllm/model_executor/models/gemma4.py +++ b/vllm/model_executor/models/gemma4.py @@ -1248,21 +1248,27 @@ class Gemma4Model(nn.Module): ("gate_up_proj", "up_proj", 1), ] - # MoE expert weight mapping: checkpoint 3D packed tensors are - # exploded in _weight_iterator to per-expert 2D weights like: + # MoE expert weight mapping: checkpoint can have either: + # 1. 3D packed tensors (exploded in _weight_iterator to per-expert 2D) + # 2. Already per-expert 2D weights (if quantized) + # Map to FusedMoE parameters: # moe.experts.{id}.gate_proj → FusedMoE w1 (shard of w13) # moe.experts.{id}.up_proj → FusedMoE w3 (shard of w13) # moe.experts.{id}.down_proj → FusedMoE w2 - # We build the mapping directly since Gemma4 uses bare param - # names (no .weight suffix) unlike standard MoE checkpoints. + # + # Use prefix matching to handle both weights and + # quantization scale parameters. The param_name is a prefix ending + # in underscore, and weight_name ends with a dot, so that: + # "experts.0.gate_proj.weight_scale" -> "experts.w13_weight_scale" + # "experts.0.gate_proj.weight" -> "experts.w13_weight" num_experts = getattr(self.config, "num_experts", None) or 0 expert_params_mapping = [ # (param_name, weight_name, expert_id, shard_id) ( - "experts.w13_weight" + "experts.w13_" if proj_name in ["gate_proj", "up_proj"] - else "experts.w2_weight", - f"experts.{expert_id}.{proj_name}", + else "experts.w2_", + f"experts.{expert_id}.{proj_name}.", expert_id, shard_id, ) @@ -1322,9 +1328,21 @@ class Gemma4Model(nn.Module): expert_id, shard_id, ) in expert_params_mapping: - if weight_name not in name: + # Match both: + # - Bare weights: "experts.0.down_proj" (from 3D explosion) + # - With suffix: "experts.0.down_proj.weight_scale" (2D quantized) + # weight_name has trailing dot, so check with and without it + weight_name_base = weight_name.rstrip(".") + if weight_name in name: + # Has suffix (e.g., .weight_scale) + moe_name = name.replace(weight_name, param_name) + elif name.endswith(weight_name_base): + # Bare weight (no suffix) + moe_name = name.replace( + weight_name_base, param_name.rstrip("_") + "_weight" + ) + else: continue - moe_name = name.replace(weight_name, param_name) if moe_name not in params_dict: continue if is_pp_missing_parameter(moe_name, self): @@ -1334,15 +1352,12 @@ class Gemma4Model(nn.Module): # orientation for FusedMoE after _weight_iterator: # gate/up: [I, H] → w1/w3 expects [I, H] # down: [H, I] → w2 expects [H, I] - assert loaded_weight.dim() == 2, ( - f"Expected 2D expert weight for {weight_name}, " - f"got shape {loaded_weight.shape}" - ) + # Scales and other quantization params may be 1D or scalar. weight_loader = param.weight_loader weight_loader( param, loaded_weight, - weight_name + ".weight", + moe_name, # Pass mapped name (handles both weights and scales) shard_id=shard_id, expert_id=expert_id, ) @@ -1499,6 +1514,11 @@ class Gemma4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): ".moe.down_proj", ) + # Remap individual 2D expert weights: + # .experts.{id}.{proj} → .moe.experts.{id}.{proj} + # (This handles per-expert 2D quantized weights) + name = re.sub(r"\.experts\.(\d+)\.", r".moe.experts.\1.", name) + # MoE expert weights: checkpoint stores as 3D packed # tensors. Explode into per-expert 2D weights for # FusedMoE weight_loader.