[Refactor] Simplify BOS/EOS token handling (#34435)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-02-13 10:18:24 +08:00
committed by GitHub
parent 04ea31baab
commit ea5ff3a1f6
29 changed files with 123 additions and 123 deletions

View File

@@ -223,6 +223,7 @@ class SamplingParams(
# The below fields are not supposed to be used as an input.
# They are set in post_init.
output_text_buffer_length: int = 0
_eos_token_id: int | None = None
_all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
# Fields used to construct logits processors
@@ -477,24 +478,26 @@ class SamplingParams(
def update_from_generation_config(
self,
generation_config: dict[str, Any],
model_eos_token_id: int | None = None,
eos_token_id: int | None = None,
) -> None:
"""Update if there are non-default values from generation_config"""
if not self.ignore_eos:
self._eos_token_id = eos_token_id
if model_eos_token_id is not None:
if eos_token_id is not None:
# Add the eos token id into the sampling_params to support
# min_tokens processing.
self._all_stop_token_ids.add(model_eos_token_id)
self._all_stop_token_ids.add(eos_token_id)
# Update eos_token_id for generation
if (eos_ids := generation_config.get("eos_token_id")) is not None:
# it can be either int or list of int
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
if model_eos_token_id is not None:
if eos_token_id is not None:
# We don't need to include the primary eos_token_id in
# stop_token_ids since it's handled separately for stopping
# purposes.
eos_ids.discard(model_eos_token_id)
eos_ids.discard(eos_token_id)
if eos_ids:
self._all_stop_token_ids.update(eos_ids)
if not self.ignore_eos:
@@ -550,6 +553,10 @@ class SamplingParams(
return SamplingType.RANDOM_SEED
return SamplingType.RANDOM
@property
def eos_token_id(self) -> int | None:
return self._eos_token_id
@property
def all_stop_token_ids(self) -> set[int]:
return self._all_stop_token_ids