[BugFix] Fix min_tokens behaviour for multiple eos tokens (#5849)

This commit is contained in:
Nick Hill
2024-06-27 11:31:11 -07:00
committed by GitHub
parent 691e29ecf3
commit 365791ff81
2 changed files with 23 additions and 13 deletions

View File

@@ -606,12 +606,9 @@ class LLMEngine:
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()
# Add the eos token id into the sampling_params to support min_tokens
# processing
if seq.eos_token_id is not None:
sampling_params.all_stop_token_ids.add(seq.eos_token_id)
sampling_params.update_from_generation_config(
self.generation_config_fields)
self.generation_config_fields, seq.eos_token_id)
# Create the sequence group.
seq_group = SequenceGroup(