[V1] Optimize handling of sampling metadata and req_ids list (#13244)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-02-18 12:15:33 -08:00
committed by GitHub
parent a4d577b379
commit 30172b4947
15 changed files with 254 additions and 297 deletions

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Dict, List, Optional, Set
from typing import Dict, List, Optional, Set, Tuple
import torch
@@ -12,15 +12,13 @@ class SamplingMetadata:
temperature: torch.Tensor
all_greedy: bool
all_random: bool
rejection_sampling: bool
spec_token_ids: List[List[int]]
top_p: torch.Tensor
top_k: torch.Tensor
no_top_p: bool
no_top_k: bool
min_p: torch.Tensor
no_min_p: bool
# None when there are no speculated tokens.
spec_token_ids: Optional[List[List[int]]]
top_p: Optional[torch.Tensor]
top_k: Optional[torch.Tensor]
min_p: Optional[torch.Tensor]
generators: Dict[int, torch.Generator]
@@ -34,7 +32,8 @@ class SamplingMetadata:
repetition_penalties: torch.Tensor
output_token_ids: List[List[int]]
min_tokens: List[int]
stop_token_ids: List[Set[int]]
# req_index -> (min_tokens, stop_token_ids)
min_tokens: Dict[int, Tuple[int, Set[int]]]
logit_bias: List[Optional[Dict[int, float]]]