Support top-k sampling (#94)
This commit is contained in:
@@ -46,12 +46,13 @@ class Sampler(nn.Module):
|
||||
# Compute the log probabilities (before applying top-p).
|
||||
logprobs = torch.log(probs)
|
||||
|
||||
# Apply top-p truncation.
|
||||
top_ps = _get_top_ps(input_metadata)
|
||||
assert len(top_ps) == probs.shape[0]
|
||||
if any(p < 1.0 for p in top_ps):
|
||||
# Apply top-p and top-k truncation.
|
||||
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
|
||||
assert len(top_ps) == len(top_ks) == probs.shape[0]
|
||||
if any(p < 1.0 for p in top_ps) or any(k != -1 for k in top_ks):
|
||||
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
|
||||
probs = _apply_top_p(probs, p)
|
||||
k = torch.tensor(top_ks, dtype=torch.int, device=probs.device)
|
||||
probs = _apply_top_p_top_k(probs, p, k)
|
||||
|
||||
# Sample the next tokens.
|
||||
return _sample(probs, logprobs, input_metadata)
|
||||
@@ -94,31 +95,51 @@ def _get_temperatures(
|
||||
return temperatures
|
||||
|
||||
|
||||
def _get_top_ps(
|
||||
def _get_top_p_top_k(
|
||||
input_metadata: InputMetadata,
|
||||
) -> List[float]:
|
||||
vocab_size: int,
|
||||
) -> Tuple[List[float], List[int]]:
|
||||
top_ps: List[float] = []
|
||||
top_ks: List[int] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
top_p = sampling_params.top_p
|
||||
# k should not be greater than the vocab size.
|
||||
top_k = min(sampling_params.top_k, vocab_size)
|
||||
# k=-1 means no truncation.
|
||||
top_k = vocab_size if top_k == -1 else top_k
|
||||
if i < input_metadata.num_prompts:
|
||||
# A prompt input.
|
||||
top_ps.append(sampling_params.top_p)
|
||||
top_ps.append(top_p)
|
||||
top_ks.append(top_k)
|
||||
else:
|
||||
# A generation token.
|
||||
top_ps += [sampling_params.top_p] * len(seq_ids)
|
||||
return top_ps
|
||||
top_ps += [top_p] * len(seq_ids)
|
||||
top_ks += [top_k] * len(seq_ids)
|
||||
return top_ps, top_ks
|
||||
|
||||
|
||||
def _apply_top_p(
|
||||
def _apply_top_p_top_k(
|
||||
probs: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# TODO(woosuk): Optimize.
|
||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||
|
||||
# Apply top-p.
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
|
||||
probs_sort[mask] = 0.0
|
||||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||
top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
|
||||
probs_sort[top_p_mask] = 0.0
|
||||
|
||||
# Apply top-k.
|
||||
# Create a mask for the top-k elements.
|
||||
top_k_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device)
|
||||
top_k_mask = top_k_mask.expand(probs_idx.shape[0], -1)
|
||||
top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
|
||||
probs_sort[top_k_mask] = 0.0
|
||||
|
||||
# Re-sort the probabilities.
|
||||
probs = torch.gather(
|
||||
probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1))
|
||||
return probs
|
||||
@@ -160,7 +181,7 @@ def _sample_from_prompt(
|
||||
next_token_id = torch.argmax(prob)
|
||||
next_token_ids = [next_token_id.item()]
|
||||
else:
|
||||
# Neucleus sampling.
|
||||
# Random sampling.
|
||||
# Sample n tokens for the prompt.
|
||||
n = sampling_params.n
|
||||
next_token_ids = torch.multinomial(
|
||||
@@ -218,7 +239,7 @@ def _sample_from_generation_tokens(
|
||||
next_token_ids = [next_token_id.item()]
|
||||
parent_seq_ids = seq_ids
|
||||
else:
|
||||
# Neucleus sampling.
|
||||
# Random sampling.
|
||||
# Sample 1 token for each sequence in the group.
|
||||
next_token_ids = torch.multinomial(
|
||||
probs, num_samples=1, replacement=True)
|
||||
|
||||
Reference in New Issue
Block a user