Support beam search & parallel generation (#7)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
@@ -25,12 +25,35 @@ class Frontend:
|
||||
def query(
|
||||
self,
|
||||
prompt: str,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
n: int = 1,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
use_beam_search: bool = False,
|
||||
stop_token_ids: Set[int] = set(),
|
||||
max_num_steps: int = 16, # From OpenAI API.
|
||||
num_logprobs: int = 0,
|
||||
context_window_size: Optional[int] = None,
|
||||
) -> None:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
token_ids: List[int] = self.tokenizer.encode(prompt)
|
||||
# Stop when we see an EOS token.
|
||||
stop_token_ids.add(self.tokenizer.eos_token_id)
|
||||
sampling_params = SamplingParams(
|
||||
n=n,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
use_beam_search=use_beam_search,
|
||||
stop_token_ids=stop_token_ids,
|
||||
max_num_steps=max_num_steps,
|
||||
num_logprobs=num_logprobs,
|
||||
context_window_size=context_window_size,
|
||||
)
|
||||
token_ids = self.tokenizer.encode(prompt)
|
||||
self._add_query(token_ids, sampling_params)
|
||||
|
||||
def _add_query(
|
||||
self,
|
||||
token_ids: List[int],
|
||||
sampling_params: SamplingParams,
|
||||
) -> None:
|
||||
seqs: List[Sequence] = []
|
||||
for _ in range(sampling_params.n):
|
||||
seq_id = next(self.seq_counter)
|
||||
|
||||
Reference in New Issue
Block a user