[Speculative Decoding] Add norm_before_fc for gpt-oss draft models (#36545)

Signed-off-by: Shubhra Pandit <shubhra.pandit@gmail.com>
Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com>
Co-authored-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Shubhra Pandit
2026-03-12 19:03:32 -04:00
committed by GitHub
parent a79c1c2c80
commit 87985077a4

View File

@@ -150,6 +150,7 @@ class LlamaModel(nn.Module):
self.use_aux_hidden_state = eagle_config["use_aux_hidden_state"]
else:
self.use_aux_hidden_state = True
self.norm_before_fc = getattr(self.config, "norm_before_fc", False)
current_vllm_config = get_current_vllm_config()
@@ -175,6 +176,13 @@ class LlamaModel(nn.Module):
fc_input_size = self.config.target_hidden_size * 3
else:
fc_input_size = self.config.hidden_size * 3
if self.norm_before_fc:
self.input_norm = RMSNorm(
fc_input_size,
eps=self.config.rms_norm_eps,
)
else:
self.input_norm = None
self.fc = ReplicatedLinear(
input_size=fc_input_size,
output_size=self.config.hidden_size,
@@ -357,6 +365,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
if not self.model.use_aux_hidden_state:
return hidden_states
# combine multiple auxiliary hidden states returned by eagle3
if self.model.norm_before_fc:
hidden_states = self.model.input_norm(hidden_states)
return self.model.fc(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
@@ -403,6 +414,8 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
skip_substrs.append("embed_tokens")
if not self.model.use_aux_hidden_state:
skip_substrs.append("fc.")
if not self.model.norm_before_fc:
skip_substrs.append("input_norm.")
loader = AutoWeightsLoader(
self,
skip_prefixes=None,