[mypy] Fix wrong type annotations related to tuple (#25660)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-09-25 21:00:45 +08:00
committed by GitHub
parent 1e9a77e037
commit 2f17117606
9 changed files with 25 additions and 20 deletions

View File

@@ -72,8 +72,10 @@ def _create_allowed_token_ids(
def _create_bad_words_token_ids(
batch_size: int, vocab_size: int,
bad_words_lengths: list[tuple[int]]) -> dict[int, list[list[int]]]:
batch_size: int,
vocab_size: int,
bad_words_lengths: tuple[int, ...],
) -> dict[int, list[list[int]]]:
bad_words_token_ids = {}
for batch_idx in range(batch_size):
token_ids_single_batch = []
@@ -402,7 +404,7 @@ def test_sampler_allowed_token_ids(device: str, batch_size: int,
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("bad_words_lengths", [(1, ), (1, 3), (2, 2)])
def test_sampler_bad_words(device: str, batch_size: int,
bad_words_lengths: list[tuple[int]]):
bad_words_lengths: tuple[int, ...]):
"""
Test to verify that when the bad words restriction is present, tokens
are penalized based on their match with the bad words.