Mark invariant normalizer in Gemma as non-persistent (#19788)
Signed-off-by: Yu-Hang Tang <Tang.Maxin@gmail.com>
This commit is contained in:
committed by
GitHub
parent
e2148dc5ea
commit
83ca9ae47b
@@ -281,7 +281,9 @@ class GemmaModel(nn.Module):
|
||||
# data type such as bfloat16, not float32.
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
normalizer = self.config.hidden_size**0.5
|
||||
self.register_buffer("normalizer", torch.tensor(normalizer))
|
||||
self.register_buffer("normalizer",
|
||||
torch.tensor(normalizer),
|
||||
persistent=False)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
Reference in New Issue
Block a user