[v1] Support allowed_token_ids in v1 Sampler (#13210)

Signed-off-by: Lu Fang <lufang@fb.com>
This commit is contained in:
Lu Fang
2025-02-21 22:13:05 -08:00
committed by GitHub
parent 8aca27fa11
commit bb78fb318e
7 changed files with 168 additions and 19 deletions

View File

@@ -66,6 +66,10 @@ def _construct_expected_sampling_metadata(
temperature = [0.0 for _ in range(num_reqs)]
min_tokens = {}
logit_bias = [None] * num_reqs
allowed_token_ids_mask = torch.zeros(num_reqs,
VOCAB_SIZE,
dtype=torch.bool,
device=device)
for req in reqs:
if req.req_id not in req_ids_retained:
continue
@@ -86,6 +90,10 @@ def _construct_expected_sampling_metadata(
req.sampling_params.min_tokens,
req.sampling_params.all_stop_token_ids)
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
if req.sampling_params.allowed_token_ids:
allowed_token_ids_mask[index_in_input_batch][
req.sampling_params.allowed_token_ids] = True
return SamplingMetadata(
temperature=torch.tensor(temperature, dtype=torch.float,
device=device),
@@ -121,6 +129,7 @@ def _construct_expected_sampling_metadata(
and all(x == 0 for x in frequency_penalties)
and all(x == 1 for x in repetition_penalties)),
logit_bias=logit_bias,
allowed_token_ids_mask=allowed_token_ids_mask,
)
@@ -242,3 +251,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
assert expected_sampling_metadata.no_penalties == \
sampling_metadata.no_penalties
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias
if sampling_metadata.allowed_token_ids_mask:
assert torch.allclose(
expected_sampling_metadata.allowed_token_ids_mask,
sampling_metadata.allowed_token_ids_mask)