[V1] Optimize handling of sampling metadata and req_ids list (#13244)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -26,17 +26,13 @@ def create_logits_tensor(token_ids: List[int],
|
||||
def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
|
||||
batch_size = len(spec_tokens)
|
||||
return SamplingMetadata(
|
||||
temperature=0.0,
|
||||
temperature=torch.tensor([]),
|
||||
all_greedy=True,
|
||||
all_random=False,
|
||||
rejection_sampling=True,
|
||||
spec_token_ids=spec_tokens,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
no_top_p=False,
|
||||
no_top_k=False,
|
||||
min_p=torch.empty(batch_size, ),
|
||||
no_min_p=True,
|
||||
generators={},
|
||||
max_num_logprobs=0,
|
||||
no_penalties=False,
|
||||
@@ -45,8 +41,7 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
|
||||
presence_penalties=torch.tensor([]),
|
||||
repetition_penalties=torch.tensor([]),
|
||||
output_token_ids=[],
|
||||
min_tokens=[],
|
||||
stop_token_ids=[],
|
||||
min_tokens={},
|
||||
logit_bias=[None] * batch_size,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user