[V1] Support bad_words in sampler (#13376)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -100,6 +100,7 @@ def _construct_expected_sampling_metadata(
|
||||
VOCAB_SIZE,
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
bad_words_token_ids = {}
|
||||
for req in reqs:
|
||||
if req.req_id not in req_ids_retained:
|
||||
continue
|
||||
@@ -123,6 +124,8 @@ def _construct_expected_sampling_metadata(
|
||||
if req.sampling_params.allowed_token_ids:
|
||||
allowed_token_ids_mask[index_in_input_batch][
|
||||
req.sampling_params.allowed_token_ids] = True
|
||||
bad_words_token_ids[
|
||||
index_in_input_batch] = req.sampling_params.bad_words_token_ids
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=torch.tensor(temperature, dtype=torch.float,
|
||||
@@ -159,6 +162,7 @@ def _construct_expected_sampling_metadata(
|
||||
and all(x == 1 for x in repetition_penalties)),
|
||||
logit_bias=logit_bias,
|
||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||
bad_words_token_ids=bad_words_token_ids,
|
||||
)
|
||||
|
||||
|
||||
@@ -284,6 +288,8 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
assert torch.allclose(
|
||||
expected_sampling_metadata.allowed_token_ids_mask,
|
||||
sampling_metadata.allowed_token_ids_mask)
|
||||
assert expected_sampling_metadata.bad_words_token_ids == \
|
||||
sampling_metadata.bad_words_token_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
|
||||
Reference in New Issue
Block a user