[BugFix] Fix handling of stop strings and stop token ids (#3672)

This commit is contained in:
Nick Hill
2024-04-11 23:34:12 +01:00
committed by GitHub
parent 1e96c3341a
commit e46a60aa4c
8 changed files with 206 additions and 41 deletions

View File

@@ -166,6 +166,13 @@ class SamplingParams:
self.logits_processors = logits_processors
self.include_stop_str_in_output = include_stop_str_in_output
self.truncate_prompt_tokens = truncate_prompt_tokens
# Number of characters to hold back for stop string evaluation
# until sequence is finished.
if self.stop and not include_stop_str_in_output:
self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
else:
self.output_text_buffer_length = 0
self._verify_args()
if self.use_beam_search:
self._verify_beam_search()
@@ -226,6 +233,8 @@ class SamplingParams:
and self.truncate_prompt_tokens < 1):
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
f"got {self.truncate_prompt_tokens}")
if any(not stop_str for stop_str in self.stop):
raise ValueError("stop cannot contain an empty string.")
if self.stop and not self.detokenize:
raise ValueError(
"stop strings are only supported when detokenize is True. "