[Core] Subclass ModelRunner to support cross-attention & encoder sequences (towards eventual encoder/decoder model support) (#4942)
Co-authored-by: Andrew Feldman <afeld2012@gmail.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from vllm.sequence import SampleLogprobs
|
||||
@@ -45,11 +46,27 @@ def check_logprobs_close(
|
||||
outputs_1_lst: Sequence[TokensTextLogprobs],
|
||||
name_0: str,
|
||||
name_1: str,
|
||||
num_outputs_0_skip_tokens: int = 0,
|
||||
warn_on_mismatch: bool = True,
|
||||
):
|
||||
"""
|
||||
Compare the logprobs of two sequences generated by different models,
|
||||
which should be similar but not necessarily equal.
|
||||
|
||||
Arguments:
|
||||
|
||||
* outputs_0_lst: First sequence to compare
|
||||
* outputs_0_lst: Second sequence to compare
|
||||
* name_0: sequence #0 name
|
||||
* name_1: sequence #1 name
|
||||
* num_outputs_0_skip_tokens: If > 0, specifies the number of initial
|
||||
sequence #0 tokens & logprobs to discard
|
||||
before comparison, i.e. all
|
||||
of sequence #1 will be compared to
|
||||
sequence #0 beginning at index
|
||||
num_outputs_0_skip_tokens
|
||||
* warn_on_mismatch: Issue a warning if there is token-wise or text-wise
|
||||
mismatch between the two sequences
|
||||
"""
|
||||
assert len(outputs_0_lst) == len(outputs_1_lst)
|
||||
|
||||
@@ -65,6 +82,15 @@ def check_logprobs_close(
|
||||
if logprobs_1 is None:
|
||||
logprobs_1 = [None] * len(output_ids_1)
|
||||
|
||||
# Skip specified number of initial sequence #0 tokens
|
||||
# & logprobs, leaving output text as-is for simplicity
|
||||
# (text mismatches may generate warnings but do not
|
||||
# cause the test to fail.)
|
||||
if num_outputs_0_skip_tokens < 0:
|
||||
raise ValueError("num_outputs_0_skip_tokens must be non-negative")
|
||||
output_ids_0 = output_ids_0[num_outputs_0_skip_tokens:]
|
||||
logprobs_0 = logprobs_0[num_outputs_0_skip_tokens:]
|
||||
|
||||
# Loop through generated tokens.
|
||||
for idx, (output_id_0,
|
||||
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
|
||||
@@ -110,3 +136,13 @@ def check_logprobs_close(
|
||||
warnings.simplefilter("always")
|
||||
|
||||
warnings.warn(fail_msg, stacklevel=2)
|
||||
|
||||
|
||||
class DecoderPromptType(Enum):
|
||||
'''
|
||||
For encoder/decoder models only -
|
||||
|
||||
'''
|
||||
CUSTOM = 1
|
||||
NONE = 2
|
||||
EMPTY_STR = 3
|
||||
|
||||
Reference in New Issue
Block a user