[V1] Optimize handling of sampling metadata and req_ids list (#13244)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Dict, List, Set, Tuple
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -41,7 +41,7 @@ def _remove_requests(
|
||||
for index in req_indices_to_remove:
|
||||
input_batch.remove_request(reqs[index].req_id)
|
||||
req_ids_to_remove.add(reqs[index].req_id)
|
||||
return (req_ids_to_remove, req_indices_to_remove_list)
|
||||
return req_ids_to_remove, req_indices_to_remove_list
|
||||
|
||||
|
||||
def _construct_expected_sampling_metadata(
|
||||
@@ -64,8 +64,7 @@ def _construct_expected_sampling_metadata(
|
||||
top_p = [0.0 for _ in range(num_reqs)]
|
||||
min_p = [0.0 for _ in range(num_reqs)]
|
||||
temperature = [0.0 for _ in range(num_reqs)]
|
||||
stop_token_ids: List[Set[int]] = [set() for _ in range(num_reqs)]
|
||||
min_tokens = [0 for _ in range(num_reqs)]
|
||||
min_tokens = {}
|
||||
logit_bias = [None] * num_reqs
|
||||
for req in reqs:
|
||||
if req.req_id not in req_ids_retained:
|
||||
@@ -83,22 +82,21 @@ def _construct_expected_sampling_metadata(
|
||||
top_p[index_in_input_batch] = req.sampling_params.top_p
|
||||
min_p[index_in_input_batch] = req.sampling_params.min_p
|
||||
temperature[index_in_input_batch] = req.sampling_params.temperature
|
||||
stop_token_ids[
|
||||
index_in_input_batch] = req.sampling_params.all_stop_token_ids
|
||||
min_tokens[index_in_input_batch] = req.sampling_params.min_tokens
|
||||
min_tokens[index_in_input_batch] = (
|
||||
req.sampling_params.min_tokens,
|
||||
req.sampling_params.all_stop_token_ids)
|
||||
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
|
||||
return SamplingMetadata(
|
||||
temperature=torch.tensor(temperature, dtype=torch.float,
|
||||
device=device),
|
||||
all_greedy=False,
|
||||
all_random=True,
|
||||
rejection_sampling=False,
|
||||
top_p=torch.tensor(top_p, dtype=torch.float, device=device),
|
||||
top_k=torch.tensor(top_k, dtype=torch.int, device=device),
|
||||
no_top_p=all(x == 1.0 for x in top_p),
|
||||
no_top_k=all(x == 0 for x in top_k),
|
||||
min_p=torch.tensor(min_p, dtype=torch.float, device=device),
|
||||
no_min_p=all(x == 0.0 for x in min_p),
|
||||
top_p=None if all(x == 1.0 for x in top_p) else torch.tensor(
|
||||
top_p, dtype=torch.float, device=device),
|
||||
top_k=None if all(x == 0 for x in top_k) else torch.tensor(
|
||||
top_k, dtype=torch.int, device=device),
|
||||
min_p=None if all(x == 0.0 for x in min_p) else torch.tensor(
|
||||
min_p, dtype=torch.float, device=device),
|
||||
generators={},
|
||||
max_num_logprobs=0,
|
||||
prompt_token_ids=make_tensor_with_pad(
|
||||
@@ -117,9 +115,8 @@ def _construct_expected_sampling_metadata(
|
||||
dtype=torch.float,
|
||||
device=device),
|
||||
output_token_ids=output_token_ids,
|
||||
spec_token_ids=[],
|
||||
spec_token_ids=None,
|
||||
min_tokens=min_tokens,
|
||||
stop_token_ids=stop_token_ids,
|
||||
no_penalties=(all(x == 0 for x in presence_penalties)
|
||||
and all(x == 0 for x in frequency_penalties)
|
||||
and all(x == 1 for x in repetition_penalties)),
|
||||
@@ -206,8 +203,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
input_batch.condense(req_indices_to_remove)
|
||||
|
||||
# Generate the sampling metadata
|
||||
sampling_metadata = input_batch.make_sampling_metadata(
|
||||
req_id_output_token_ids, req_id_to_spec_token_ids={}, skip_copy=False)
|
||||
sampling_metadata = input_batch._make_sampling_metadata()
|
||||
|
||||
# Create expected output.
|
||||
expected_sampling_metadata = _construct_expected_sampling_metadata(
|
||||
@@ -216,13 +212,16 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
input_batch.req_id_to_index,
|
||||
device=torch.device(device))
|
||||
|
||||
def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool:
|
||||
return (t1 is None
|
||||
and t2 is None) or (t1 is not None and t2 is not None
|
||||
and torch.allclose(t1, t2))
|
||||
|
||||
# Assert the actual and expected output.
|
||||
assert torch.allclose(expected_sampling_metadata.temperature,
|
||||
sampling_metadata.temperature)
|
||||
assert torch.allclose(expected_sampling_metadata.top_p,
|
||||
sampling_metadata.top_p)
|
||||
assert torch.allclose(expected_sampling_metadata.top_k,
|
||||
sampling_metadata.top_k)
|
||||
assert same(expected_sampling_metadata.top_p, sampling_metadata.top_p)
|
||||
assert same(expected_sampling_metadata.top_k, sampling_metadata.top_k)
|
||||
assert torch.allclose(
|
||||
expected_sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.frequency_penalties,
|
||||
@@ -240,10 +239,6 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
assert (expected_sampling_metadata.output_token_ids ==
|
||||
sampling_metadata.output_token_ids)
|
||||
assert expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens
|
||||
assert expected_sampling_metadata.stop_token_ids == \
|
||||
sampling_metadata.stop_token_ids
|
||||
assert expected_sampling_metadata.no_penalties == \
|
||||
sampling_metadata.no_penalties
|
||||
assert expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p
|
||||
assert expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k
|
||||
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias
|
||||
|
||||
Reference in New Issue
Block a user