Add skip_special_tokens sampling params (#1186)

This commit is contained in:
Dan Lord
2023-09-27 19:21:42 -07:00
committed by GitHub
parent 649aa730c5
commit 20f7cc4cde
4 changed files with 14 additions and 4 deletions

View File

@@ -387,7 +387,7 @@ class LLMEngine:
child_seqs.append((parent, parent))
for seq, _ in child_seqs:
self._decode_sequence(seq)
self._decode_sequence(seq, seq_group.sampling_params)
self._check_stop(seq, seq_group.sampling_params)
# Non-beam search case
@@ -621,7 +621,8 @@ class LLMEngine:
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
self.last_logging_time = now
def _decode_sequence(self, seq: Sequence) -> None:
def _decode_sequence(self, seq: Sequence,
sampling_params: SamplingParams) -> None:
"""Decodes the new token for a sequence."""
(new_tokens, new_output_text, prefix_offset,
read_offset) = detokenize_incrementally(
@@ -630,7 +631,7 @@ class LLMEngine:
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=True,
skip_special_tokens=sampling_params.skip_special_tokens,
)
if seq.tokens is None:
seq.tokens = new_tokens