Support eos_token_id from generation_config.json (#4182)

This commit is contained in:
Simon Mo
2024-04-18 21:13:36 -07:00
committed by GitHub
parent 8a7a3e4436
commit a134ef6f5e
2 changed files with 30 additions and 3 deletions

View File

@@ -2,7 +2,7 @@
import copy
from enum import IntEnum
from functools import cached_property
from typing import Callable, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from pydantic import Field
@@ -271,6 +271,18 @@ class SamplingParams:
raise ValueError("best_of must be 1 when using greedy sampling."
f"Got {self.best_of}.")
def update_from_generation_config(
self, generation_config: Dict[str, Any]) -> None:
"""Update if there are non-default values from generation_config"""
# Update eos_token_id for generation
if eos_ids := generation_config.get("eos_token_id"):
# it can be either int or list of int
if isinstance(eos_ids, int):
eos_ids = [eos_ids]
original_stop_token_ids = set(self.stop_token_ids)
original_stop_token_ids.update(eos_ids)
self.stop_token_ids = list(original_stop_token_ids)
@cached_property
def sampling_type(self) -> SamplingType:
if self.use_beam_search: