[V1] Optimize handling of sampling metadata and req_ids list (#13244)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-02-18 12:15:33 -08:00
committed by GitHub
parent a4d577b379
commit 30172b4947
15 changed files with 254 additions and 297 deletions

View File

@@ -77,25 +77,20 @@ def _create_default_sampling_metadata(
temperature=torch.full((batch_size, ), 0.0),
all_greedy=True,
all_random=False,
rejection_sampling=False,
top_p=torch.empty(batch_size, ),
top_k=torch.empty(batch_size, ),
no_top_p=True,
no_top_k=True,
min_p=torch.empty(batch_size, ),
no_min_p=True,
top_p=None,
top_k=None,
min_p=None,
generators={},
max_num_logprobs=0,
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
vocab_size, device),
output_token_ids=output_token_ids,
spec_token_ids=[],
spec_token_ids=None,
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
no_penalties=True,
min_tokens=[],
stop_token_ids=[],
min_tokens={},
logit_bias=[None] * batch_size,
)
return fake_sampling_metadata
@@ -104,10 +99,10 @@ def _create_default_sampling_metadata(
def _generate_min_token_penalties_and_stop_tokens(
num_output_tokens: int, batch_size: int, vocab_size: int,
batch_indices_for_min_token_penalty: List[int]
) -> Tuple[List[int], List[Set[int]]]:
) -> Dict[int, Tuple[int, Set[int]]]:
"""
Generates and returns a list of minimum token penalties (`min_tokens`)
and a corresponding list of stop token IDs (`stop_token_ids`) for each
Generates and returns a dict of minimum token penalties and
corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each
batch.
If a batch index is included in `batch_indices_for_min_token_penalty`,
@@ -115,22 +110,19 @@ def _generate_min_token_penalties_and_stop_tokens(
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
"""
stop_token_ids: List[Set[int]] = []
min_tokens: List[int] = []
min_tokens: Dict[int, Tuple[int, Set[int]]] = {}
for index in range(batch_size):
if index in batch_indices_for_min_token_penalty:
min_tokens.append(
min_tokens[index] = (
np.random.randint(num_output_tokens + 1,
2 * num_output_tokens))
stop_token_ids.append(
2 * num_output_tokens),
set(
np.random.randint(0, vocab_size - 1)
for _ in range(np.random.randint(0, vocab_size))))
else:
min_tokens.append(np.random.randint(0, num_output_tokens))
stop_token_ids.append(set())
return (min_tokens, stop_token_ids)
min_tokens[index] = (np.random.randint(0,
num_output_tokens), set())
return min_tokens
def _create_weighted_output_token_list(
@@ -165,7 +157,7 @@ def _create_weighted_output_token_list(
output_token_ids_for_batch.extend(
[token_id for _ in range(index + 1)])
output_token_ids.append(output_token_ids_for_batch)
return (output_token_ids, sorted_token_ids_in_output)
return output_token_ids, sorted_token_ids_in_output
@pytest.mark.parametrize("device", CUDA_DEVICES)
@@ -182,17 +174,17 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int):
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
batch_indices_for_min_token_penalty = np.random.randint(
0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist()
min_tokens, stop_token_ids = _generate_min_token_penalties_and_stop_tokens(
min_tokens = _generate_min_token_penalties_and_stop_tokens(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE,
batch_indices_for_min_token_penalty)
sampling_metadata.min_tokens = min_tokens
sampling_metadata.stop_token_ids = stop_token_ids
sampler = Sampler()
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
for token_id in range(VOCAB_SIZE):
if token_id in stop_token_ids[batch_idx]:
_, stop_token_ids = min_tokens.get(batch_idx, (0, set()))
if token_id in stop_token_ids:
assert logits[batch_idx][token_id] == -float("inf")
else:
assert logits[batch_idx][token_id] != -float("inf")