[v1] Support allowed_token_ids in v1 Sampler (#13210)
Signed-off-by: Lu Fang <lufang@fb.com>
This commit is contained in:
@@ -57,6 +57,26 @@ def _create_logit_bias(
|
||||
return res
|
||||
|
||||
|
||||
def _create_allowed_token_ids(
|
||||
batch_size: int,
|
||||
vocab_size: int,
|
||||
num_allowed_token_ids: int,
|
||||
device: torch.device,
|
||||
) -> Optional[torch.Tensor]:
|
||||
mask: Optional[torch.Tensor] = None
|
||||
for i in range(batch_size):
|
||||
if i % 2 == 1:
|
||||
continue
|
||||
if mask is None:
|
||||
mask = torch.zeros((batch_size, vocab_size),
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
start = min(i, vocab_size - 1)
|
||||
end = min(i + num_allowed_token_ids, vocab_size - 1)
|
||||
mask[i, start:end] = True
|
||||
return mask
|
||||
|
||||
|
||||
def _create_default_sampling_metadata(
|
||||
num_output_tokens: int,
|
||||
batch_size: int,
|
||||
@@ -92,6 +112,7 @@ def _create_default_sampling_metadata(
|
||||
no_penalties=True,
|
||||
min_tokens={},
|
||||
logit_bias=[None] * batch_size,
|
||||
allowed_token_ids_mask=None,
|
||||
)
|
||||
return fake_sampling_metadata
|
||||
|
||||
@@ -253,7 +274,10 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
|
||||
sampling_metadata.frequency_penalties = _create_penalty_tensor(
|
||||
batch_size, frequency_penalty, torch.device(device))
|
||||
output_token_ids, sorted_token_ids_in_output = \
|
||||
_create_weighted_output_token_list(batch_size, VOCAB_SIZE)
|
||||
_create_weighted_output_token_list(
|
||||
batch_size,
|
||||
VOCAB_SIZE,
|
||||
)
|
||||
sampling_metadata.output_token_ids = output_token_ids
|
||||
sampling_metadata.no_penalties = False
|
||||
sampler = Sampler()
|
||||
@@ -262,8 +286,8 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
|
||||
for batch_idx in range(batch_size):
|
||||
non_penalized_token_id = logits[batch_idx].argmax().item()
|
||||
penalized_token_id = logits[batch_idx].argmin().item()
|
||||
distinct_sorted_token_ids_in_output = \
|
||||
sorted_token_ids_in_output[batch_idx]
|
||||
distinct_sorted_token_ids_in_output = sorted_token_ids_in_output[
|
||||
batch_idx]
|
||||
most_frequent_token_id = distinct_sorted_token_ids_in_output[
|
||||
len(distinct_sorted_token_ids_in_output) - 1]
|
||||
if frequency_penalty > 0:
|
||||
@@ -272,8 +296,8 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
|
||||
# non-penalized token ID is not present in the output, while the
|
||||
# most penalized token is the one that occurs most frequently in
|
||||
# the output.
|
||||
assert non_penalized_token_id \
|
||||
not in distinct_sorted_token_ids_in_output
|
||||
assert (non_penalized_token_id
|
||||
not in distinct_sorted_token_ids_in_output)
|
||||
assert penalized_token_id == most_frequent_token_id
|
||||
elif frequency_penalty < 0:
|
||||
# If `frequency_penalty` is set to < 0, it indicates
|
||||
@@ -282,8 +306,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
|
||||
# in the output, while the penalized token ID is one that has not
|
||||
# yet appeared.
|
||||
assert non_penalized_token_id == most_frequent_token_id
|
||||
assert penalized_token_id \
|
||||
not in distinct_sorted_token_ids_in_output
|
||||
assert penalized_token_id not in distinct_sorted_token_ids_in_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@@ -318,18 +341,18 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
|
||||
# If `repetition_penalty` > 1.0, verify that the non-penalized
|
||||
# token ID has not been seen before, while the penalized token ID
|
||||
# exists either in the prompt or the output.
|
||||
assert (non_penalized_token_id not in prompt_tokens and \
|
||||
non_penalized_token_id not in output_tokens)
|
||||
assert (penalized_token_id in prompt_tokens or \
|
||||
penalized_token_id in output_tokens)
|
||||
assert (non_penalized_token_id not in prompt_tokens
|
||||
and non_penalized_token_id not in output_tokens)
|
||||
assert (penalized_token_id in prompt_tokens
|
||||
or penalized_token_id in output_tokens)
|
||||
elif repetition_penalty < 1.0:
|
||||
# If `repetition_penalty` < 1.0, verify that the penalized
|
||||
# token ID has not been seen before, while the non-penalized
|
||||
# token ID exists either in the prompt or the output.
|
||||
assert (penalized_token_id not in prompt_tokens and \
|
||||
penalized_token_id not in output_tokens)
|
||||
assert (non_penalized_token_id in prompt_tokens or \
|
||||
non_penalized_token_id in output_tokens)
|
||||
assert (penalized_token_id not in prompt_tokens
|
||||
and penalized_token_id not in output_tokens)
|
||||
assert (non_penalized_token_id in prompt_tokens
|
||||
or non_penalized_token_id in output_tokens)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@@ -404,3 +427,44 @@ def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
|
||||
1e-2)
|
||||
else:
|
||||
assert logits_for_req[token_id] == pytest.approx(1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
@pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2])
|
||||
def test_sampler_allowed_token_ids(device: str, batch_size: int,
|
||||
num_allowed_token_ids: int):
|
||||
"""
|
||||
Test to verify that when the repetition penalty is enabled, tokens
|
||||
are penalized based on their presence in the prompt or the existing
|
||||
output.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
# Create fake logits where each token is assigned the same
|
||||
# logit value.
|
||||
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
|
||||
sampling_metadata = _create_default_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
mask = _create_allowed_token_ids(
|
||||
batch_size=batch_size,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
num_allowed_token_ids=num_allowed_token_ids,
|
||||
device=device,
|
||||
)
|
||||
sampling_metadata.allowed_token_ids_mask = mask
|
||||
sampler = Sampler()
|
||||
logits = sampler.apply_allowed_token_ids(fake_logits, sampling_metadata)
|
||||
logits = logits.cpu()
|
||||
for batch_idx in range(batch_size):
|
||||
logits_for_req = logits[batch_idx]
|
||||
if batch_idx % 2 == 1:
|
||||
assert torch.all(logits_for_req != -float("inf"))
|
||||
continue
|
||||
for token_id in range(VOCAB_SIZE):
|
||||
start = min(batch_idx, VOCAB_SIZE - 1)
|
||||
end = min(batch_idx + num_allowed_token_ids, VOCAB_SIZE - 1)
|
||||
if token_id >= start and token_id < end:
|
||||
assert logits_for_req[token_id] == -float(
|
||||
"inf"), f"{batch_idx}, {token_id}"
|
||||
else:
|
||||
assert logits_for_req[token_id] != -float("inf")
|
||||
|
||||
Reference in New Issue
Block a user