[BugFix] Fix min_tokens when eos_token_id is None (#4389)

Co-authored-by: DefTruth <31974251+deftruth@users.noreply.github.com>
This commit is contained in:
Nick Hill
2024-04-27 09:52:46 -07:00
committed by GitHub
parent dfea173148
commit 81661da7b2
4 changed files with 14 additions and 18 deletions

View File

@@ -431,9 +431,10 @@ class LLMEngine:
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()
# inject the eos token id into the sampling_params to support min_tokens
# Add the eos token id into the sampling_params to support min_tokens
# processing
sampling_params.eos_token_id = seq.eos_token_id
if seq.eos_token_id is not None:
sampling_params.all_stop_token_ids.add(seq.eos_token_id)
sampling_params.update_from_generation_config(
self.generation_config_fields)