[V1] Optimize handling of sampling metadata and req_ids list (#13244)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-02-18 12:15:33 -08:00
committed by GitHub
parent a4d577b379
commit 30172b4947
15 changed files with 254 additions and 297 deletions

View File

@@ -26,17 +26,13 @@ def create_logits_tensor(token_ids: List[int],
def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
batch_size = len(spec_tokens)
return SamplingMetadata(
temperature=0.0,
temperature=torch.tensor([]),
all_greedy=True,
all_random=False,
rejection_sampling=True,
spec_token_ids=spec_tokens,
top_p=None,
top_k=None,
no_top_p=False,
no_top_k=False,
min_p=torch.empty(batch_size, ),
no_min_p=True,
generators={},
max_num_logprobs=0,
no_penalties=False,
@@ -45,8 +41,7 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
presence_penalties=torch.tensor([]),
repetition_penalties=torch.tensor([]),
output_token_ids=[],
min_tokens=[],
stop_token_ids=[],
min_tokens={},
logit_bias=[None] * batch_size,
)