[Model] Add NVFP4 quantization support for Step3.5-Flash (#34478)

Signed-off-by: tacos8me <ian@cloudhabit.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
tacos8me
2026-02-22 14:30:46 -05:00
committed by GitHub
parent 682566b18e
commit b7892a3bef
5 changed files with 204 additions and 4 deletions

View File

@@ -2,7 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Jurassic model."""
from collections.abc import Iterable
import typing
from collections.abc import Callable, Iterable
from typing import Any
import torch
@@ -231,6 +232,7 @@ class Step3p5Attention(nn.Module):
hidden_size,
self.total_num_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.g_proj",
)
@@ -640,12 +642,22 @@ class Step3p5Model(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
# Old packed 3D format: .moe.gate_proj.weight [num_experts, out, in]
expert_params_mapping = [
(".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"),
(".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"),
(".moe.experts.w2_weight", ".moe.down_proj.weight", "w2"),
]
# New per-expert format: .moe.experts.E.gate_proj.weight_packed [out, in]
per_expert_mapping = FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.moe_num_experts,
)
disable_moe_stacked_params = [data[1] for data in expert_params_mapping]
for name, loaded_weight in weights:
@@ -668,6 +680,54 @@ class Step3p5Model(nn.Module):
if layer_idx >= config.num_hidden_layers:
continue
# Per-expert MoE weights (new format from LLM Compressor):
# .moe.experts.{E}.{gate,up,down}_proj.{weight_packed,scale,...}
# Each weight is individual per-expert, not stacked 3D.
if ".moe.experts." in local_name:
is_expert_weight = False
for mapping in per_expert_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in local_name:
continue
is_expert_weight = True
name_mapped = local_name.replace(weight_name, param_name)
if is_pp_missing_parameter(name_mapped, self):
continue
if name_mapped not in params_dict:
continue
param = params_dict[name_mapped]
weight_loader = typing.cast(
Callable[..., bool], param.weight_loader
)
success = weight_loader(
param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
loaded_params.add(name_mapped)
break
else:
if (
not is_expert_weight
and not is_pp_missing_parameter(local_name, self)
and local_name in params_dict
):
# Not an expert proj — use default loader
# (e.g. share_expert weights if they matched)
param = params_dict[local_name]
weight_loader = getattr(
param,
"weight_loader",
default_weight_loader,
)
weight_loader(param, loaded_weight)
loaded_params.add(local_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in local_name:
continue
@@ -703,6 +763,16 @@ class Step3p5Model(nn.Module):
param = params_dict[replaced_name]
weight_loader = param.weight_loader
moe_expert_num = self.moe_num_experts
# Per-tensor global scales (e.g. weight_global_scale)
# have shape [1] in compressed-tensors NVFP4 checkpoints.
# Expand to per-expert before the iteration loop.
if (
loaded_weight.shape[0] == 1
and loaded_weight.shape[0] != moe_expert_num
):
loaded_weight = loaded_weight.expand(
moe_expert_num, *loaded_weight.shape[1:]
)
assert loaded_weight.shape[0] == moe_expert_num
for expert_id in range(moe_expert_num):
loaded_weight_expert = loaded_weight[expert_id]