[V1] Support bad_words in sampler (#13376)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
22quinn
2025-03-08 14:50:26 -08:00
committed by GitHub
parent 9513290032
commit eb8b5eb183
13 changed files with 266 additions and 28 deletions

View File

@@ -120,8 +120,22 @@ def test_detokenize_false(model):
def test_bad_words(model):
"""Check that we respect bad words."""
with pytest.raises(ValueError):
_ = model.generate(PROMPT, SamplingParams(bad_words=["Hello"]))
output = model.generate(PROMPT, SamplingParams(temperature=0))
split_text = output[0].outputs[0].text.split()
bad_words_1 = " ".join(split_text[:2])
params = SamplingParams(temperature=0, bad_words=[bad_words_1])
output = model.generate(PROMPT, params)
new_text = output[0].outputs[0].text
assert bad_words_1 not in new_text
bad_words_2 = new_text.split()[-1]
params = SamplingParams(temperature=0,
bad_words=[bad_words_1, bad_words_2])
output = model.generate(PROMPT, params)
new_text = output[0].outputs[0].text
assert bad_words_1 not in new_text
assert bad_words_2 not in new_text
def test_logits_processor(model):