remove floats == 0 comparison (#285)

This commit is contained in:
Lily Liu
2023-06-28 14:11:51 -07:00
committed by GitHub
parent 4338cc4750
commit 425040d4c1
2 changed files with 11 additions and 9 deletions

View File

@@ -1,6 +1,7 @@
"""Sampling parameters for text generation."""
from typing import List, Optional, Union
_SAMPLING_EPS = 1e-5
class SamplingParams:
"""Sampling parameters for text generation.
@@ -71,7 +72,7 @@ class SamplingParams:
self._verify_args()
if self.use_beam_search:
self._verity_beam_search()
elif self.temperature == 0.0:
elif self.temperature < _SAMPLING_EPS:
# Zero temperature means greedy sampling.
self._verify_greedy_sampling()
@@ -106,9 +107,9 @@ class SamplingParams:
if self.best_of == 1:
raise ValueError("best_of must be greater than 1 when using beam "
f"search. Got {self.best_of}.")
if self.temperature > 0.0:
if self.temperature > _SAMPLING_EPS:
raise ValueError("temperature must be 0 when using beam search.")
if self.top_p < 1.0:
if self.top_p < 1.0 - _SAMPLING_EPS:
raise ValueError("top_p must be 1 when using beam search.")
if self.top_k != -1:
raise ValueError("top_k must be -1 when using beam search.")
@@ -117,7 +118,7 @@ class SamplingParams:
if self.best_of > 1:
raise ValueError("best_of must be 1 when using greedy sampling."
f"Got {self.best_of}.")
if self.top_p < 1.0:
if self.top_p < 1.0 - _SAMPLING_EPS:
raise ValueError("top_p must be 1 when using greedy sampling.")
if self.top_k != -1:
raise ValueError("top_k must be -1 when using greedy sampling.")