[V1] support min_tokens for detokener (#22014)
Signed-off-by: calvin chen <wen.chen@dynamia.ai> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -74,6 +74,7 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
|
||||
params = request.sampling_params
|
||||
assert params is not None
|
||||
self.stop = stop = params.stop
|
||||
self.min_tokens = params.min_tokens
|
||||
self.include_stop_str_in_output = params.include_stop_str_in_output
|
||||
|
||||
# Number of chars to hold back when stop strings are to be excluded
|
||||
@@ -111,10 +112,14 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
|
||||
# 1) Detokenize the new token ids incrementally.
|
||||
# TODO(woosuk): This method becomes very inefficient when the number of
|
||||
# new_token_ids is more than 1. We need to optimize this.
|
||||
offset_before = len(self.output_text)
|
||||
stop_check_offset = len(self.output_text)
|
||||
for new_token_id in new_token_ids:
|
||||
self.token_ids.append(new_token_id)
|
||||
self.output_text += self.decode_next(new_token_id)
|
||||
# Support min_tokens, see https://github.com/vllm-project/vllm/pull/22014
|
||||
if self.min_tokens and len(
|
||||
self.output_token_ids) <= self.min_tokens:
|
||||
stop_check_offset = len(self.output_text)
|
||||
|
||||
if stop_terminated:
|
||||
if skipped_stop_token_id is not None:
|
||||
@@ -125,10 +130,10 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
|
||||
|
||||
# 2) Evaluate stop strings.
|
||||
stop_string = None
|
||||
if self.stop:
|
||||
if self.stop and len(self.output_token_ids) > self.min_tokens:
|
||||
stop = StopChecker.check_stop_strings(
|
||||
output_text=self.output_text,
|
||||
new_char_count=len(self.output_text) - offset_before,
|
||||
new_char_count=len(self.output_text) - stop_check_offset,
|
||||
stop=self.stop,
|
||||
include_in_output=self.include_stop_str_in_output,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user