[Speculative Decoding] Add norm_before_fc for gpt-oss draft models (#36545)
Signed-off-by: Shubhra Pandit <shubhra.pandit@gmail.com> Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
@@ -150,6 +150,7 @@ class LlamaModel(nn.Module):
|
|||||||
self.use_aux_hidden_state = eagle_config["use_aux_hidden_state"]
|
self.use_aux_hidden_state = eagle_config["use_aux_hidden_state"]
|
||||||
else:
|
else:
|
||||||
self.use_aux_hidden_state = True
|
self.use_aux_hidden_state = True
|
||||||
|
self.norm_before_fc = getattr(self.config, "norm_before_fc", False)
|
||||||
|
|
||||||
current_vllm_config = get_current_vllm_config()
|
current_vllm_config = get_current_vllm_config()
|
||||||
|
|
||||||
@@ -175,6 +176,13 @@ class LlamaModel(nn.Module):
|
|||||||
fc_input_size = self.config.target_hidden_size * 3
|
fc_input_size = self.config.target_hidden_size * 3
|
||||||
else:
|
else:
|
||||||
fc_input_size = self.config.hidden_size * 3
|
fc_input_size = self.config.hidden_size * 3
|
||||||
|
if self.norm_before_fc:
|
||||||
|
self.input_norm = RMSNorm(
|
||||||
|
fc_input_size,
|
||||||
|
eps=self.config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.input_norm = None
|
||||||
self.fc = ReplicatedLinear(
|
self.fc = ReplicatedLinear(
|
||||||
input_size=fc_input_size,
|
input_size=fc_input_size,
|
||||||
output_size=self.config.hidden_size,
|
output_size=self.config.hidden_size,
|
||||||
@@ -357,6 +365,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
|||||||
if not self.model.use_aux_hidden_state:
|
if not self.model.use_aux_hidden_state:
|
||||||
return hidden_states
|
return hidden_states
|
||||||
# combine multiple auxiliary hidden states returned by eagle3
|
# combine multiple auxiliary hidden states returned by eagle3
|
||||||
|
|
||||||
|
if self.model.norm_before_fc:
|
||||||
|
hidden_states = self.model.input_norm(hidden_states)
|
||||||
return self.model.fc(hidden_states)
|
return self.model.fc(hidden_states)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
@@ -403,6 +414,8 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
|||||||
skip_substrs.append("embed_tokens")
|
skip_substrs.append("embed_tokens")
|
||||||
if not self.model.use_aux_hidden_state:
|
if not self.model.use_aux_hidden_state:
|
||||||
skip_substrs.append("fc.")
|
skip_substrs.append("fc.")
|
||||||
|
if not self.model.norm_before_fc:
|
||||||
|
skip_substrs.append("input_norm.")
|
||||||
loader = AutoWeightsLoader(
|
loader = AutoWeightsLoader(
|
||||||
self,
|
self,
|
||||||
skip_prefixes=None,
|
skip_prefixes=None,
|
||||||
|
|||||||
Reference in New Issue
Block a user