[Refactor] Simplify BOS/EOS token handling (#34435)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -223,6 +223,7 @@ class SamplingParams(
|
||||
# The below fields are not supposed to be used as an input.
|
||||
# They are set in post_init.
|
||||
output_text_buffer_length: int = 0
|
||||
_eos_token_id: int | None = None
|
||||
_all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
|
||||
|
||||
# Fields used to construct logits processors
|
||||
@@ -477,24 +478,26 @@ class SamplingParams(
|
||||
def update_from_generation_config(
|
||||
self,
|
||||
generation_config: dict[str, Any],
|
||||
model_eos_token_id: int | None = None,
|
||||
eos_token_id: int | None = None,
|
||||
) -> None:
|
||||
"""Update if there are non-default values from generation_config"""
|
||||
if not self.ignore_eos:
|
||||
self._eos_token_id = eos_token_id
|
||||
|
||||
if model_eos_token_id is not None:
|
||||
if eos_token_id is not None:
|
||||
# Add the eos token id into the sampling_params to support
|
||||
# min_tokens processing.
|
||||
self._all_stop_token_ids.add(model_eos_token_id)
|
||||
self._all_stop_token_ids.add(eos_token_id)
|
||||
|
||||
# Update eos_token_id for generation
|
||||
if (eos_ids := generation_config.get("eos_token_id")) is not None:
|
||||
# it can be either int or list of int
|
||||
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
|
||||
if model_eos_token_id is not None:
|
||||
if eos_token_id is not None:
|
||||
# We don't need to include the primary eos_token_id in
|
||||
# stop_token_ids since it's handled separately for stopping
|
||||
# purposes.
|
||||
eos_ids.discard(model_eos_token_id)
|
||||
eos_ids.discard(eos_token_id)
|
||||
if eos_ids:
|
||||
self._all_stop_token_ids.update(eos_ids)
|
||||
if not self.ignore_eos:
|
||||
@@ -550,6 +553,10 @@ class SamplingParams(
|
||||
return SamplingType.RANDOM_SEED
|
||||
return SamplingType.RANDOM
|
||||
|
||||
@property
|
||||
def eos_token_id(self) -> int | None:
|
||||
return self._eos_token_id
|
||||
|
||||
@property
|
||||
def all_stop_token_ids(self) -> set[int]:
|
||||
return self._all_stop_token_ids
|
||||
|
||||
Reference in New Issue
Block a user