[Quantization] Support FP8 MoE bias for models like GPT-OSS (#34906)
Signed-off-by: jasperjiaguo <jasperg662@gmail.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -758,6 +758,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# BIASES (for models like GPT-OSS that have biased MoE)
|
||||
if self.moe.has_bias:
|
||||
w13_bias = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=layer.orig_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_bias", w13_bias)
|
||||
set_weight_attrs(w13_bias, extra_weight_attrs)
|
||||
w2_bias = torch.nn.Parameter(
|
||||
torch.zeros(num_experts, hidden_size, dtype=layer.orig_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_bias", w2_bias)
|
||||
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
if not self.block_quant:
|
||||
# For per-tensor quant, the scales are per expert and weight.
|
||||
@@ -939,7 +958,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
a1_scale = layer.w13_input_scale
|
||||
a2_scale = layer.w2_input_scale
|
||||
|
||||
return make_fp8_moe_quant_config(
|
||||
quant_config = make_fp8_moe_quant_config(
|
||||
fp8_backend=self.fp8_backend,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
@@ -948,6 +967,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
block_shape=self.weight_block_size,
|
||||
)
|
||||
|
||||
# Inject biases into the quant config if the model has them
|
||||
# (e.g. GPT-OSS biased MoE)
|
||||
if quant_config is not None and self.moe.has_bias:
|
||||
w13_bias = getattr(layer, "w13_bias", None)
|
||||
w2_bias = getattr(layer, "w2_bias", None)
|
||||
if w13_bias is not None:
|
||||
quant_config._w1.bias = w13_bias
|
||||
if w2_bias is not None:
|
||||
quant_config._w2.bias = w2_bias
|
||||
|
||||
return quant_config
|
||||
|
||||
@property
|
||||
def supports_eplb(self) -> bool:
|
||||
return True
|
||||
@@ -1168,6 +1199,28 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
# stash the correct device for `patched_weight_loader`
|
||||
layer._load_device = torch.get_default_device()
|
||||
|
||||
# BIASES (for models like GPT-OSS that have biased MoE)
|
||||
if self.moe.has_bias:
|
||||
# Use the original weight_loader (not patched) for biases
|
||||
orig_extra_weight_attrs = dict(extra_weight_attrs)
|
||||
orig_extra_weight_attrs["weight_loader"] = weight_loader
|
||||
w13_bias = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=layer.orig_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_bias", w13_bias)
|
||||
set_weight_attrs(w13_bias, orig_extra_weight_attrs)
|
||||
w2_bias = torch.nn.Parameter(
|
||||
torch.zeros(num_experts, hidden_size, dtype=layer.orig_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_bias", w2_bias)
|
||||
set_weight_attrs(w2_bias, orig_extra_weight_attrs)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
|
||||
Reference in New Issue
Block a user