Implement presence and frequency penalties (#95)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -31,6 +32,16 @@ class Sampler(nn.Module):
|
||||
# Remove paddings in vocab (if any).
|
||||
logits = logits[:, :self.vocab_size]
|
||||
|
||||
# Apply presence and frequency penalties.
|
||||
output_tokens = _get_output_tokens(input_metadata)
|
||||
assert len(output_tokens) == logits.shape[0]
|
||||
presence_penalties, frequency_penalties = _get_penalties(input_metadata)
|
||||
assert len(presence_penalties) == logits.shape[0]
|
||||
assert len(frequency_penalties) == logits.shape[0]
|
||||
logits = _apply_penalties(
|
||||
logits, output_tokens, presence_penalties, frequency_penalties,
|
||||
self.vocab_size)
|
||||
|
||||
# Apply temperature scaling.
|
||||
temperatures = _get_temperatures(input_metadata)
|
||||
assert len(temperatures) == logits.shape[0]
|
||||
@@ -43,16 +54,14 @@ class Sampler(nn.Module):
|
||||
# We use float32 for probabilities and log probabilities.
|
||||
# Compute the probabilities.
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
# Compute the log probabilities (before applying top-p).
|
||||
# Compute the log probabilities (before applying top-p and top-k).
|
||||
logprobs = torch.log(probs)
|
||||
|
||||
# Apply top-p and top-k truncation.
|
||||
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
|
||||
assert len(top_ps) == len(top_ks) == probs.shape[0]
|
||||
if any(p < 1.0 for p in top_ps) or any(k != -1 for k in top_ks):
|
||||
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
|
||||
k = torch.tensor(top_ks, dtype=torch.int, device=probs.device)
|
||||
probs = _apply_top_p_top_k(probs, p, k)
|
||||
probs = _apply_top_p_top_k(probs, top_ps, top_ks)
|
||||
|
||||
# Sample the next tokens.
|
||||
return _sample(probs, logprobs, input_metadata)
|
||||
@@ -72,6 +81,93 @@ def _prune_hidden_states(
|
||||
return hidden_states[last_token_indicies]
|
||||
|
||||
|
||||
def _get_penalties(
|
||||
input_metadata: InputMetadata,
|
||||
) -> Tuple[List[float], List[float]]:
|
||||
# Collect the presence and frequency penalties.
|
||||
presence_penalties: List[float] = []
|
||||
frequency_penalties: List[float] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
p = sampling_params.presence_penalty
|
||||
f = sampling_params.frequency_penalty
|
||||
if i < input_metadata.num_prompts:
|
||||
# A prompt input.
|
||||
presence_penalties.append(p)
|
||||
frequency_penalties.append(f)
|
||||
else:
|
||||
# A generation token.
|
||||
presence_penalties += [p] * len(seq_ids)
|
||||
frequency_penalties += [f] * len(seq_ids)
|
||||
return presence_penalties, frequency_penalties
|
||||
|
||||
|
||||
def _get_output_tokens(
|
||||
input_metadata: InputMetadata,
|
||||
) -> List[List[int]]:
|
||||
output_tokens: List[List[int]] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, _ = seq_group
|
||||
if i < input_metadata.num_prompts:
|
||||
# A prompt input.
|
||||
# NOTE: While the prompt input usually has no output tokens,
|
||||
# it may have output tokens in the case of recomputation.
|
||||
seq_id = seq_ids[0]
|
||||
seq_data = input_metadata.seq_data[seq_id]
|
||||
output_tokens.append(seq_data.output_token_ids)
|
||||
else:
|
||||
# A generation token.
|
||||
for seq_id in seq_ids:
|
||||
seq_data = input_metadata.seq_data[seq_id]
|
||||
output_tokens.append(seq_data.output_token_ids)
|
||||
return output_tokens
|
||||
|
||||
|
||||
def _apply_penalties(
|
||||
logits: torch.Tensor,
|
||||
output_tokens: List[List[int]],
|
||||
presence_penalties: List[float],
|
||||
frequency_penalties: List[float],
|
||||
vocab_size: int,
|
||||
) -> torch.Tensor:
|
||||
num_seqs = logits.shape[0]
|
||||
# Collect the indices of sequences that have non-zero penalties.
|
||||
indices = []
|
||||
for i in range(num_seqs):
|
||||
if not output_tokens[i]:
|
||||
continue
|
||||
p = presence_penalties[i]
|
||||
f = frequency_penalties[i]
|
||||
if p == 0.0 and f == 0.0:
|
||||
continue
|
||||
indices.append(i)
|
||||
|
||||
# Return early if all sequences have zero penalties.
|
||||
if not indices:
|
||||
return logits
|
||||
|
||||
bin_counts = []
|
||||
for i in indices:
|
||||
bin_counts.append(np.bincount(output_tokens[i], minlength=vocab_size))
|
||||
bin_counts = np.stack(bin_counts, axis=0)
|
||||
bin_counts = torch.from_numpy(bin_counts).to(dtype=logits.dtype,
|
||||
device=logits.device)
|
||||
|
||||
frequency_penalties = [frequency_penalties[i] for i in indices]
|
||||
frequency_penalties = torch.tensor(
|
||||
frequency_penalties, dtype=logits.dtype, device=logits.device)
|
||||
presence_penalties = [presence_penalties[i] for i in indices]
|
||||
presence_penalties = torch.tensor(
|
||||
presence_penalties, dtype=logits.dtype, device=logits.device)
|
||||
|
||||
# We follow the definition in OpenAI API.
|
||||
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
||||
logits[indices] -= frequency_penalties.unsqueeze(dim=1) * bin_counts
|
||||
presence_mask = (bin_counts > 0.0).to(dtype=logits.dtype)
|
||||
logits[indices] -= presence_penalties.unsqueeze(dim=1) * presence_mask
|
||||
return logits
|
||||
|
||||
|
||||
def _get_temperatures(
|
||||
input_metadata: InputMetadata,
|
||||
) -> List[float]:
|
||||
@@ -121,10 +217,11 @@ def _get_top_p_top_k(
|
||||
|
||||
def _apply_top_p_top_k(
|
||||
probs: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
top_ps: List[float],
|
||||
top_ks: List[int],
|
||||
) -> torch.Tensor:
|
||||
# TODO(woosuk): Optimize.
|
||||
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
|
||||
k = torch.tensor(top_ks, dtype=torch.int, device=probs.device)
|
||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||
|
||||
# Apply top-p.
|
||||
@@ -286,7 +383,8 @@ def _sample(
|
||||
|
||||
# Sample the next tokens.
|
||||
seq_logprobs = [
|
||||
input_metadata.seq_logprobs[seq_id] for seq_id in seq_ids]
|
||||
input_metadata.seq_data[seq_id].cumulative_logprobs
|
||||
for seq_id in seq_ids]
|
||||
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
|
||||
seq_ids, prob, logprob, seq_logprobs, sampling_params)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user