diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 5f66716d5..462d18c98 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -150,6 +150,7 @@ class LlamaModel(nn.Module): self.use_aux_hidden_state = eagle_config["use_aux_hidden_state"] else: self.use_aux_hidden_state = True + self.norm_before_fc = getattr(self.config, "norm_before_fc", False) current_vllm_config = get_current_vllm_config() @@ -175,6 +176,13 @@ class LlamaModel(nn.Module): fc_input_size = self.config.target_hidden_size * 3 else: fc_input_size = self.config.hidden_size * 3 + if self.norm_before_fc: + self.input_norm = RMSNorm( + fc_input_size, + eps=self.config.rms_norm_eps, + ) + else: + self.input_norm = None self.fc = ReplicatedLinear( input_size=fc_input_size, output_size=self.config.hidden_size, @@ -357,6 +365,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): if not self.model.use_aux_hidden_state: return hidden_states # combine multiple auxiliary hidden states returned by eagle3 + + if self.model.norm_before_fc: + hidden_states = self.model.input_norm(hidden_states) return self.model.fc(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): @@ -403,6 +414,8 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): skip_substrs.append("embed_tokens") if not self.model.use_aux_hidden_state: skip_substrs.append("fc.") + if not self.model.norm_before_fc: + skip_substrs.append("input_norm.") loader = AutoWeightsLoader( self, skip_prefixes=None,