[Core] Dynamic image size support for VLMs (#5276)

Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: Xiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: ywang96 <ywang@roblox.com>
Co-authored-by: xwjiang2010 <87673679+xwjiang2010@users.noreply.github.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This commit is contained in:
Cyrus Leung
2024-07-03 11:34:00 +08:00
committed by GitHub
parent 482045ee77
commit 9831aec49f
38 changed files with 1453 additions and 664 deletions

View File

@@ -1,11 +1,18 @@
from typing import Dict, List, Tuple
import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union
from vllm.sequence import SampleLogprobs
TokensText = Tuple[List[int], str]
def check_outputs_equal(outputs_0_lst: List[TokensText],
outputs_1_lst: List[TokensText], name_0: str,
name_1: str):
def check_outputs_equal(
*,
outputs_0_lst: Sequence[TokensText],
outputs_1_lst: Sequence[TokensText],
name_0: str,
name_1: str,
):
"""
Compare the two sequences generated by different models,
which should be equal.
@@ -18,20 +25,28 @@ def check_outputs_equal(outputs_0_lst: List[TokensText],
output_ids_0, output_str_0 = outputs_0
output_ids_1, output_str_1 = outputs_1
assert output_str_0 == output_str_1, (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
assert output_ids_0 == output_ids_1, (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
# The text and token outputs should exactly match
fail_msg = (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
assert output_str_0 == output_str_1, fail_msg
assert output_ids_0 == output_ids_1, fail_msg
TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]]
TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
float]],
SampleLogprobs]]]
def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs],
outputs_1_lst: List[TokensTextLogprobs], name_0: str,
name_1: str):
def check_logprobs_close(
*,
outputs_0_lst: Sequence[TokensTextLogprobs],
outputs_1_lst: Sequence[TokensTextLogprobs],
name_0: str,
name_1: str,
warn_on_mismatch: bool = True,
):
"""
Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal.
@@ -45,21 +60,52 @@ def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs],
output_ids_0, output_str_0, logprobs_0 = outputs_0
output_ids_1, output_str_1, logprobs_1 = outputs_1
if logprobs_0 is None:
logprobs_0 = [None] * len(output_ids_0)
if logprobs_1 is None:
logprobs_1 = [None] * len(output_ids_1)
# Loop through generated tokens.
for idx, (output_id_0,
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
# If generated tokens don't match, then
if output_id_0 != output_id_1:
logprobs_elem_0 = logprobs_0[idx]
logprobs_elem_1 = logprobs_1[idx]
# Each predicted token must be in top N logprobs of the other
assert output_id_0 in logprobs_1[idx], (
fail_msg = (
f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
assert output_id_1 in logprobs_0[idx], (
f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
f"\n{name_0}:\t{output_str_0!r}\t{logprobs_elem_0}"
f"\n{name_1}:\t{output_str_1!r}\t{logprobs_elem_1}")
assert logprobs_elem_0 is not None, fail_msg
assert logprobs_elem_1 is not None, fail_msg
assert output_id_0 in logprobs_elem_1, fail_msg
assert output_id_1 in logprobs_elem_0, fail_msg
if warn_on_mismatch:
with warnings.catch_warnings():
# This ensures that repeated warnings are shown
# in the output, not just the first occurrence
warnings.simplefilter("always")
warnings.warn(fail_msg, stacklevel=2)
# Break out since sequences will now diverge.
break
else:
if output_str_0 != output_str_1 and warn_on_mismatch:
# The token outputs exactly match,
# so the text outputs should exactly match as well
fail_msg = (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
with warnings.catch_warnings():
# This ensures that repeated warnings are shown
# in the output, not just the first occurrence
warnings.simplefilter("always")
warnings.warn(fail_msg, stacklevel=2)