Add skip_special_tokens sampling params (#1186)
This commit is contained in:
@@ -387,7 +387,7 @@ class LLMEngine:
|
||||
child_seqs.append((parent, parent))
|
||||
|
||||
for seq, _ in child_seqs:
|
||||
self._decode_sequence(seq)
|
||||
self._decode_sequence(seq, seq_group.sampling_params)
|
||||
self._check_stop(seq, seq_group.sampling_params)
|
||||
|
||||
# Non-beam search case
|
||||
@@ -621,7 +621,8 @@ class LLMEngine:
|
||||
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
||||
self.last_logging_time = now
|
||||
|
||||
def _decode_sequence(self, seq: Sequence) -> None:
|
||||
def _decode_sequence(self, seq: Sequence,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
"""Decodes the new token for a sequence."""
|
||||
(new_tokens, new_output_text, prefix_offset,
|
||||
read_offset) = detokenize_incrementally(
|
||||
@@ -630,7 +631,7 @@ class LLMEngine:
|
||||
prev_tokens=seq.tokens,
|
||||
prefix_offset=seq.prefix_offset,
|
||||
read_offset=seq.read_offset,
|
||||
skip_special_tokens=True,
|
||||
skip_special_tokens=sampling_params.skip_special_tokens,
|
||||
)
|
||||
if seq.tokens is None:
|
||||
seq.tokens = new_tokens
|
||||
|
||||
Reference in New Issue
Block a user