Fix import that was moved in Transformers 5.2.0 (#36120)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -516,8 +516,11 @@ class Base(
|
||||
)
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.check_version("5.0.0", "Eagle3 support")
|
||||
from transformers.utils.generic import OutputRecorder
|
||||
self.check_version("5.2.0", "Eagle3 support")
|
||||
from transformers.utils.output_capturing import (
|
||||
OutputRecorder,
|
||||
maybe_install_capturing_hooks,
|
||||
)
|
||||
|
||||
# The default value in PreTrainedModel is None
|
||||
if self.model._can_record_outputs is None:
|
||||
@@ -532,6 +535,9 @@ class Base(
|
||||
self.model._can_record_outputs[layer_key] = aux_hidden_state_i
|
||||
self._output_aux_hidden_states_kwargs[f"output_{layer_key}"] = True
|
||||
|
||||
# Ensure that the capture hooks are installed before dynamo traces the model
|
||||
maybe_install_capturing_hooks(self.model)
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
num_layers = self.text_config.num_hidden_layers
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
Reference in New Issue
Block a user