[BugFix] Fix min_tokens when eos_token_id is None (#4389)
Co-authored-by: DefTruth <31974251+deftruth@users.noreply.github.com>
This commit is contained in:
@@ -431,9 +431,10 @@ class LLMEngine:
|
||||
# Defensive copy of SamplingParams, which are used by the sampler,
|
||||
# this doesn't deep-copy LogitsProcessor objects
|
||||
sampling_params = sampling_params.clone()
|
||||
# inject the eos token id into the sampling_params to support min_tokens
|
||||
# Add the eos token id into the sampling_params to support min_tokens
|
||||
# processing
|
||||
sampling_params.eos_token_id = seq.eos_token_id
|
||||
if seq.eos_token_id is not None:
|
||||
sampling_params.all_stop_token_ids.add(seq.eos_token_id)
|
||||
sampling_params.update_from_generation_config(
|
||||
self.generation_config_fields)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user