[Speculative Decoding] Add norm_before_fc for gpt-oss draft models (#36545)
Signed-off-by: Shubhra Pandit <shubhra.pandit@gmail.com> Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
@@ -150,6 +150,7 @@ class LlamaModel(nn.Module):
|
||||
self.use_aux_hidden_state = eagle_config["use_aux_hidden_state"]
|
||||
else:
|
||||
self.use_aux_hidden_state = True
|
||||
self.norm_before_fc = getattr(self.config, "norm_before_fc", False)
|
||||
|
||||
current_vllm_config = get_current_vllm_config()
|
||||
|
||||
@@ -175,6 +176,13 @@ class LlamaModel(nn.Module):
|
||||
fc_input_size = self.config.target_hidden_size * 3
|
||||
else:
|
||||
fc_input_size = self.config.hidden_size * 3
|
||||
if self.norm_before_fc:
|
||||
self.input_norm = RMSNorm(
|
||||
fc_input_size,
|
||||
eps=self.config.rms_norm_eps,
|
||||
)
|
||||
else:
|
||||
self.input_norm = None
|
||||
self.fc = ReplicatedLinear(
|
||||
input_size=fc_input_size,
|
||||
output_size=self.config.hidden_size,
|
||||
@@ -357,6 +365,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
if not self.model.use_aux_hidden_state:
|
||||
return hidden_states
|
||||
# combine multiple auxiliary hidden states returned by eagle3
|
||||
|
||||
if self.model.norm_before_fc:
|
||||
hidden_states = self.model.input_norm(hidden_states)
|
||||
return self.model.fc(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
@@ -403,6 +414,8 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
skip_substrs.append("embed_tokens")
|
||||
if not self.model.use_aux_hidden_state:
|
||||
skip_substrs.append("fc.")
|
||||
if not self.model.norm_before_fc:
|
||||
skip_substrs.append("input_norm.")
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=None,
|
||||
|
||||
Reference in New Issue
Block a user