[mypy] Fix wrong type annotations related to tuple (#25660)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -72,8 +72,10 @@ def _create_allowed_token_ids(
|
||||
|
||||
|
||||
def _create_bad_words_token_ids(
|
||||
batch_size: int, vocab_size: int,
|
||||
bad_words_lengths: list[tuple[int]]) -> dict[int, list[list[int]]]:
|
||||
batch_size: int,
|
||||
vocab_size: int,
|
||||
bad_words_lengths: tuple[int, ...],
|
||||
) -> dict[int, list[list[int]]]:
|
||||
bad_words_token_ids = {}
|
||||
for batch_idx in range(batch_size):
|
||||
token_ids_single_batch = []
|
||||
@@ -402,7 +404,7 @@ def test_sampler_allowed_token_ids(device: str, batch_size: int,
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
@pytest.mark.parametrize("bad_words_lengths", [(1, ), (1, 3), (2, 2)])
|
||||
def test_sampler_bad_words(device: str, batch_size: int,
|
||||
bad_words_lengths: list[tuple[int]]):
|
||||
bad_words_lengths: tuple[int, ...]):
|
||||
"""
|
||||
Test to verify that when the bad words restriction is present, tokens
|
||||
are penalized based on their match with the bad words.
|
||||
|
||||
Reference in New Issue
Block a user