Support eos_token_id from generation_config.json (#4182)
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
import copy
|
||||
from enum import IntEnum
|
||||
from functools import cached_property
|
||||
from typing import Callable, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
@@ -271,6 +271,18 @@ class SamplingParams:
|
||||
raise ValueError("best_of must be 1 when using greedy sampling."
|
||||
f"Got {self.best_of}.")
|
||||
|
||||
def update_from_generation_config(
|
||||
self, generation_config: Dict[str, Any]) -> None:
|
||||
"""Update if there are non-default values from generation_config"""
|
||||
# Update eos_token_id for generation
|
||||
if eos_ids := generation_config.get("eos_token_id"):
|
||||
# it can be either int or list of int
|
||||
if isinstance(eos_ids, int):
|
||||
eos_ids = [eos_ids]
|
||||
original_stop_token_ids = set(self.stop_token_ids)
|
||||
original_stop_token_ids.update(eos_ids)
|
||||
self.stop_token_ids = list(original_stop_token_ids)
|
||||
|
||||
@cached_property
|
||||
def sampling_type(self) -> SamplingType:
|
||||
if self.use_beam_search:
|
||||
|
||||
Reference in New Issue
Block a user