Update pre-commit hooks (#12475)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-01-28 00:23:08 +00:00
committed by GitHub
parent 6116ca8cd7
commit 823ab79633
64 changed files with 322 additions and 288 deletions

View File

@@ -115,17 +115,17 @@ class VocabParallelEmbeddingShardIndices:
def __post_init__(self):
# sanity checks
assert (self.padded_org_vocab_start_index <=
self.padded_org_vocab_end_index)
assert (self.padded_added_vocab_start_index <=
self.padded_added_vocab_end_index)
assert (self.padded_org_vocab_start_index
<= self.padded_org_vocab_end_index)
assert (self.padded_added_vocab_start_index
<= self.padded_added_vocab_end_index)
assert self.org_vocab_start_index <= self.org_vocab_end_index
assert self.added_vocab_start_index <= self.added_vocab_end_index
assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
assert (self.added_vocab_start_index <=
self.padded_added_vocab_start_index)
assert (self.added_vocab_start_index
<= self.padded_added_vocab_start_index)
assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
@@ -141,8 +141,8 @@ def get_masked_input_and_mask(
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
# torch.compile will fuse all of the pointwise ops below
# into a single kernel, making it very fast
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ <
org_vocab_end_index)
org_vocab_mask = (input_ >= org_vocab_start_index) & (
input_ < org_vocab_end_index)
added_vocab_mask = (input_ >= added_vocab_start_index) & (
input_ < added_vocab_end_index)
added_offset = added_vocab_start_index - (