[V1] support min_tokens for detokener (#22014)
Signed-off-by: calvin chen <wen.chen@dynamia.ai> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
50
tests/detokenizer/test_min_tokens.py
Normal file
50
tests/detokenizer/test_min_tokens.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from vllm import SamplingParams
|
||||||
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
|
from vllm.v1.engine.detokenizer import FastIncrementalDetokenizer
|
||||||
|
|
||||||
|
PROMPT = "Hello, my name is Lee, and I'm a student in the " + \
|
||||||
|
"college of engineering"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("min_tokens,stop,truth", [
|
||||||
|
(0, None, " is Lee, and I'm a student in the college of engineering"),
|
||||||
|
(0, "e", " is L"),
|
||||||
|
(5, "e", " is Lee, and I'm a stud"),
|
||||||
|
])
|
||||||
|
def test_min_tokens_with_stop(min_tokens: int, stop: str, truth: str):
|
||||||
|
"""Test for a specific min_tokens and stop.
|
||||||
|
|
||||||
|
See https://github.com/vllm-project/vllm/pull/22014
|
||||||
|
"""
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
|
||||||
|
all_prompt_ids = tokenizer(PROMPT, add_special_tokens=False).input_ids
|
||||||
|
|
||||||
|
# The prompt is "Hello, my name is"
|
||||||
|
prompt_token_ids = all_prompt_ids[:4]
|
||||||
|
params = SamplingParams(
|
||||||
|
stop=stop,
|
||||||
|
min_tokens=min_tokens,
|
||||||
|
)
|
||||||
|
request = EngineCoreRequest("",
|
||||||
|
prompt_token_ids,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
params,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0.0,
|
||||||
|
None,
|
||||||
|
cache_salt=None,
|
||||||
|
data_parallel_rank=None)
|
||||||
|
|
||||||
|
detokenizer = FastIncrementalDetokenizer(tokenizer, request)
|
||||||
|
|
||||||
|
detokenizer.update(all_prompt_ids[4:], False)
|
||||||
|
assert detokenizer.output_text == truth
|
||||||
@@ -74,6 +74,7 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
|
|||||||
params = request.sampling_params
|
params = request.sampling_params
|
||||||
assert params is not None
|
assert params is not None
|
||||||
self.stop = stop = params.stop
|
self.stop = stop = params.stop
|
||||||
|
self.min_tokens = params.min_tokens
|
||||||
self.include_stop_str_in_output = params.include_stop_str_in_output
|
self.include_stop_str_in_output = params.include_stop_str_in_output
|
||||||
|
|
||||||
# Number of chars to hold back when stop strings are to be excluded
|
# Number of chars to hold back when stop strings are to be excluded
|
||||||
@@ -111,10 +112,14 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
|
|||||||
# 1) Detokenize the new token ids incrementally.
|
# 1) Detokenize the new token ids incrementally.
|
||||||
# TODO(woosuk): This method becomes very inefficient when the number of
|
# TODO(woosuk): This method becomes very inefficient when the number of
|
||||||
# new_token_ids is more than 1. We need to optimize this.
|
# new_token_ids is more than 1. We need to optimize this.
|
||||||
offset_before = len(self.output_text)
|
stop_check_offset = len(self.output_text)
|
||||||
for new_token_id in new_token_ids:
|
for new_token_id in new_token_ids:
|
||||||
self.token_ids.append(new_token_id)
|
self.token_ids.append(new_token_id)
|
||||||
self.output_text += self.decode_next(new_token_id)
|
self.output_text += self.decode_next(new_token_id)
|
||||||
|
# Support min_tokens, see https://github.com/vllm-project/vllm/pull/22014
|
||||||
|
if self.min_tokens and len(
|
||||||
|
self.output_token_ids) <= self.min_tokens:
|
||||||
|
stop_check_offset = len(self.output_text)
|
||||||
|
|
||||||
if stop_terminated:
|
if stop_terminated:
|
||||||
if skipped_stop_token_id is not None:
|
if skipped_stop_token_id is not None:
|
||||||
@@ -125,10 +130,10 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
|
|||||||
|
|
||||||
# 2) Evaluate stop strings.
|
# 2) Evaluate stop strings.
|
||||||
stop_string = None
|
stop_string = None
|
||||||
if self.stop:
|
if self.stop and len(self.output_token_ids) > self.min_tokens:
|
||||||
stop = StopChecker.check_stop_strings(
|
stop = StopChecker.check_stop_strings(
|
||||||
output_text=self.output_text,
|
output_text=self.output_text,
|
||||||
new_char_count=len(self.output_text) - offset_before,
|
new_char_count=len(self.output_text) - stop_check_offset,
|
||||||
stop=self.stop,
|
stop=self.stop,
|
||||||
include_in_output=self.include_stop_str_in_output,
|
include_in_output=self.include_stop_str_in_output,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user