Rename variables and methods (#91)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Optional, Set, Dict
|
||||
from typing import Dict, Set
|
||||
|
||||
|
||||
class SamplingParams:
|
||||
@@ -12,7 +12,6 @@ class SamplingParams:
|
||||
stop_token_ids: Set[int],
|
||||
max_num_steps: int,
|
||||
num_logprobs: int,
|
||||
context_window_size: Optional[int],
|
||||
) -> None:
|
||||
if n < 1:
|
||||
raise ValueError(f'n must be at least 1, got {n}.')
|
||||
@@ -27,10 +26,6 @@ class SamplingParams:
|
||||
if num_logprobs < 0:
|
||||
raise ValueError(
|
||||
f'num_logprobs must be non-negative, got {num_logprobs}.')
|
||||
if context_window_size is not None and context_window_size < 0:
|
||||
raise ValueError(
|
||||
'context_window_size must be non-negative, '
|
||||
f'got {context_window_size}.')
|
||||
|
||||
if use_beam_search:
|
||||
if n == 1:
|
||||
@@ -58,7 +53,6 @@ class SamplingParams:
|
||||
self.stop_token_ids = stop_token_ids
|
||||
self.max_num_steps = max_num_steps
|
||||
self.num_logprobs = num_logprobs
|
||||
self.context_window_size = context_window_size
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f'SamplingParams(n={self.n}, '
|
||||
@@ -67,8 +61,7 @@ class SamplingParams:
|
||||
f'use_beam_search={self.use_beam_search}, '
|
||||
f'stop_token_ids={self.stop_token_ids}, '
|
||||
f'max_num_steps={self.max_num_steps}, '
|
||||
f'num_logprobs={self.num_logprobs}, '
|
||||
f'context_window_size={self.context_window_size})')
|
||||
f'num_logprobs={self.num_logprobs}')
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: Dict) -> 'SamplingParams':
|
||||
@@ -80,5 +73,4 @@ class SamplingParams:
|
||||
stop_token_ids=set(d.get('stop_token_ids', set())),
|
||||
max_num_steps=d.get('max_num_steps', 16),
|
||||
num_logprobs=d.get('num_logprobs', 0),
|
||||
context_window_size=d.get('context_window_size', None),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user