Added logits processor API to sampling params (#1469)

This commit is contained in:
Noam Gat
2023-11-03 23:12:15 +02:00
committed by GitHub
parent 54ca1ba71d
commit 555bdcc5a3
3 changed files with 70 additions and 2 deletions

View File

@@ -1,7 +1,8 @@
"""Sampling parameters for text generation."""
from enum import IntEnum
from functools import cached_property
from typing import List, Optional, Union
from typing import Callable, List, Optional, Union
import torch
_SAMPLING_EPS = 1e-5
@@ -12,6 +13,12 @@ class SamplingType(IntEnum):
BEAM = 2
LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
"""LogitsProcessor is a function that takes a list of previously generated
tokens and a tensor of the logits for the next token, and returns a modified
tensor of logits to sample from."""
class SamplingParams:
"""Sampling parameters for text generation.
@@ -73,6 +80,8 @@ class SamplingParams:
skip_special_tokens: Whether to skip special tokens in the output.
spaces_between_special_tokens: Whether to add spaces between special
tokens in the output. Defaults to True.
logits_processors: List of functions that modify logits based on
previously generated tokens.
"""
def __init__(
@@ -96,6 +105,7 @@ class SamplingParams:
prompt_logprobs: Optional[int] = None,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[LogitsProcessor]] = None,
) -> None:
self.n = n
self.best_of = best_of if best_of is not None else n
@@ -124,7 +134,7 @@ class SamplingParams:
self.prompt_logprobs = prompt_logprobs
self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens
self.logits_processors = logits_processors
self._verify_args()
if self.use_beam_search:
self._verify_beam_search()