[V1] Support bad_words in sampler (#13376)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -120,8 +120,22 @@ def test_detokenize_false(model):
|
||||
def test_bad_words(model):
|
||||
"""Check that we respect bad words."""
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_ = model.generate(PROMPT, SamplingParams(bad_words=["Hello"]))
|
||||
output = model.generate(PROMPT, SamplingParams(temperature=0))
|
||||
split_text = output[0].outputs[0].text.split()
|
||||
|
||||
bad_words_1 = " ".join(split_text[:2])
|
||||
params = SamplingParams(temperature=0, bad_words=[bad_words_1])
|
||||
output = model.generate(PROMPT, params)
|
||||
new_text = output[0].outputs[0].text
|
||||
assert bad_words_1 not in new_text
|
||||
|
||||
bad_words_2 = new_text.split()[-1]
|
||||
params = SamplingParams(temperature=0,
|
||||
bad_words=[bad_words_1, bad_words_2])
|
||||
output = model.generate(PROMPT, params)
|
||||
new_text = output[0].outputs[0].text
|
||||
assert bad_words_1 not in new_text
|
||||
assert bad_words_2 not in new_text
|
||||
|
||||
|
||||
def test_logits_processor(model):
|
||||
|
||||
Reference in New Issue
Block a user