[Sampler] Vectorized sampling (simplified) (#1048)
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
@@ -1,9 +1,17 @@
|
||||
"""Sampling parameters for text generation."""
|
||||
from enum import IntEnum
|
||||
from functools import cached_property
|
||||
from typing import List, Optional, Union
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class SamplingType(IntEnum):
|
||||
GREEDY = 0
|
||||
RANDOM = 1
|
||||
BEAM = 2
|
||||
|
||||
|
||||
class SamplingParams:
|
||||
"""Sampling parameters for text generation.
|
||||
|
||||
@@ -166,6 +174,14 @@ class SamplingParams:
|
||||
if self.top_k != -1:
|
||||
raise ValueError("top_k must be -1 when using greedy sampling.")
|
||||
|
||||
@cached_property
|
||||
def sampling_type(self) -> SamplingType:
|
||||
if self.use_beam_search:
|
||||
return SamplingType.BEAM
|
||||
if self.temperature < _SAMPLING_EPS:
|
||||
return SamplingType.GREEDY
|
||||
return SamplingType.RANDOM
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SamplingParams(n={self.n}, "
|
||||
f"best_of={self.best_of}, "
|
||||
|
||||
Reference in New Issue
Block a user