Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -29,12 +29,12 @@ def _create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor:
|
||||
return fake_logits
|
||||
|
||||
|
||||
def _create_penalty_tensor(batch_size: int, penalty_value: float,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
return torch.full((batch_size, ),
|
||||
fill_value=penalty_value,
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
def _create_penalty_tensor(
|
||||
batch_size: int, penalty_value: float, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
return torch.full(
|
||||
(batch_size,), fill_value=penalty_value, dtype=torch.float, device=device
|
||||
)
|
||||
|
||||
|
||||
def _create_prompt_tokens_tensor(
|
||||
@@ -62,9 +62,9 @@ def _create_allowed_token_ids(
|
||||
if i % 2 == 1:
|
||||
continue
|
||||
if mask is None:
|
||||
mask = torch.zeros((batch_size, vocab_size),
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
mask = torch.zeros(
|
||||
(batch_size, vocab_size), dtype=torch.bool, device=device
|
||||
)
|
||||
start = min(i, vocab_size - 1)
|
||||
end = min(i + num_allowed_token_ids, vocab_size - 1)
|
||||
mask[i, start:end] = True
|
||||
@@ -80,9 +80,9 @@ def _create_bad_words_token_ids(
|
||||
for batch_idx in range(batch_size):
|
||||
token_ids_single_batch = []
|
||||
for bad_words_length in bad_words_lengths:
|
||||
token_ids = np.random.choice(vocab_size,
|
||||
size=bad_words_length,
|
||||
replace=True).tolist()
|
||||
token_ids = np.random.choice(
|
||||
vocab_size, size=bad_words_length, replace=True
|
||||
).tolist()
|
||||
token_ids_single_batch.append(token_ids)
|
||||
bad_words_token_ids[batch_idx] = token_ids_single_batch
|
||||
if batch_size >= 2:
|
||||
@@ -95,26 +95,27 @@ def _create_bad_words_token_ids(
|
||||
# Returns all last tokens of bad word sequences that share the same prefix
|
||||
# as `given_prefix` (excluding the last token).
|
||||
def _collect_suffixes_with_same_prefix(
|
||||
given_prefix: list[int],
|
||||
bad_words_token_ids: list[list[int]]) -> list[int]:
|
||||
given_prefix: list[int], bad_words_token_ids: list[list[int]]
|
||||
) -> list[int]:
|
||||
return [bwt[-1] for bwt in bad_words_token_ids if bwt[:-1] == given_prefix]
|
||||
|
||||
|
||||
# generate a valid token id that is not in bad_words_token_ids
|
||||
def _generate_valid_token_id(bad_words_token_ids: list[list[int]],
|
||||
vocab_size: int) -> int:
|
||||
def _generate_valid_token_id(
|
||||
bad_words_token_ids: list[list[int]], vocab_size: int
|
||||
) -> int:
|
||||
forbidden_start_tokens = set()
|
||||
for bad_word in bad_words_token_ids:
|
||||
forbidden_start_tokens.add(bad_word[0])
|
||||
# Get a safe token that's not in forbidden starts
|
||||
safe_token_candidates = list(
|
||||
set(range(vocab_size)) - forbidden_start_tokens)
|
||||
safe_token_candidates = list(set(range(vocab_size)) - forbidden_start_tokens)
|
||||
# Pick a random safe token
|
||||
return np.random.choice(safe_token_candidates)
|
||||
|
||||
|
||||
def _update_output_token_ids_for_bad_words(
|
||||
metadata: SamplingMetadata, vocab_size: int) -> dict[int, list[int]]:
|
||||
metadata: SamplingMetadata, vocab_size: int
|
||||
) -> dict[int, list[int]]:
|
||||
bad_words_last_tokens = {}
|
||||
for batch_idx, bad_words_token_ids in metadata.bad_words_token_ids.items():
|
||||
output_token_ids = metadata.output_token_ids[batch_idx]
|
||||
@@ -132,12 +133,13 @@ def _update_output_token_ids_for_bad_words(
|
||||
# Collect all last tokens from other bad words
|
||||
# that share this prefix
|
||||
bad_words_last_token.extend(
|
||||
_collect_suffixes_with_same_prefix(
|
||||
prefix, bad_words_token_ids))
|
||||
_collect_suffixes_with_same_prefix(prefix, bad_words_token_ids)
|
||||
)
|
||||
break # Maximum one update to output_token_ids
|
||||
else: # Make sure no accidental match to bad words
|
||||
output_token_ids[-1] = _generate_valid_token_id(
|
||||
bad_words_token_ids, vocab_size)
|
||||
bad_words_token_ids, vocab_size
|
||||
)
|
||||
bad_words_last_tokens[batch_idx] = bad_words_last_token
|
||||
return bad_words_last_tokens
|
||||
|
||||
@@ -152,22 +154,24 @@ def _create_default_sampling_metadata(
|
||||
prompt_token_ids: list[list[int]] = []
|
||||
for _ in range(batch_size):
|
||||
output_token_ids.append(
|
||||
np.random.randint(0, vocab_size, size=num_output_tokens).tolist())
|
||||
np.random.randint(0, vocab_size, size=num_output_tokens).tolist()
|
||||
)
|
||||
prompt_token_ids.append(
|
||||
np.random.randint(0,
|
||||
vocab_size,
|
||||
size=np.random.randint(
|
||||
1, MAX_NUM_PROMPT_TOKENS)).tolist())
|
||||
np.random.randint(
|
||||
0, vocab_size, size=np.random.randint(1, MAX_NUM_PROMPT_TOKENS)
|
||||
).tolist()
|
||||
)
|
||||
fake_sampling_metadata = SamplingMetadata(
|
||||
temperature=torch.full((batch_size, ), 0.0),
|
||||
temperature=torch.full((batch_size,), 0.0),
|
||||
all_greedy=True,
|
||||
all_random=False,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
generators={},
|
||||
max_num_logprobs=0,
|
||||
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
|
||||
vocab_size, device),
|
||||
prompt_token_ids=_create_prompt_tokens_tensor(
|
||||
prompt_token_ids, vocab_size, device
|
||||
),
|
||||
output_token_ids=output_token_ids,
|
||||
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
|
||||
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
|
||||
@@ -181,8 +185,8 @@ def _create_default_sampling_metadata(
|
||||
|
||||
|
||||
def _create_weighted_output_token_list(
|
||||
batch_size: int,
|
||||
vocab_size: int) -> tuple[list[list[int]], list[list[int]]]:
|
||||
batch_size: int, vocab_size: int
|
||||
) -> tuple[list[list[int]], list[list[int]]]:
|
||||
"""
|
||||
Creates an output token list where each token occurs a distinct
|
||||
number of times.
|
||||
@@ -203,14 +207,13 @@ def _create_weighted_output_token_list(
|
||||
output_token_ids: list[list[int]] = []
|
||||
sorted_token_ids_in_output: list[list[int]] = []
|
||||
for _ in range(batch_size):
|
||||
distinct_token_ids = np.random.choice(vocab_size,
|
||||
size=np.random.randint(1, 10),
|
||||
replace=False).tolist()
|
||||
distinct_token_ids = np.random.choice(
|
||||
vocab_size, size=np.random.randint(1, 10), replace=False
|
||||
).tolist()
|
||||
sorted_token_ids_in_output.append(distinct_token_ids)
|
||||
output_token_ids_for_batch = []
|
||||
for index, token_id in enumerate(distinct_token_ids):
|
||||
output_token_ids_for_batch.extend(
|
||||
[token_id for _ in range(index + 1)])
|
||||
output_token_ids_for_batch.extend([token_id for _ in range(index + 1)])
|
||||
output_token_ids.append(output_token_ids_for_batch)
|
||||
return output_token_ids, sorted_token_ids_in_output
|
||||
|
||||
@@ -218,8 +221,9 @@ def _create_weighted_output_token_list(
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
@pytest.mark.parametrize("presence_penalty", [-2.0, 2.0])
|
||||
def test_sampler_presence_penalty(device: str, batch_size: int,
|
||||
presence_penalty: float):
|
||||
def test_sampler_presence_penalty(
|
||||
device: str, batch_size: int, presence_penalty: float
|
||||
):
|
||||
"""
|
||||
Test to verify that if presence penalty is enabled then tokens
|
||||
are penalized as per their presence in the existing output.
|
||||
@@ -229,10 +233,12 @@ def test_sampler_presence_penalty(device: str, batch_size: int,
|
||||
# logit value.
|
||||
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
|
||||
sampling_metadata = _create_default_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)
|
||||
)
|
||||
output_token_ids = sampling_metadata.output_token_ids
|
||||
sampling_metadata.presence_penalties = _create_penalty_tensor(
|
||||
batch_size, presence_penalty, torch.device(device))
|
||||
batch_size, presence_penalty, torch.device(device)
|
||||
)
|
||||
sampling_metadata.no_penalties = False
|
||||
sampler = Sampler()
|
||||
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
||||
@@ -263,8 +269,9 @@ def test_sampler_presence_penalty(device: str, batch_size: int,
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
@pytest.mark.parametrize("frequency_penalty", [-2.0, 2.0])
|
||||
def test_sampler_frequency_penalty(device: str, batch_size: int,
|
||||
frequency_penalty: float):
|
||||
def test_sampler_frequency_penalty(
|
||||
device: str, batch_size: int, frequency_penalty: float
|
||||
):
|
||||
"""
|
||||
Test to verify that if frequency penalty is enabled then tokens are
|
||||
penalized as per their frequency of occurrence.
|
||||
@@ -274,14 +281,15 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
|
||||
# logit value.
|
||||
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
|
||||
sampling_metadata = _create_default_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)
|
||||
)
|
||||
sampling_metadata.frequency_penalties = _create_penalty_tensor(
|
||||
batch_size, frequency_penalty, torch.device(device))
|
||||
output_token_ids, sorted_token_ids_in_output = \
|
||||
_create_weighted_output_token_list(
|
||||
batch_size,
|
||||
VOCAB_SIZE,
|
||||
)
|
||||
batch_size, frequency_penalty, torch.device(device)
|
||||
)
|
||||
output_token_ids, sorted_token_ids_in_output = _create_weighted_output_token_list(
|
||||
batch_size,
|
||||
VOCAB_SIZE,
|
||||
)
|
||||
sampling_metadata.output_token_ids = output_token_ids
|
||||
sampling_metadata.no_penalties = False
|
||||
sampler = Sampler()
|
||||
@@ -290,18 +298,17 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
|
||||
for batch_idx in range(batch_size):
|
||||
non_penalized_token_id = logits[batch_idx].argmax().item()
|
||||
penalized_token_id = logits[batch_idx].argmin().item()
|
||||
distinct_sorted_token_ids_in_output = sorted_token_ids_in_output[
|
||||
batch_idx]
|
||||
distinct_sorted_token_ids_in_output = sorted_token_ids_in_output[batch_idx]
|
||||
most_frequent_token_id = distinct_sorted_token_ids_in_output[
|
||||
len(distinct_sorted_token_ids_in_output) - 1]
|
||||
len(distinct_sorted_token_ids_in_output) - 1
|
||||
]
|
||||
if frequency_penalty > 0:
|
||||
# If `frequency_penalty` is set to > 0, it indicates
|
||||
# a preference for new tokens over existing ones. Verify that the
|
||||
# non-penalized token ID is not present in the output, while the
|
||||
# most penalized token is the one that occurs most frequently in
|
||||
# the output.
|
||||
assert (non_penalized_token_id
|
||||
not in distinct_sorted_token_ids_in_output)
|
||||
assert non_penalized_token_id not in distinct_sorted_token_ids_in_output
|
||||
assert penalized_token_id == most_frequent_token_id
|
||||
elif frequency_penalty < 0:
|
||||
# If `frequency_penalty` is set to < 0, it indicates
|
||||
@@ -316,8 +323,9 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
@pytest.mark.parametrize("repetition_penalty", [0.1, 1.9])
|
||||
def test_sampler_repetition_penalty(device: str, batch_size: int,
|
||||
repetition_penalty: float):
|
||||
def test_sampler_repetition_penalty(
|
||||
device: str, batch_size: int, repetition_penalty: float
|
||||
):
|
||||
"""
|
||||
Test to verify that when the repetition penalty is enabled, tokens
|
||||
are penalized based on their presence in the prompt or the existing
|
||||
@@ -328,9 +336,11 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
|
||||
# logit value.
|
||||
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
|
||||
sampling_metadata = _create_default_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)
|
||||
)
|
||||
sampling_metadata.repetition_penalties = _create_penalty_tensor(
|
||||
batch_size, repetition_penalty, torch.device(device))
|
||||
batch_size, repetition_penalty, torch.device(device)
|
||||
)
|
||||
sampling_metadata.no_penalties = False
|
||||
sampler = Sampler()
|
||||
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
||||
@@ -338,32 +348,40 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
|
||||
for batch_idx in range(batch_size):
|
||||
non_penalized_token_id = logits[batch_idx].argmax().item()
|
||||
penalized_token_id = logits[batch_idx].argmin().item()
|
||||
prompt_tokens = sampling_metadata.prompt_token_ids[
|
||||
batch_idx][:].tolist()
|
||||
prompt_tokens = sampling_metadata.prompt_token_ids[batch_idx][:].tolist()
|
||||
output_tokens = sampling_metadata.output_token_ids[batch_idx]
|
||||
if repetition_penalty > 1.0:
|
||||
# If `repetition_penalty` > 1.0, verify that the non-penalized
|
||||
# token ID has not been seen before, while the penalized token ID
|
||||
# exists either in the prompt or the output.
|
||||
assert (non_penalized_token_id not in prompt_tokens
|
||||
and non_penalized_token_id not in output_tokens)
|
||||
assert (penalized_token_id in prompt_tokens
|
||||
or penalized_token_id in output_tokens)
|
||||
assert (
|
||||
non_penalized_token_id not in prompt_tokens
|
||||
and non_penalized_token_id not in output_tokens
|
||||
)
|
||||
assert (
|
||||
penalized_token_id in prompt_tokens
|
||||
or penalized_token_id in output_tokens
|
||||
)
|
||||
elif repetition_penalty < 1.0:
|
||||
# If `repetition_penalty` < 1.0, verify that the penalized
|
||||
# token ID has not been seen before, while the non-penalized
|
||||
# token ID exists either in the prompt or the output.
|
||||
assert (penalized_token_id not in prompt_tokens
|
||||
and penalized_token_id not in output_tokens)
|
||||
assert (non_penalized_token_id in prompt_tokens
|
||||
or non_penalized_token_id in output_tokens)
|
||||
assert (
|
||||
penalized_token_id not in prompt_tokens
|
||||
and penalized_token_id not in output_tokens
|
||||
)
|
||||
assert (
|
||||
non_penalized_token_id in prompt_tokens
|
||||
or non_penalized_token_id in output_tokens
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
@pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2])
|
||||
def test_sampler_allowed_token_ids(device: str, batch_size: int,
|
||||
num_allowed_token_ids: int):
|
||||
def test_sampler_allowed_token_ids(
|
||||
device: str, batch_size: int, num_allowed_token_ids: int
|
||||
):
|
||||
"""
|
||||
Test to verify that when the repetition penalty is enabled, tokens
|
||||
are penalized based on their presence in the prompt or the existing
|
||||
@@ -374,7 +392,8 @@ def test_sampler_allowed_token_ids(device: str, batch_size: int,
|
||||
# logit value.
|
||||
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
|
||||
sampling_metadata = _create_default_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)
|
||||
)
|
||||
mask = _create_allowed_token_ids(
|
||||
batch_size=batch_size,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
@@ -394,17 +413,19 @@ def test_sampler_allowed_token_ids(device: str, batch_size: int,
|
||||
start = min(batch_idx, VOCAB_SIZE - 1)
|
||||
end = min(batch_idx + num_allowed_token_ids, VOCAB_SIZE - 1)
|
||||
if token_id >= start and token_id < end:
|
||||
assert logits_for_req[token_id] == -float(
|
||||
"inf"), f"{batch_idx}, {token_id}"
|
||||
assert logits_for_req[token_id] == -float("inf"), (
|
||||
f"{batch_idx}, {token_id}"
|
||||
)
|
||||
else:
|
||||
assert logits_for_req[token_id] != -float("inf")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
@pytest.mark.parametrize("bad_words_lengths", [(1, ), (1, 3), (2, 2)])
|
||||
def test_sampler_bad_words(device: str, batch_size: int,
|
||||
bad_words_lengths: tuple[int, ...]):
|
||||
@pytest.mark.parametrize("bad_words_lengths", [(1,), (1, 3), (2, 2)])
|
||||
def test_sampler_bad_words(
|
||||
device: str, batch_size: int, bad_words_lengths: tuple[int, ...]
|
||||
):
|
||||
"""
|
||||
Test to verify that when the bad words restriction is present, tokens
|
||||
are penalized based on their match with the bad words.
|
||||
@@ -414,19 +435,24 @@ def test_sampler_bad_words(device: str, batch_size: int,
|
||||
# logit value.
|
||||
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
|
||||
sampling_metadata = _create_default_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)
|
||||
)
|
||||
sampling_metadata.bad_words_token_ids = _create_bad_words_token_ids(
|
||||
batch_size, VOCAB_SIZE, bad_words_lengths)
|
||||
batch_size, VOCAB_SIZE, bad_words_lengths
|
||||
)
|
||||
bad_words_last_tokens = _update_output_token_ids_for_bad_words(
|
||||
sampling_metadata, VOCAB_SIZE)
|
||||
sampling_metadata, VOCAB_SIZE
|
||||
)
|
||||
sampler = Sampler()
|
||||
logits = sampler.apply_bad_words(fake_logits, sampling_metadata)
|
||||
logits = logits.cpu()
|
||||
for batch_idx in range(batch_size):
|
||||
logits_for_req = logits[batch_idx]
|
||||
for token_id in range(VOCAB_SIZE):
|
||||
if (batch_idx in bad_words_last_tokens
|
||||
and token_id in bad_words_last_tokens[batch_idx]):
|
||||
if (
|
||||
batch_idx in bad_words_last_tokens
|
||||
and token_id in bad_words_last_tokens[batch_idx]
|
||||
):
|
||||
assert logits_for_req[token_id] == -float("inf")
|
||||
else:
|
||||
assert logits_for_req[token_id] != -float("inf")
|
||||
|
||||
Reference in New Issue
Block a user