[V1] Optimize handling of sampling metadata and req_ids list (#13244)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -77,25 +77,20 @@ def _create_default_sampling_metadata(
|
||||
temperature=torch.full((batch_size, ), 0.0),
|
||||
all_greedy=True,
|
||||
all_random=False,
|
||||
rejection_sampling=False,
|
||||
top_p=torch.empty(batch_size, ),
|
||||
top_k=torch.empty(batch_size, ),
|
||||
no_top_p=True,
|
||||
no_top_k=True,
|
||||
min_p=torch.empty(batch_size, ),
|
||||
no_min_p=True,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
min_p=None,
|
||||
generators={},
|
||||
max_num_logprobs=0,
|
||||
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
|
||||
vocab_size, device),
|
||||
output_token_ids=output_token_ids,
|
||||
spec_token_ids=[],
|
||||
spec_token_ids=None,
|
||||
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
|
||||
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
|
||||
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
|
||||
no_penalties=True,
|
||||
min_tokens=[],
|
||||
stop_token_ids=[],
|
||||
min_tokens={},
|
||||
logit_bias=[None] * batch_size,
|
||||
)
|
||||
return fake_sampling_metadata
|
||||
@@ -104,10 +99,10 @@ def _create_default_sampling_metadata(
|
||||
def _generate_min_token_penalties_and_stop_tokens(
|
||||
num_output_tokens: int, batch_size: int, vocab_size: int,
|
||||
batch_indices_for_min_token_penalty: List[int]
|
||||
) -> Tuple[List[int], List[Set[int]]]:
|
||||
) -> Dict[int, Tuple[int, Set[int]]]:
|
||||
"""
|
||||
Generates and returns a list of minimum token penalties (`min_tokens`)
|
||||
and a corresponding list of stop token IDs (`stop_token_ids`) for each
|
||||
Generates and returns a dict of minimum token penalties and
|
||||
corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each
|
||||
batch.
|
||||
|
||||
If a batch index is included in `batch_indices_for_min_token_penalty`,
|
||||
@@ -115,22 +110,19 @@ def _generate_min_token_penalties_and_stop_tokens(
|
||||
and a random set of stop token IDs is created. Otherwise, a lower
|
||||
`min_tokens` value is assigned, and the stop token IDs set is empty.
|
||||
"""
|
||||
stop_token_ids: List[Set[int]] = []
|
||||
min_tokens: List[int] = []
|
||||
min_tokens: Dict[int, Tuple[int, Set[int]]] = {}
|
||||
for index in range(batch_size):
|
||||
if index in batch_indices_for_min_token_penalty:
|
||||
min_tokens.append(
|
||||
min_tokens[index] = (
|
||||
np.random.randint(num_output_tokens + 1,
|
||||
2 * num_output_tokens))
|
||||
stop_token_ids.append(
|
||||
2 * num_output_tokens),
|
||||
set(
|
||||
np.random.randint(0, vocab_size - 1)
|
||||
for _ in range(np.random.randint(0, vocab_size))))
|
||||
|
||||
else:
|
||||
min_tokens.append(np.random.randint(0, num_output_tokens))
|
||||
stop_token_ids.append(set())
|
||||
return (min_tokens, stop_token_ids)
|
||||
min_tokens[index] = (np.random.randint(0,
|
||||
num_output_tokens), set())
|
||||
return min_tokens
|
||||
|
||||
|
||||
def _create_weighted_output_token_list(
|
||||
@@ -165,7 +157,7 @@ def _create_weighted_output_token_list(
|
||||
output_token_ids_for_batch.extend(
|
||||
[token_id for _ in range(index + 1)])
|
||||
output_token_ids.append(output_token_ids_for_batch)
|
||||
return (output_token_ids, sorted_token_ids_in_output)
|
||||
return output_token_ids, sorted_token_ids_in_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@@ -182,17 +174,17 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int):
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
batch_indices_for_min_token_penalty = np.random.randint(
|
||||
0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist()
|
||||
min_tokens, stop_token_ids = _generate_min_token_penalties_and_stop_tokens(
|
||||
min_tokens = _generate_min_token_penalties_and_stop_tokens(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE,
|
||||
batch_indices_for_min_token_penalty)
|
||||
sampling_metadata.min_tokens = min_tokens
|
||||
sampling_metadata.stop_token_ids = stop_token_ids
|
||||
sampler = Sampler()
|
||||
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
||||
logits = logits.cpu()
|
||||
for batch_idx in range(batch_size):
|
||||
for token_id in range(VOCAB_SIZE):
|
||||
if token_id in stop_token_ids[batch_idx]:
|
||||
_, stop_token_ids = min_tokens.get(batch_idx, (0, set()))
|
||||
if token_id in stop_token_ids:
|
||||
assert logits[batch_idx][token_id] == -float("inf")
|
||||
else:
|
||||
assert logits[batch_idx][token_id] != -float("inf")
|
||||
|
||||
Reference in New Issue
Block a user