[V1] Use FlashInfer Sampling Kernel for Top-P & Top-K Sampling (#11394)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -68,7 +68,7 @@ def _create_default_sampling_metadata(
|
||||
no_top_p=True,
|
||||
no_top_k=True,
|
||||
generators={},
|
||||
max_num_logprobs=VOCAB_SIZE,
|
||||
max_num_logprobs=0,
|
||||
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
|
||||
vocab_size, device),
|
||||
output_token_ids=output_token_ids,
|
||||
@@ -169,20 +169,14 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int):
|
||||
sampling_metadata.min_tokens = min_tokens
|
||||
sampling_metadata.stop_token_ids = stop_token_ids
|
||||
sampler = Sampler()
|
||||
sampler_output = sampler(fake_logits, sampling_metadata)
|
||||
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
||||
logits = logits.cpu()
|
||||
for batch_idx in range(batch_size):
|
||||
for vocab in range(VOCAB_SIZE):
|
||||
# Verify that the logprobs for stop token ids is set
|
||||
# to -inf.
|
||||
logprob_index = torch.where(
|
||||
sampler_output.logprob_token_ids[batch_idx] ==
|
||||
vocab)[0].item()
|
||||
if vocab in stop_token_ids[batch_idx]:
|
||||
assert sampler_output.logprobs[batch_idx][
|
||||
logprob_index] == -float("inf")
|
||||
for token_id in range(VOCAB_SIZE):
|
||||
if token_id in stop_token_ids[batch_idx]:
|
||||
assert logits[batch_idx][token_id] == -float("inf")
|
||||
else:
|
||||
assert sampler_output.logprobs[batch_idx][
|
||||
logprob_index] != -float("inf")
|
||||
assert logits[batch_idx][token_id] != -float("inf")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@@ -205,18 +199,14 @@ def test_sampler_presence_penalty(device: str, batch_size: int,
|
||||
batch_size, presence_penalty, torch.device(device))
|
||||
sampling_metadata.no_penalties = False
|
||||
sampler = Sampler()
|
||||
sampler_output = sampler(fake_logits, sampling_metadata)
|
||||
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
||||
logits = logits.cpu()
|
||||
for batch_idx in range(batch_size):
|
||||
# The logprobs in the SamplerOutput are arranged in descending order.
|
||||
# Since all tokens initially have the same logprobs, the non-penalized
|
||||
# tokens will appear at the beginning, while the penalized tokens
|
||||
# will appear at the end of the list.
|
||||
penalized_token_id = sampler_output.logprob_token_ids[batch_idx][
|
||||
VOCAB_SIZE - 1]
|
||||
penalized_log_prod = sampler_output.logprobs[batch_idx][VOCAB_SIZE - 1]
|
||||
non_penalized_token_id = sampler_output.logprob_token_ids[batch_idx][0]
|
||||
non_penalized_log_prod = sampler_output.logprobs[batch_idx][0]
|
||||
assert non_penalized_log_prod > penalized_log_prod
|
||||
# Since all tokens initially have the same logits, the non-penalized
|
||||
# token ID will be the one with the highest logit value, while the
|
||||
# penalized token ID will be the one with the lowest logit value.
|
||||
non_penalized_token_id = logits[batch_idx].argmax().item()
|
||||
penalized_token_id = logits[batch_idx].argmin().item()
|
||||
if presence_penalty > 0:
|
||||
# If `presence_penalty` is set to a value greater than 0, it
|
||||
# indicates a preference for new tokens over those already
|
||||
@@ -256,11 +246,11 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
|
||||
sampling_metadata.output_token_ids = output_token_ids
|
||||
sampling_metadata.no_penalties = False
|
||||
sampler = Sampler()
|
||||
sampler_output = sampler(fake_logits, sampling_metadata)
|
||||
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
||||
logits = logits.cpu()
|
||||
for batch_idx in range(batch_size):
|
||||
logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx]
|
||||
non_penalized_token_id = logprobs_token_ids[0]
|
||||
penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1]
|
||||
non_penalized_token_id = logits[batch_idx].argmax().item()
|
||||
penalized_token_id = logits[batch_idx].argmin().item()
|
||||
distinct_sorted_token_ids_in_output = \
|
||||
sorted_token_ids_in_output[batch_idx]
|
||||
most_frequent_token_id = distinct_sorted_token_ids_in_output[
|
||||
@@ -305,11 +295,11 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
|
||||
batch_size, repetition_penalty, torch.device(device))
|
||||
sampling_metadata.no_penalties = False
|
||||
sampler = Sampler()
|
||||
sampler_output = sampler(fake_logits, sampling_metadata)
|
||||
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
||||
logits = logits.cpu()
|
||||
for batch_idx in range(batch_size):
|
||||
logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx]
|
||||
non_penalized_token_id = logprobs_token_ids[0]
|
||||
penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1]
|
||||
non_penalized_token_id = logits[batch_idx].argmax().item()
|
||||
penalized_token_id = logits[batch_idx].argmin().item()
|
||||
prompt_tokens = sampling_metadata.prompt_token_ids[
|
||||
batch_idx][:].tolist()
|
||||
output_tokens = sampling_metadata.output_token_ids[batch_idx]
|
||||
|
||||
Reference in New Issue
Block a user