[V1] Do not detokenize if sampling param detokenize is False (#14224)

Signed-off-by: Himanshu Jaju <hj@mistral.ai>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Himanshu Jaju
2025-03-06 19:40:24 +01:00
committed by GitHub
parent 9f1710f1ac
commit cd579352bf
4 changed files with 69 additions and 27 deletions

View File

@@ -14,7 +14,10 @@ PROMPT = "Hello my name is Robert and I"
@pytest.fixture(scope="module")
def model() -> LLM:
return LLM(MODEL, enforce_eager=True)
# Disable prefix caching so that we can test prompt logprobs.
# TODO remove this after https://github.com/vllm-project/vllm/pull/13949
# is merged
return LLM(MODEL, enforce_eager=True, enable_prefix_caching=False)
def test_n_gt_1(model):
@@ -87,9 +90,33 @@ def test_stop_token_ids(model):
stop_token_ids = [stop_token_id_0, stop_token_id_1]
params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids)
output = model.generate(PROMPT, params)
assert output[0].outputs[0].token_ids[-1] == stop_token_id_0
def test_detokenize_false(model):
"""Check that detokenize=False option works."""
output = model.generate(PROMPT, SamplingParams(detokenize=False))
assert len(output[0].outputs[0].token_ids) > 0
assert len(output[0].outputs[0].text) == 0
output = model.generate(
PROMPT, SamplingParams(detokenize=False, logprobs=3,
prompt_logprobs=3))
assert len(output[0].outputs[0].token_ids) > 0
assert len(output[0].outputs[0].text) == 0
prompt_logprobs = output[0].prompt_logprobs
sampled_logprobs = output[0].outputs[0].logprobs
assert len(prompt_logprobs) > 1
assert len(sampled_logprobs) > 1
for all_logprobs in (prompt_logprobs[1:], sampled_logprobs):
for logprobs in all_logprobs:
assert 3 <= len(logprobs) <= 4
assert all(lp.decoded_token is None for lp in logprobs.values())
def test_bad_words(model):
"""Check that we respect bad words."""