Implement block copy kernel to optimize beam search (#32)

This commit is contained in:
Woosuk Kwon
2023-04-07 17:45:07 -07:00
committed by GitHub
parent a490aafa36
commit 0f40557af6
6 changed files with 154 additions and 48 deletions

View File

@@ -185,9 +185,10 @@ def _sample_from_generation_tokens(
vocab_size = logprobs.size(-1)
beam_width = len(seq_ids)
_, topk_ids = torch.topk(logprobs.flatten(), beam_width)
seq_idx = torch.div(topk_ids, vocab_size, rounding_mode='floor').tolist()
topk_ids = topk_ids.tolist()
seq_idx = [i // vocab_size for i in topk_ids]
beam_seq_ids = [seq_ids[i] for i in seq_idx]
token_ids = (topk_ids % vocab_size).tolist()
token_ids = [i % vocab_size for i in topk_ids]
beam_outputs: Dict[int, Tuple[int, int]] = {}
outstanding_beams: List[Tuple[int, int]] = []