[Sampler] Vectorized sampling (simplified) (#1048)

Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
Zhuohan Li
2023-09-22 17:48:04 -07:00
committed by GitHub
parent 8d926e91f1
commit 947b794146
3 changed files with 475 additions and 174 deletions

View File

@@ -1,9 +1,17 @@
"""Sampling parameters for text generation."""
from enum import IntEnum
from functools import cached_property
from typing import List, Optional, Union
_SAMPLING_EPS = 1e-5
class SamplingType(IntEnum):
GREEDY = 0
RANDOM = 1
BEAM = 2
class SamplingParams:
"""Sampling parameters for text generation.
@@ -166,6 +174,14 @@ class SamplingParams:
if self.top_k != -1:
raise ValueError("top_k must be -1 when using greedy sampling.")
@cached_property
def sampling_type(self) -> SamplingType:
if self.use_beam_search:
return SamplingType.BEAM
if self.temperature < _SAMPLING_EPS:
return SamplingType.GREEDY
return SamplingType.RANDOM
def __repr__(self) -> str:
return (f"SamplingParams(n={self.n}, "
f"best_of={self.best_of}, "