[v1] Support allowed_token_ids in v1 Sampler (#13210)
Signed-off-by: Lu Fang <lufang@fb.com>
This commit is contained in:
@@ -66,6 +66,10 @@ def _construct_expected_sampling_metadata(
|
||||
temperature = [0.0 for _ in range(num_reqs)]
|
||||
min_tokens = {}
|
||||
logit_bias = [None] * num_reqs
|
||||
allowed_token_ids_mask = torch.zeros(num_reqs,
|
||||
VOCAB_SIZE,
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
for req in reqs:
|
||||
if req.req_id not in req_ids_retained:
|
||||
continue
|
||||
@@ -86,6 +90,10 @@ def _construct_expected_sampling_metadata(
|
||||
req.sampling_params.min_tokens,
|
||||
req.sampling_params.all_stop_token_ids)
|
||||
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
|
||||
if req.sampling_params.allowed_token_ids:
|
||||
allowed_token_ids_mask[index_in_input_batch][
|
||||
req.sampling_params.allowed_token_ids] = True
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=torch.tensor(temperature, dtype=torch.float,
|
||||
device=device),
|
||||
@@ -121,6 +129,7 @@ def _construct_expected_sampling_metadata(
|
||||
and all(x == 0 for x in frequency_penalties)
|
||||
and all(x == 1 for x in repetition_penalties)),
|
||||
logit_bias=logit_bias,
|
||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||
)
|
||||
|
||||
|
||||
@@ -242,3 +251,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
assert expected_sampling_metadata.no_penalties == \
|
||||
sampling_metadata.no_penalties
|
||||
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias
|
||||
if sampling_metadata.allowed_token_ids_mask:
|
||||
assert torch.allclose(
|
||||
expected_sampling_metadata.allowed_token_ids_mask,
|
||||
sampling_metadata.allowed_token_ids_mask)
|
||||
|
||||
Reference in New Issue
Block a user