Added logits processor API to sampling params (#1469)
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
"""Sampling parameters for text generation."""
|
||||
from enum import IntEnum
|
||||
from functools import cached_property
|
||||
from typing import List, Optional, Union
|
||||
from typing import Callable, List, Optional, Union
|
||||
import torch
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
@@ -12,6 +13,12 @@ class SamplingType(IntEnum):
|
||||
BEAM = 2
|
||||
|
||||
|
||||
LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
|
||||
"""LogitsProcessor is a function that takes a list of previously generated
|
||||
tokens and a tensor of the logits for the next token, and returns a modified
|
||||
tensor of logits to sample from."""
|
||||
|
||||
|
||||
class SamplingParams:
|
||||
"""Sampling parameters for text generation.
|
||||
|
||||
@@ -73,6 +80,8 @@ class SamplingParams:
|
||||
skip_special_tokens: Whether to skip special tokens in the output.
|
||||
spaces_between_special_tokens: Whether to add spaces between special
|
||||
tokens in the output. Defaults to True.
|
||||
logits_processors: List of functions that modify logits based on
|
||||
previously generated tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -96,6 +105,7 @@ class SamplingParams:
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
skip_special_tokens: bool = True,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
logits_processors: Optional[List[LogitsProcessor]] = None,
|
||||
) -> None:
|
||||
self.n = n
|
||||
self.best_of = best_of if best_of is not None else n
|
||||
@@ -124,7 +134,7 @@ class SamplingParams:
|
||||
self.prompt_logprobs = prompt_logprobs
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||
|
||||
self.logits_processors = logits_processors
|
||||
self._verify_args()
|
||||
if self.use_beam_search:
|
||||
self._verify_beam_search()
|
||||
|
||||
Reference in New Issue
Block a user