[Quantization] Support FP8 MoE bias for models like GPT-OSS (#34906)

Signed-off-by: jasperjiaguo <jasperg662@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Jia Guo
2026-02-23 19:07:47 -08:00
committed by GitHub
parent 2ff4e51152
commit ec85340531

View File

@@ -758,6 +758,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# BIASES (for models like GPT-OSS that have biased MoE)
if self.moe.has_bias:
w13_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
dtype=layer.orig_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
w2_bias = torch.nn.Parameter(
torch.zeros(num_experts, hidden_size, dtype=layer.orig_dtype),
requires_grad=False,
)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
# WEIGHT_SCALES
if not self.block_quant:
# For per-tensor quant, the scales are per expert and weight.
@@ -939,7 +958,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
a1_scale = layer.w13_input_scale
a2_scale = layer.w2_input_scale
return make_fp8_moe_quant_config(
quant_config = make_fp8_moe_quant_config(
fp8_backend=self.fp8_backend,
w1_scale=w1_scale,
w2_scale=w2_scale,
@@ -948,6 +967,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
block_shape=self.weight_block_size,
)
# Inject biases into the quant config if the model has them
# (e.g. GPT-OSS biased MoE)
if quant_config is not None and self.moe.has_bias:
w13_bias = getattr(layer, "w13_bias", None)
w2_bias = getattr(layer, "w2_bias", None)
if w13_bias is not None:
quant_config._w1.bias = w13_bias
if w2_bias is not None:
quant_config._w2.bias = w2_bias
return quant_config
@property
def supports_eplb(self) -> bool:
return True
@@ -1168,6 +1199,28 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
# stash the correct device for `patched_weight_loader`
layer._load_device = torch.get_default_device()
# BIASES (for models like GPT-OSS that have biased MoE)
if self.moe.has_bias:
# Use the original weight_loader (not patched) for biases
orig_extra_weight_attrs = dict(extra_weight_attrs)
orig_extra_weight_attrs["weight_loader"] = weight_loader
w13_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
dtype=layer.orig_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, orig_extra_weight_attrs)
w2_bias = torch.nn.Parameter(
torch.zeros(num_experts, hidden_size, dtype=layer.orig_dtype),
requires_grad=False,
)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, orig_extra_weight_attrs)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.