[V1] Adding min tokens/repetition/presence/frequence penalties to V1 sampler (#10681)

Signed-off-by: Sourashis Roy <sroy@roblox.com>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
sroy745
2024-12-26 02:02:58 -08:00
committed by GitHub
parent 7492a36207
commit dcb1a944d4
11 changed files with 879 additions and 49 deletions

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Dict
from typing import Dict, List, Optional, Set
import torch
@@ -19,3 +19,13 @@ class SamplingMetadata:
generators: Dict[int, torch.Generator]
max_num_logprobs: int
no_penalties: bool
prompt_token_ids: Optional[torch.Tensor]
frequency_penalties: torch.Tensor
presence_penalties: torch.Tensor
repetition_penalties: torch.Tensor
output_token_ids: List[List[int]]
min_tokens: List[int]
stop_token_ids: List[Set[int]]