[v1] Support allowed_token_ids in v1 Sampler (#13210)

Signed-off-by: Lu Fang <lufang@fb.com>
This commit is contained in:
Lu Fang
2025-02-21 22:13:05 -08:00
committed by GitHub
parent 8aca27fa11
commit bb78fb318e
7 changed files with 168 additions and 19 deletions

View File

@@ -57,6 +57,26 @@ def _create_logit_bias(
return res
def _create_allowed_token_ids(
batch_size: int,
vocab_size: int,
num_allowed_token_ids: int,
device: torch.device,
) -> Optional[torch.Tensor]:
mask: Optional[torch.Tensor] = None
for i in range(batch_size):
if i % 2 == 1:
continue
if mask is None:
mask = torch.zeros((batch_size, vocab_size),
dtype=torch.bool,
device=device)
start = min(i, vocab_size - 1)
end = min(i + num_allowed_token_ids, vocab_size - 1)
mask[i, start:end] = True
return mask
def _create_default_sampling_metadata(
num_output_tokens: int,
batch_size: int,
@@ -92,6 +112,7 @@ def _create_default_sampling_metadata(
no_penalties=True,
min_tokens={},
logit_bias=[None] * batch_size,
allowed_token_ids_mask=None,
)
return fake_sampling_metadata
@@ -253,7 +274,10 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
sampling_metadata.frequency_penalties = _create_penalty_tensor(
batch_size, frequency_penalty, torch.device(device))
output_token_ids, sorted_token_ids_in_output = \
_create_weighted_output_token_list(batch_size, VOCAB_SIZE)
_create_weighted_output_token_list(
batch_size,
VOCAB_SIZE,
)
sampling_metadata.output_token_ids = output_token_ids
sampling_metadata.no_penalties = False
sampler = Sampler()
@@ -262,8 +286,8 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
for batch_idx in range(batch_size):
non_penalized_token_id = logits[batch_idx].argmax().item()
penalized_token_id = logits[batch_idx].argmin().item()
distinct_sorted_token_ids_in_output = \
sorted_token_ids_in_output[batch_idx]
distinct_sorted_token_ids_in_output = sorted_token_ids_in_output[
batch_idx]
most_frequent_token_id = distinct_sorted_token_ids_in_output[
len(distinct_sorted_token_ids_in_output) - 1]
if frequency_penalty > 0:
@@ -272,8 +296,8 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
# non-penalized token ID is not present in the output, while the
# most penalized token is the one that occurs most frequently in
# the output.
assert non_penalized_token_id \
not in distinct_sorted_token_ids_in_output
assert (non_penalized_token_id
not in distinct_sorted_token_ids_in_output)
assert penalized_token_id == most_frequent_token_id
elif frequency_penalty < 0:
# If `frequency_penalty` is set to < 0, it indicates
@@ -282,8 +306,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
# in the output, while the penalized token ID is one that has not
# yet appeared.
assert non_penalized_token_id == most_frequent_token_id
assert penalized_token_id \
not in distinct_sorted_token_ids_in_output
assert penalized_token_id not in distinct_sorted_token_ids_in_output
@pytest.mark.parametrize("device", CUDA_DEVICES)
@@ -318,18 +341,18 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
# If `repetition_penalty` > 1.0, verify that the non-penalized
# token ID has not been seen before, while the penalized token ID
# exists either in the prompt or the output.
assert (non_penalized_token_id not in prompt_tokens and \
non_penalized_token_id not in output_tokens)
assert (penalized_token_id in prompt_tokens or \
penalized_token_id in output_tokens)
assert (non_penalized_token_id not in prompt_tokens
and non_penalized_token_id not in output_tokens)
assert (penalized_token_id in prompt_tokens
or penalized_token_id in output_tokens)
elif repetition_penalty < 1.0:
# If `repetition_penalty` < 1.0, verify that the penalized
# token ID has not been seen before, while the non-penalized
# token ID exists either in the prompt or the output.
assert (penalized_token_id not in prompt_tokens and \
penalized_token_id not in output_tokens)
assert (non_penalized_token_id in prompt_tokens or \
non_penalized_token_id in output_tokens)
assert (penalized_token_id not in prompt_tokens
and penalized_token_id not in output_tokens)
assert (non_penalized_token_id in prompt_tokens
or non_penalized_token_id in output_tokens)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@@ -404,3 +427,44 @@ def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
1e-2)
else:
assert logits_for_req[token_id] == pytest.approx(1e-2)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2])
def test_sampler_allowed_token_ids(device: str, batch_size: int,
num_allowed_token_ids: int):
"""
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
torch.set_default_device(device)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
mask = _create_allowed_token_ids(
batch_size=batch_size,
vocab_size=VOCAB_SIZE,
num_allowed_token_ids=num_allowed_token_ids,
device=device,
)
sampling_metadata.allowed_token_ids_mask = mask
sampler = Sampler()
logits = sampler.apply_allowed_token_ids(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
logits_for_req = logits[batch_idx]
if batch_idx % 2 == 1:
assert torch.all(logits_for_req != -float("inf"))
continue
for token_id in range(VOCAB_SIZE):
start = min(batch_idx, VOCAB_SIZE - 1)
end = min(batch_idx + num_allowed_token_ids, VOCAB_SIZE - 1)
if token_id >= start and token_id < end:
assert logits_for_req[token_id] == -float(
"inf"), f"{batch_idx}, {token_id}"
else:
assert logits_for_req[token_id] != -float("inf")