Add skip_special_tokens sampling params (#1186)

This commit is contained in:
Dan Lord
2023-09-27 19:21:42 -07:00
committed by GitHub
parent 649aa730c5
commit 20f7cc4cde
4 changed files with 14 additions and 4 deletions

View File

@@ -60,6 +60,8 @@ class SamplingParams:
tokens after the EOS token is generated.
max_tokens: Maximum number of tokens to generate per output sequence.
logprobs: Number of log probabilities to return per output token.
skip_special_tokens: Whether to skip special tokens in the output.
Defaults to true.
"""
def __init__(
@@ -79,6 +81,7 @@ class SamplingParams:
ignore_eos: bool = False,
max_tokens: int = 16,
logprobs: Optional[int] = None,
skip_special_tokens: bool = True,
) -> None:
self.n = n
self.best_of = best_of if best_of is not None else n
@@ -103,6 +106,7 @@ class SamplingParams:
self.ignore_eos = ignore_eos
self.max_tokens = max_tokens
self.logprobs = logprobs
self.skip_special_tokens = skip_special_tokens
self._verify_args()
if self.use_beam_search:
@@ -196,4 +200,5 @@ class SamplingParams:
f"stop={self.stop}, "
f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, "
f"logprobs={self.logprobs})")
f"logprobs={self.logprobs}, "
f"skip_special_tokens={self.skip_special_tokens})")