[Sampler] Support returning all prompt logprobs (#23868)
Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com> Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -430,7 +430,7 @@ def test_zero_logprobs(vllm_model, example_prompts,
|
||||
|
||||
|
||||
def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Engine should return all vocabulary logprobs
|
||||
"""Engine should return all vocabulary logprobs and prompt logprobs
|
||||
|
||||
Args:
|
||||
example_prompts: list of example prompts (test fixture)
|
||||
@@ -444,16 +444,24 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
|
||||
# 2 other llms alive during whole session
|
||||
gpu_memory_utilization=0.15,
|
||||
max_model_len=256)
|
||||
|
||||
sampling_params_logprobs_all = SamplingParams(max_tokens=5,
|
||||
logprobs=-1)
|
||||
logprobs=-1,
|
||||
prompt_logprobs=-1)
|
||||
results_logprobs_all = runner.llm.generate(
|
||||
example_prompts, sampling_params=sampling_params_logprobs_all)
|
||||
vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size()
|
||||
|
||||
for i in range(len(results_logprobs_all)):
|
||||
logprobs = results_logprobs_all[i].outputs[0].logprobs
|
||||
prompt_logprobs = results_logprobs_all[i].prompt_logprobs
|
||||
assert logprobs is not None
|
||||
for logprob in logprobs:
|
||||
assert len(logprob) == vocab_size
|
||||
assert prompt_logprobs is not None
|
||||
assert prompt_logprobs[0] is None
|
||||
for prompt_logprob in prompt_logprobs[1:]:
|
||||
assert len(prompt_logprob) == vocab_size
|
||||
|
||||
|
||||
@pytest.mark.parametrize("logprobs_mode", list(LogprobsMode))
|
||||
|
||||
Reference in New Issue
Block a user