remove floats == 0 comparison (#285)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""Sampling parameters for text generation."""
|
||||
from typing import List, Optional, Union
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
class SamplingParams:
|
||||
"""Sampling parameters for text generation.
|
||||
@@ -71,7 +72,7 @@ class SamplingParams:
|
||||
self._verify_args()
|
||||
if self.use_beam_search:
|
||||
self._verity_beam_search()
|
||||
elif self.temperature == 0.0:
|
||||
elif self.temperature < _SAMPLING_EPS:
|
||||
# Zero temperature means greedy sampling.
|
||||
self._verify_greedy_sampling()
|
||||
|
||||
@@ -106,9 +107,9 @@ class SamplingParams:
|
||||
if self.best_of == 1:
|
||||
raise ValueError("best_of must be greater than 1 when using beam "
|
||||
f"search. Got {self.best_of}.")
|
||||
if self.temperature > 0.0:
|
||||
if self.temperature > _SAMPLING_EPS:
|
||||
raise ValueError("temperature must be 0 when using beam search.")
|
||||
if self.top_p < 1.0:
|
||||
if self.top_p < 1.0 - _SAMPLING_EPS:
|
||||
raise ValueError("top_p must be 1 when using beam search.")
|
||||
if self.top_k != -1:
|
||||
raise ValueError("top_k must be -1 when using beam search.")
|
||||
@@ -117,7 +118,7 @@ class SamplingParams:
|
||||
if self.best_of > 1:
|
||||
raise ValueError("best_of must be 1 when using greedy sampling."
|
||||
f"Got {self.best_of}.")
|
||||
if self.top_p < 1.0:
|
||||
if self.top_p < 1.0 - _SAMPLING_EPS:
|
||||
raise ValueError("top_p must be 1 when using greedy sampling.")
|
||||
if self.top_k != -1:
|
||||
raise ValueError("top_k must be -1 when using greedy sampling.")
|
||||
|
||||
Reference in New Issue
Block a user