[Core] *Prompt* logprobs support in Multi-step (#8199)

This commit is contained in:
afeldman-nm
2024-09-18 11:38:43 -04:00
committed by GitHub
parent 7c7714d856
commit a8c1d161a7
5 changed files with 299 additions and 58 deletions

View File

@@ -1,7 +1,7 @@
import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union
from vllm.sequence import Logprob, SampleLogprobs
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
TokensText = Tuple[List[int], str]
@@ -34,20 +34,47 @@ def check_outputs_equal(
assert output_ids_0 == output_ids_1, fail_msg
# Representation of generated sequence as a tuple of
# * Token ID list
# * String
# * List of top sample logprobs for each sampled token
#
# Assumes prompt logprobs were not requested.
TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
float]],
SampleLogprobs]]]
# Allow for tokens to be represented as str's rather than IDs
# Allow for tokens to be represented as str's rather than IDs;
# tuple of
# * Token string representations list
# * String
# * Optional list of top sample logprobs for each sampled token
#
# Assumes prompt logprobs were not requested.
TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]],
List[Dict[str,
Logprob]]]]]
# Representation of generated sequence as a tuple of
# * Token ID list
# * String
# * Optional list of top sample logprobs for each sampled token
# * Optional list of top prompt logprobs for each prompt token
#
# Allows prompt logprobs to be requested.
TokensTextLogprobsPromptLogprobs = Tuple[
List[int], str, Optional[Union[List[Dict[int, float]], SampleLogprobs]],
Optional[Union[List[Optional[Dict[int, float]]], PromptLogprobs]]]
def check_logprobs_close(
*,
outputs_0_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]],
outputs_1_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]],
outputs_0_lst: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs,
TextTextLogprobs]],
outputs_1_lst: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs,
TextTextLogprobs]],
name_0: str,
name_1: str,
num_outputs_0_skip_tokens: int = 0,
@@ -57,6 +84,18 @@ def check_logprobs_close(
"""Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal.
How sample logprobs are compared:
* `always_check_logprobs == True`: set of highest-logprob token ids
must match between seq0 and seq1 at all sampled token offsets
* `always_check_logprobs == False`: highest-logprob token ids are
only compared at sampled token offsets for which generated token
ids don't match
Prompt logprobs must be provided either for both input sequences, or
for neither. If prompt logprobs are provided, then highest-logprob
prompt token ids must match between seq0 and seq1 at all prompt token
offsets.
Args:
outputs_0_lst: First sequence to compare
outputs_0_lst: Second sequence to compare
@@ -78,8 +117,65 @@ def check_logprobs_close(
for prompt_idx, (outputs_0,
outputs_1) in enumerate(zip(outputs_0_lst,
outputs_1_lst)):
output_ids_0, output_str_0, logprobs_0 = outputs_0
output_ids_1, output_str_1, logprobs_1 = outputs_1
assert len(outputs_0) == len(outputs_1)
if len(outputs_0) == 3:
assert len(outputs_1) == 3
# Break out tokens, text & sample logprobs
# (prompt logprobs were not provided)
output_ids_0, output_str_0, logprobs_0 = outputs_0
output_ids_1, output_str_1, logprobs_1 = outputs_1
elif len(outputs_0) == 4:
assert len(outputs_1) == 4
# Break out tokens, text, sample logprobs & prompt logprobs
(
output_ids_0,
output_str_0,
logprobs_0,
prompt_logprobs_0,
) = outputs_0
(
output_ids_1,
output_str_1,
logprobs_1,
prompt_logprobs_1,
) = outputs_1
# Test prompt logprobs closeness
if (prompt_logprobs_0 is not None
and prompt_logprobs_1 is not None):
# Both sequences' prompt logprobs lists are not `None``
# (although individual list elements may be `None`);
# for each token's logprobs:
for idx, (logprobs_elem_0, logprobs_elem_1) in enumerate(
zip(prompt_logprobs_0, prompt_logprobs_1)):
fail_msg = (
f"Prompt logprobs test:"
f"\n{name_0}:\tPrompt index {idx}\t{logprobs_elem_0}"
f"\n{name_1}:\tPrompt index {idx}\t{logprobs_elem_1}")
if logprobs_elem_0 is None:
# If the seq 0 token's logprobs are `None`,
# the seq 1 token's logprobs must be `None`
assert logprobs_elem_1 is None, fail_msg
else:
# If the seq 0 token's logprobs are not `None`,
# the seq 1 token's logprobs must not be `None`
assert logprobs_elem_1 is not None, fail_msg
# Logprobs check: top-k token choices must be the same
assert (set(logprobs_elem_0.keys()) == set(
logprobs_elem_1.keys())), fail_msg
else:
# Both sequence logprobs lists must be `None`
fail_msg = (f"Prompt logprobs test:"
f"\n{name_0}:\tlogprobs\t{prompt_logprobs_0}"
f"\n{name_1}:\tlogprobs\t{prompt_logprobs_1}")
assert (prompt_logprobs_0 is None
and prompt_logprobs_1 is None), fail_msg
else:
raise ValueError(f"Outputs tuple must have 3 or 4 elements but "
f"{len(outputs_0)} elements were provided: "
f"{outputs_0}")
if logprobs_0 is None:
logprobs_0 = [None] * len(output_ids_0)