Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -9,9 +9,12 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.sample.utils import (
|
||||
BatchLogprobsComposition, BatchLogprobsSpecType,
|
||||
BatchLogprobsComposition,
|
||||
BatchLogprobsSpecType,
|
||||
assert_incr_detok_str_matches_non_incr_detok_str,
|
||||
compute_correct_cumulative_logprob, get_test_batch)
|
||||
compute_correct_cumulative_logprob,
|
||||
get_test_batch,
|
||||
)
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import LogprobsMode
|
||||
|
||||
@@ -29,22 +32,23 @@ SAMPLE_PROMPT = BatchLogprobsComposition.SAMPLE_PROMPT
|
||||
@pytest.fixture(
|
||||
scope="module",
|
||||
# Parameterize APC
|
||||
params=[False, True])
|
||||
params=[False, True],
|
||||
)
|
||||
def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]:
|
||||
with vllm_runner(
|
||||
MODEL,
|
||||
dtype=DTYPE,
|
||||
max_logprobs=7,
|
||||
# Very small number of batched tokens to ensure
|
||||
# that we test chunking.
|
||||
max_num_batched_tokens=16,
|
||||
max_num_seqs=16,
|
||||
max_model_len=128,
|
||||
enforce_eager=True,
|
||||
#TODO: enable this once we support it for
|
||||
# prompt logprobs.
|
||||
enable_prefix_caching=request.param,
|
||||
gpu_memory_utilization=0.4, # up to 2 alive concurrently
|
||||
MODEL,
|
||||
dtype=DTYPE,
|
||||
max_logprobs=7,
|
||||
# Very small number of batched tokens to ensure
|
||||
# that we test chunking.
|
||||
max_num_batched_tokens=16,
|
||||
max_num_seqs=16,
|
||||
max_model_len=128,
|
||||
enforce_eager=True,
|
||||
# TODO: enable this once we support it for
|
||||
# prompt logprobs.
|
||||
enable_prefix_caching=request.param,
|
||||
gpu_memory_utilization=0.4, # up to 2 alive concurrently
|
||||
) as vllm_model:
|
||||
yield vllm_model
|
||||
|
||||
@@ -96,8 +100,8 @@ def _repeat_logprob_config(
|
||||
num_test_prompts = len(test_prompts)
|
||||
# Make sure there is a logprobs configuration for each test prompt
|
||||
logprob_prompt_logprob_list = list(
|
||||
itertools.islice(itertools.cycle(logprob_prompt_logprob_list),
|
||||
num_test_prompts))
|
||||
itertools.islice(itertools.cycle(logprob_prompt_logprob_list), num_test_prompts)
|
||||
)
|
||||
# Now the number of prompts should match the number of sample params combos
|
||||
assert num_test_prompts == len(logprob_prompt_logprob_list)
|
||||
return logprob_prompt_logprob_list
|
||||
@@ -115,24 +119,28 @@ def _run_and_validate(
|
||||
do_apc: bool,
|
||||
) -> None:
|
||||
vllm_results = vllm_model.llm.generate(
|
||||
test_prompts, sampling_params=vllm_sampling_params)
|
||||
test_prompts, sampling_params=vllm_sampling_params
|
||||
)
|
||||
|
||||
for vllm_result, hf_logprob, hf_output, logprob_prompt_logprob in zip(
|
||||
vllm_results, hf_logprobs, hf_outputs,
|
||||
logprob_prompt_logprob_list):
|
||||
|
||||
vllm_results, hf_logprobs, hf_outputs, logprob_prompt_logprob_list
|
||||
):
|
||||
# Extract request-level (prompt)logprobs config
|
||||
num_top_logprobs, num_top_prompt_logprobs = logprob_prompt_logprob
|
||||
|
||||
# Test whether sampled token output is consistent between vLLM and HF
|
||||
# vLLM prompt+completion should match HF output
|
||||
if temperature == 0.0:
|
||||
assert (vllm_result.prompt_token_ids +
|
||||
vllm_result.outputs[0].token_ids == hf_output[0])
|
||||
assert (
|
||||
vllm_result.prompt_token_ids + vllm_result.outputs[0].token_ids
|
||||
== hf_output[0]
|
||||
)
|
||||
else:
|
||||
# Sampled tokens won't match if not greedy
|
||||
assert (vllm_result.prompt_token_ids == hf_output[0]
|
||||
[:len(vllm_result.prompt_token_ids)])
|
||||
assert (
|
||||
vllm_result.prompt_token_ids
|
||||
== hf_output[0][: len(vllm_result.prompt_token_ids)]
|
||||
)
|
||||
|
||||
# Validate sample logprobs
|
||||
if num_top_logprobs is not None:
|
||||
@@ -141,8 +149,9 @@ def _run_and_validate(
|
||||
# correct
|
||||
assert vllm_result.outputs[0].logprobs is not None
|
||||
assert len(vllm_result.outputs[0].logprobs) == max_tokens
|
||||
for logprobs, token_id in zip(vllm_result.outputs[0].logprobs,
|
||||
vllm_result.outputs[0].token_ids):
|
||||
for logprobs, token_id in zip(
|
||||
vllm_result.outputs[0].logprobs, vllm_result.outputs[0].token_ids
|
||||
):
|
||||
assert logprobs is not None
|
||||
|
||||
# Confirm that the output token appears among the logprobs
|
||||
@@ -159,23 +168,26 @@ def _run_and_validate(
|
||||
if num_top_logprobs > 0:
|
||||
# We should have an entry for each of the topk ranks
|
||||
all_ranks = {lp.rank for lp in logprobs.values()}
|
||||
assert all(r in all_ranks
|
||||
for r in range(1, num_top_logprobs + 1))
|
||||
assert all(r in all_ranks for r in range(1, num_top_logprobs + 1))
|
||||
|
||||
output_text = vllm_result.outputs[0].text
|
||||
output_string_from_most_likely_tokens_lst: list[str] = []
|
||||
for top_logprobs in vllm_result.outputs[0].logprobs:
|
||||
top_logprob = next(iter(top_logprobs.values()))
|
||||
output_string_from_most_likely_tokens_lst.append(
|
||||
top_logprob.decoded_token)
|
||||
top_logprob.decoded_token
|
||||
)
|
||||
|
||||
output_string_from_most_likely_tokens = "".join(
|
||||
output_string_from_most_likely_tokens_lst)
|
||||
output_string_from_most_likely_tokens_lst
|
||||
)
|
||||
assert_incr_detok_str_matches_non_incr_detok_str(
|
||||
output_text, output_string_from_most_likely_tokens,
|
||||
output_text,
|
||||
output_string_from_most_likely_tokens,
|
||||
"The output text from the top logprob for each token "
|
||||
"position should be the same as the output text in the "
|
||||
"result.")
|
||||
"result.",
|
||||
)
|
||||
|
||||
# Compare vLLM sample logprobs to HF
|
||||
vllm_sample_logprobs = vllm_result.outputs[0].logprobs
|
||||
@@ -187,11 +199,12 @@ def _run_and_validate(
|
||||
logprob,
|
||||
hf_logprob[i][-1][token_id].item(),
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
assert isinstance(
|
||||
sample_logprob.decoded_token,
|
||||
str), ("The token should be decoded by the time it is"
|
||||
" returned to the user.")
|
||||
rtol=1e-2,
|
||||
)
|
||||
assert isinstance(sample_logprob.decoded_token, str), (
|
||||
"The token should be decoded by the time it is"
|
||||
" returned to the user."
|
||||
)
|
||||
|
||||
# At this point we know the sample logprobs are correct for this
|
||||
# request. Validate that cumulative_logprob is actually the sum.
|
||||
@@ -201,7 +214,8 @@ def _run_and_validate(
|
||||
vllm_result.outputs[0].cumulative_logprob,
|
||||
compute_correct_cumulative_logprob(vllm_result.outputs[0]),
|
||||
atol=1e-6,
|
||||
rtol=1e-6)
|
||||
rtol=1e-6,
|
||||
)
|
||||
else:
|
||||
# Logprobs disabled for this request; should be None
|
||||
assert vllm_result.outputs[0].logprobs is None
|
||||
@@ -214,17 +228,17 @@ def _run_and_validate(
|
||||
assert vllm_result.prompt_logprobs[0] is None
|
||||
# - Prompt logprobs are returned for all indices in
|
||||
# the prompt
|
||||
assert len(vllm_result.prompt_logprobs) == len(
|
||||
vllm_result.prompt_token_ids)
|
||||
assert len(vllm_result.prompt_logprobs) == len(vllm_result.prompt_token_ids)
|
||||
for prompt_logprobs, prompt_token_id in zip(
|
||||
vllm_result.prompt_logprobs[1:],
|
||||
vllm_result.prompt_token_ids[1:]):
|
||||
vllm_result.prompt_logprobs[1:], vllm_result.prompt_token_ids[1:]
|
||||
):
|
||||
assert prompt_logprobs is not None
|
||||
|
||||
# Confirm that the prompt token appears among the logprobs
|
||||
assert prompt_token_id in prompt_logprobs
|
||||
token_in_topk = prompt_logprobs[
|
||||
prompt_token_id].rank <= num_top_prompt_logprobs
|
||||
token_in_topk = (
|
||||
prompt_logprobs[prompt_token_id].rank <= num_top_prompt_logprobs
|
||||
)
|
||||
|
||||
# If the prompt token is not included in the top K
|
||||
# logprob, it can return 1 more data
|
||||
@@ -236,8 +250,9 @@ def _run_and_validate(
|
||||
if num_top_prompt_logprobs > 0:
|
||||
# We should have an entry for each of the topk ranks
|
||||
all_ranks = {lp.rank for lp in prompt_logprobs.values()}
|
||||
assert all(r in all_ranks
|
||||
for r in range(1, num_top_prompt_logprobs + 1))
|
||||
assert all(
|
||||
r in all_ranks for r in range(1, num_top_prompt_logprobs + 1)
|
||||
)
|
||||
|
||||
# Compare prompt logprobs to HF
|
||||
# The first prompt logprob is always None, so we compare it from
|
||||
@@ -249,19 +264,24 @@ def _run_and_validate(
|
||||
logprob.logprob,
|
||||
hf_logprob[0][i][token_id].item(),
|
||||
atol=2e-2,
|
||||
rtol=2e-2)
|
||||
rtol=2e-2,
|
||||
)
|
||||
else:
|
||||
assert vllm_result.prompt_logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_logprobs_composition",
|
||||
[NONE, SAMPLE, PROMPT, SAMPLE_PROMPT])
|
||||
@pytest.mark.parametrize(
|
||||
"batch_logprobs_composition", [NONE, SAMPLE, PROMPT, SAMPLE_PROMPT]
|
||||
)
|
||||
@pytest.mark.parametrize("temperature", [0.0, 2.0])
|
||||
def test_get_logprobs_and_prompt_logprobs(
|
||||
hf_model, vllm_model,
|
||||
batch_logprobs_composition: BatchLogprobsComposition,
|
||||
temperature: float, example_prompts: list[str],
|
||||
monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
hf_model,
|
||||
vllm_model,
|
||||
batch_logprobs_composition: BatchLogprobsComposition,
|
||||
temperature: float,
|
||||
example_prompts: list[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Test V1 Engine logprobs & prompt logprobs
|
||||
|
||||
Exercise a variety of combinations of `logprobs` and `prompt_logprobs`
|
||||
@@ -291,8 +311,9 @@ def test_get_logprobs_and_prompt_logprobs(
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
do_apc = vllm_model.llm.llm_engine.cache_config.enable_prefix_caching
|
||||
if do_apc and (temperature < 2.0
|
||||
or batch_logprobs_composition != SAMPLE_PROMPT):
|
||||
if do_apc and (
|
||||
temperature < 2.0 or batch_logprobs_composition != SAMPLE_PROMPT
|
||||
):
|
||||
# Skip some test-cases to save time.
|
||||
pytest.skip()
|
||||
test_prompts = example_prompts
|
||||
@@ -309,19 +330,21 @@ def test_get_logprobs_and_prompt_logprobs(
|
||||
|
||||
# Batch has mixed sample params
|
||||
# (different logprobs/prompt logprobs combos)
|
||||
logprob_prompt_logprob_list = get_test_batch(
|
||||
batch_logprobs_composition)
|
||||
logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition)
|
||||
|
||||
# Ensure that each test prompt has a logprob config for testing
|
||||
logprob_prompt_logprob_list = _repeat_logprob_config(
|
||||
test_prompts, logprob_prompt_logprob_list)
|
||||
test_prompts, logprob_prompt_logprob_list
|
||||
)
|
||||
# Generate SamplingParams
|
||||
vllm_sampling_params = [
|
||||
SamplingParams(max_tokens=max_tokens,
|
||||
logprobs=num_lp,
|
||||
prompt_logprobs=num_plp,
|
||||
temperature=temperature,
|
||||
seed=1984)
|
||||
SamplingParams(
|
||||
max_tokens=max_tokens,
|
||||
logprobs=num_lp,
|
||||
prompt_logprobs=num_plp,
|
||||
temperature=temperature,
|
||||
seed=1984,
|
||||
)
|
||||
for num_lp, num_plp in logprob_prompt_logprob_list
|
||||
]
|
||||
for _ in range(2 if do_apc else 1):
|
||||
@@ -334,7 +357,8 @@ def test_get_logprobs_and_prompt_logprobs(
|
||||
logprob_prompt_logprob_list=logprob_prompt_logprob_list,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
do_apc=do_apc)
|
||||
do_apc=do_apc,
|
||||
)
|
||||
|
||||
|
||||
def test_max_logprobs(monkeypatch: pytest.MonkeyPatch):
|
||||
@@ -351,19 +375,18 @@ def test_max_logprobs(monkeypatch: pytest.MonkeyPatch):
|
||||
enable_prefix_caching=False,
|
||||
# 2 other llms alive during whole session
|
||||
gpu_memory_utilization=0.15,
|
||||
max_model_len=256)
|
||||
max_model_len=256,
|
||||
)
|
||||
vllm_sampling_params = SamplingParams(logprobs=1)
|
||||
# should pass
|
||||
runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
|
||||
|
||||
bad_sampling_params = SamplingParams(logprobs=2)
|
||||
with pytest.raises(ValueError):
|
||||
runner.generate(["Hello world"],
|
||||
sampling_params=bad_sampling_params)
|
||||
runner.generate(["Hello world"], sampling_params=bad_sampling_params)
|
||||
|
||||
|
||||
def test_none_logprobs(vllm_model, example_prompts,
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
def test_none_logprobs(vllm_model, example_prompts, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Engine should return `logprobs` and `prompt_logprobs` as `None`
|
||||
|
||||
Args:
|
||||
@@ -388,14 +411,12 @@ def test_none_logprobs(vllm_model, example_prompts,
|
||||
for i in range(len(results_logprobs_none)):
|
||||
# Check sample logprobs are None
|
||||
assert results_logprobs_none[i].outputs[0].logprobs is None
|
||||
assert results_logprobs_none[i].outputs[
|
||||
0].cumulative_logprob is None
|
||||
assert results_logprobs_none[i].outputs[0].cumulative_logprob is None
|
||||
# Check prompt logprobs are None
|
||||
assert results_logprobs_none[i].prompt_logprobs is None
|
||||
|
||||
|
||||
def test_zero_logprobs(vllm_model, example_prompts,
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
def test_zero_logprobs(vllm_model, example_prompts, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Engine should return sampled token and prompt token logprobs
|
||||
|
||||
Args:
|
||||
@@ -406,12 +427,12 @@ def test_zero_logprobs(vllm_model, example_prompts,
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
max_tokens = 5
|
||||
|
||||
sampling_params_logprobs_zero = SamplingParams(max_tokens=max_tokens,
|
||||
logprobs=0,
|
||||
prompt_logprobs=0,
|
||||
temperature=0.0)
|
||||
sampling_params_logprobs_zero = SamplingParams(
|
||||
max_tokens=max_tokens, logprobs=0, prompt_logprobs=0, temperature=0.0
|
||||
)
|
||||
results_logprobs_zero = vllm_model.llm.generate(
|
||||
example_prompts, sampling_params=sampling_params_logprobs_zero)
|
||||
example_prompts, sampling_params=sampling_params_logprobs_zero
|
||||
)
|
||||
|
||||
for i in range(len(results_logprobs_zero)):
|
||||
# Check that there is one sample logprob dict for each
|
||||
@@ -422,8 +443,7 @@ def test_zero_logprobs(vllm_model, example_prompts,
|
||||
prompt_token_ids = results_logprobs_zero[i].prompt_token_ids
|
||||
assert logprobs is not None
|
||||
assert len(sampled_token_ids) == len(logprobs)
|
||||
assert results_logprobs_zero[i].outputs[
|
||||
0].cumulative_logprob is not None
|
||||
assert results_logprobs_zero[i].outputs[0].cumulative_logprob is not None
|
||||
# Check that there is one prompt logprob dict for each
|
||||
# prompt token
|
||||
assert prompt_logprobs is not None
|
||||
@@ -444,13 +464,15 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
|
||||
enable_prefix_caching=False,
|
||||
# 2 other llms alive during whole session
|
||||
gpu_memory_utilization=0.15,
|
||||
max_model_len=256)
|
||||
max_model_len=256,
|
||||
)
|
||||
|
||||
sampling_params_logprobs_all = SamplingParams(max_tokens=5,
|
||||
logprobs=-1,
|
||||
prompt_logprobs=-1)
|
||||
sampling_params_logprobs_all = SamplingParams(
|
||||
max_tokens=5, logprobs=-1, prompt_logprobs=-1
|
||||
)
|
||||
results_logprobs_all = runner.llm.generate(
|
||||
example_prompts, sampling_params=sampling_params_logprobs_all)
|
||||
example_prompts, sampling_params=sampling_params_logprobs_all
|
||||
)
|
||||
vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size()
|
||||
|
||||
for i in range(len(results_logprobs_all)):
|
||||
@@ -466,13 +488,13 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode))
|
||||
def test_logprobs_mode(logprobs_mode: LogprobsMode,
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
def test_logprobs_mode(logprobs_mode: LogprobsMode, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test with LLM engine with different logprobs_mode.
|
||||
For logprobs, we should have non-positive values.
|
||||
For logits, we should expect at least one positive values.
|
||||
"""
|
||||
from vllm import LLM
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
@@ -483,10 +505,10 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode,
|
||||
# 2 other llms alive during whole session
|
||||
gpu_memory_utilization=0.05,
|
||||
max_model_len=16,
|
||||
logprobs_mode=logprobs_mode)
|
||||
logprobs_mode=logprobs_mode,
|
||||
)
|
||||
vllm_sampling_params = SamplingParams(logprobs=1)
|
||||
results = llm.generate(["Hello world"],
|
||||
sampling_params=vllm_sampling_params)
|
||||
results = llm.generate(["Hello world"], sampling_params=vllm_sampling_params)
|
||||
|
||||
total_token_with_logprobs = 0
|
||||
positive_values = 0
|
||||
|
||||
@@ -15,22 +15,23 @@ EXPECTED_VALUE = 0.62
|
||||
MODEL = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
MODEL_ARGS = f"pretrained={MODEL},enforce_eager=True,enable_prefix_caching=False,gpu_memory_utilization=0.8" # noqa: E501
|
||||
SERVER_ARGS = [
|
||||
"--enforce_eager", "--no_enable_prefix_caching",
|
||||
"--gpu-memory-utilization=0.8"
|
||||
"--enforce_eager",
|
||||
"--no_enable_prefix_caching",
|
||||
"--gpu-memory-utilization=0.8",
|
||||
]
|
||||
NUM_CONCURRENT = 100
|
||||
|
||||
|
||||
def test_prompt_logprobs_e2e():
|
||||
results = lm_eval.simple_evaluate(model="vllm",
|
||||
model_args=MODEL_ARGS,
|
||||
tasks=TASK,
|
||||
batch_size="auto")
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="vllm", model_args=MODEL_ARGS, tasks=TASK, batch_size="auto"
|
||||
)
|
||||
|
||||
measured_value = results["results"][TASK][FILTER]
|
||||
assert (measured_value - RTOL < EXPECTED_VALUE
|
||||
and measured_value + RTOL > EXPECTED_VALUE
|
||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||
assert (
|
||||
measured_value - RTOL < EXPECTED_VALUE
|
||||
and measured_value + RTOL > EXPECTED_VALUE
|
||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||
|
||||
|
||||
def test_prompt_logprobs_e2e_server():
|
||||
@@ -40,7 +41,8 @@ def test_prompt_logprobs_e2e_server():
|
||||
model_args = (
|
||||
f"model={MODEL},"
|
||||
f"base_url={url},"
|
||||
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False")
|
||||
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False"
|
||||
)
|
||||
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="local-completions",
|
||||
@@ -49,6 +51,7 @@ def test_prompt_logprobs_e2e_server():
|
||||
)
|
||||
|
||||
measured_value = results["results"][TASK][FILTER]
|
||||
assert (measured_value - RTOL < EXPECTED_VALUE
|
||||
and measured_value + RTOL > EXPECTED_VALUE
|
||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||
assert (
|
||||
measured_value - RTOL < EXPECTED_VALUE
|
||||
and measured_value + RTOL > EXPECTED_VALUE
|
||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||
|
||||
@@ -9,8 +9,7 @@ import torch.nn.functional as F
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID,
|
||||
RejectionSampler)
|
||||
from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID, RejectionSampler
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
DEVICE = current_platform.device_type
|
||||
@@ -21,10 +20,11 @@ def rejection_sampler():
|
||||
return RejectionSampler()
|
||||
|
||||
|
||||
def create_logits_tensor(output_token_ids: list[list[int]],
|
||||
vocab_size: int = 100) -> torch.Tensor:
|
||||
def create_logits_tensor(
|
||||
output_token_ids: list[list[int]], vocab_size: int = 100
|
||||
) -> torch.Tensor:
|
||||
"""Helper function to create logits tensor that
|
||||
will produce desired token ids on argmax"""
|
||||
will produce desired token ids on argmax"""
|
||||
token_ids = [tokens[:-1] for tokens in output_token_ids]
|
||||
num_total_tokens = sum(len(tokens) for tokens in token_ids)
|
||||
logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE)
|
||||
@@ -44,8 +44,8 @@ def create_sampling_metadata(
|
||||
generators: Optional[dict[int, Any]] = None,
|
||||
) -> SamplingMetadata:
|
||||
"""Create a v1 sampling metadata object with all_greedy set
|
||||
to the given value. Either all greedy or all random sampling
|
||||
is used.
|
||||
to the given value. Either all greedy or all random sampling
|
||||
is used.
|
||||
"""
|
||||
generators = generators or {}
|
||||
if all_greedy:
|
||||
@@ -81,10 +81,10 @@ def test_perfect_match(rejection_sampler):
|
||||
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
|
||||
device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
@@ -93,9 +93,7 @@ def test_perfect_match(rejection_sampler):
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor([[1, 2, 3, 4]],
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device)
|
||||
assert torch.equal(output, expected)
|
||||
|
||||
|
||||
@@ -106,10 +104,10 @@ def test_early_mismatch(rejection_sampler):
|
||||
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
|
||||
device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
@@ -129,15 +127,16 @@ def test_early_mismatch(rejection_sampler):
|
||||
def test_multiple_sequences(rejection_sampler):
|
||||
"""Test handling multiple sequences of speculated tokens"""
|
||||
spec_tokens = [[1, 2], [3]]
|
||||
output_tokens = [[1, 2, 5], [3,
|
||||
4]] # Two sequences with bonus tokens 5 and 4
|
||||
output_tokens = [[1, 2, 5], [3, 4]] # Two sequences with bonus tokens 5 and 4
|
||||
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor(
|
||||
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device
|
||||
)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
@@ -146,9 +145,9 @@ def test_multiple_sequences(rejection_sampler):
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor([[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]],
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
expected = torch.tensor(
|
||||
[[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, device=logits.device
|
||||
)
|
||||
assert torch.equal(output, expected)
|
||||
|
||||
|
||||
@@ -159,10 +158,10 @@ def test_single_token_sequence(rejection_sampler):
|
||||
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
|
||||
device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
@@ -182,10 +181,10 @@ def test_empty_sequence(rejection_sampler):
|
||||
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
|
||||
device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
@@ -201,15 +200,16 @@ def test_empty_sequence(rejection_sampler):
|
||||
def test_multiple_mismatches(rejection_sampler):
|
||||
"""Test handling multiple sequences with mismatches"""
|
||||
spec_tokens = [[1, 2, 3], [4, 5, 6]]
|
||||
output_tokens = [[1, 2, 7, 6], [4, 8, 6,
|
||||
9]] # Mismatches in both sequences
|
||||
output_tokens = [[1, 2, 7, 6], [4, 8, 6, 9]] # Mismatches in both sequences
|
||||
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor(
|
||||
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device
|
||||
)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
@@ -219,8 +219,10 @@ def test_multiple_mismatches(rejection_sampler):
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor(
|
||||
[[1, 2, 7, PLACEHOLDER_TOKEN_ID],
|
||||
[4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]],
|
||||
[
|
||||
[1, 2, 7, PLACEHOLDER_TOKEN_ID],
|
||||
[4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID],
|
||||
],
|
||||
dtype=torch.int,
|
||||
device=logits.device,
|
||||
)
|
||||
@@ -232,18 +234,23 @@ def test_multiple_mismatches(rejection_sampler):
|
||||
[
|
||||
([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]), # Perfect match with bonus
|
||||
([[1]], [[2, 3]], [[2, PLACEHOLDER_TOKEN_ID]]), # First mismatch
|
||||
([[1, 2], [3, 4]], [[1, 5, 6], [3, 4, 7]],
|
||||
[[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]]), # Mixed matches
|
||||
])
|
||||
def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens,
|
||||
expected):
|
||||
(
|
||||
[[1, 2], [3, 4]],
|
||||
[[1, 5, 6], [3, 4, 7]],
|
||||
[[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]],
|
||||
), # Mixed matches
|
||||
],
|
||||
)
|
||||
def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, expected):
|
||||
"""Parametrized test for various matching scenarios"""
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens],
|
||||
device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
bonus_token_tensor = torch.tensor(
|
||||
[tokens[-1] for tokens in output_tokens], device=logits.device
|
||||
)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
@@ -252,9 +259,7 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected_tensor = torch.tensor(expected,
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device)
|
||||
assert torch.equal(output, expected_tensor)
|
||||
|
||||
|
||||
@@ -273,22 +278,15 @@ def test_deterministic_when_seeded(
|
||||
n_rep: int,
|
||||
):
|
||||
num_tokens = batch_size * k
|
||||
draft_probs = torch.rand(num_tokens,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=DEVICE)
|
||||
draft_probs = torch.rand(num_tokens, vocab_size, dtype=torch.float32, device=DEVICE)
|
||||
draft_probs = F.softmax(draft_probs, dim=-1)
|
||||
target_logits = torch.rand_like(draft_probs)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64,
|
||||
device=DEVICE)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64,
|
||||
device=DEVICE)
|
||||
bonus_token_ids = torch.randint(
|
||||
low=0, high=vocab_size, size=(batch_size, 1), dtype=torch.int64, device=DEVICE
|
||||
)
|
||||
draft_token_ids = torch.randint(
|
||||
low=0, high=vocab_size, size=(batch_size, k), dtype=torch.int64, device=DEVICE
|
||||
)
|
||||
|
||||
seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded
|
||||
|
||||
@@ -296,17 +294,17 @@ def test_deterministic_when_seeded(
|
||||
for _ in range(n_rep):
|
||||
seeded_seqs = {
|
||||
i: torch.Generator(device=DEVICE).manual_seed(i)
|
||||
for i in range(batch_size) if seeded_mask[i]
|
||||
for i in range(batch_size)
|
||||
if seeded_mask[i]
|
||||
}
|
||||
|
||||
temperature = torch.ones(batch_size,
|
||||
dtype=torch.float32,
|
||||
device=DEVICE)
|
||||
sampling_metadata = create_sampling_metadata(all_greedy=False,
|
||||
temperature=temperature,
|
||||
generators=seeded_seqs)
|
||||
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
|
||||
sampling_metadata = create_sampling_metadata(
|
||||
all_greedy=False, temperature=temperature, generators=seeded_seqs
|
||||
)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
draft_token_ids.tolist(), device=DEVICE)
|
||||
draft_token_ids.tolist(), device=DEVICE
|
||||
)
|
||||
rep_result = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=draft_probs,
|
||||
@@ -352,8 +350,7 @@ def test_rejection_sampling_approximates_target_distribution():
|
||||
num_reference_probs = 100
|
||||
|
||||
# Prepare draft, target, and reference probability distributions
|
||||
draft_probs = F.softmax(torch.rand(vocab_size, dtype=torch.float32),
|
||||
dim=-1)
|
||||
draft_probs = F.softmax(torch.rand(vocab_size, dtype=torch.float32), dim=-1)
|
||||
target_logits = torch.rand(vocab_size, dtype=torch.float32)
|
||||
target_probs = F.softmax(target_logits, dim=-1)
|
||||
reference_probs = F.softmax(
|
||||
@@ -368,38 +365,48 @@ def test_rejection_sampling_approximates_target_distribution():
|
||||
for num_samples in sample_sizes:
|
||||
# Sample using rejection sampling.
|
||||
rej_sample_probs = estimate_rejection_sampling_pdf(
|
||||
draft_probs, target_logits, k, vocab_size, num_samples)
|
||||
draft_probs, target_logits, k, vocab_size, num_samples
|
||||
)
|
||||
rej_sample_probs = rej_sample_probs.to(DEVICE)
|
||||
|
||||
# Average distance from reference probs.
|
||||
reference_vs_rejsample_dist = torch.dist(
|
||||
reference_probs,
|
||||
rej_sample_probs).item() / reference_probs.shape[0]
|
||||
target_vs_rejsample_dist = torch.dist(target_probs,
|
||||
rej_sample_probs).item()
|
||||
reference_vs_rejsample_dist = (
|
||||
torch.dist(reference_probs, rej_sample_probs).item()
|
||||
/ reference_probs.shape[0]
|
||||
)
|
||||
target_vs_rejsample_dist = torch.dist(target_probs, rej_sample_probs).item()
|
||||
|
||||
distance_wrt_reference.append(reference_vs_rejsample_dist)
|
||||
distance_wrt_target.append(target_vs_rejsample_dist)
|
||||
|
||||
relative_change_in_distance_wrt_target = get_ratio_first_to_last(
|
||||
distance_wrt_target)
|
||||
distance_wrt_target
|
||||
)
|
||||
relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
|
||||
distance_wrt_reference)
|
||||
distance_wrt_reference
|
||||
)
|
||||
|
||||
print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
|
||||
f"{reference_vs_rejsample_dist=:.05f}")
|
||||
print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} "
|
||||
f"{relative_change_in_distance_wrt_reference=:.02f}")
|
||||
print(
|
||||
f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
|
||||
f"{reference_vs_rejsample_dist=:.05f}"
|
||||
)
|
||||
print(
|
||||
f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} "
|
||||
f"{relative_change_in_distance_wrt_reference=:.02f}"
|
||||
)
|
||||
|
||||
relative_change_in_distance_wrt_target = get_ratio_first_to_last(
|
||||
distance_wrt_target)
|
||||
distance_wrt_target
|
||||
)
|
||||
relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
|
||||
distance_wrt_reference)
|
||||
distance_wrt_reference
|
||||
)
|
||||
|
||||
expected_improvement_multiplier = 20
|
||||
assert (relative_change_in_distance_wrt_target
|
||||
> relative_change_in_distance_wrt_reference *
|
||||
expected_improvement_multiplier)
|
||||
assert (
|
||||
relative_change_in_distance_wrt_target
|
||||
> relative_change_in_distance_wrt_reference * expected_improvement_multiplier
|
||||
)
|
||||
|
||||
|
||||
def get_ratio_first_to_last(elements: list[float]) -> float:
|
||||
@@ -427,28 +434,29 @@ def estimate_rejection_sampling_pdf(
|
||||
rejection_sampler = RejectionSampler()
|
||||
num_tokens = num_samples * k
|
||||
# Repeat draft probs num_samples * k times.
|
||||
draft_probs = draft_probs.reshape(1, 1,
|
||||
vocab_size).repeat(num_samples, k, 1)
|
||||
draft_probs = draft_probs.reshape(1, 1, vocab_size).repeat(num_samples, k, 1)
|
||||
|
||||
# Repeat target probs num_tokens times.
|
||||
target_logits = target_logits.reshape(1, vocab_size).repeat(num_tokens, 1)
|
||||
|
||||
# Randomly sample draft token ids from draft probs.
|
||||
draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
|
||||
num_samples=k,
|
||||
replacement=True).reshape(
|
||||
num_samples, k)
|
||||
draft_token_ids = torch.multinomial(
|
||||
draft_probs[:, 0, :], num_samples=k, replacement=True
|
||||
).reshape(num_samples, k)
|
||||
draft_probs = draft_probs.view(num_tokens, vocab_size)
|
||||
|
||||
# Bonus tokens not used but required.
|
||||
bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64,
|
||||
device=DEVICE).repeat(num_samples, 1)
|
||||
bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64, device=DEVICE).repeat(
|
||||
num_samples, 1
|
||||
)
|
||||
|
||||
temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE)
|
||||
sampling_metadata = create_sampling_metadata(all_greedy=False,
|
||||
temperature=temperature)
|
||||
sampling_metadata = create_sampling_metadata(
|
||||
all_greedy=False, temperature=temperature
|
||||
)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
draft_token_ids.tolist(), device=bonus_token_ids.device)
|
||||
draft_token_ids.tolist(), device=bonus_token_ids.device
|
||||
)
|
||||
output_token_ids = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=draft_probs,
|
||||
@@ -458,11 +466,12 @@ def estimate_rejection_sampling_pdf(
|
||||
)
|
||||
output_token_ids = output_token_ids[:, :-1].flatten()
|
||||
|
||||
hist = torch.histogram(output_token_ids.to(dtype=torch.float,
|
||||
device="cpu"),
|
||||
bins=vocab_size,
|
||||
range=(0, vocab_size),
|
||||
density=True)
|
||||
hist = torch.histogram(
|
||||
output_token_ids.to(dtype=torch.float, device="cpu"),
|
||||
bins=vocab_size,
|
||||
range=(0, vocab_size),
|
||||
density=True,
|
||||
)
|
||||
|
||||
return hist.hist
|
||||
|
||||
@@ -480,9 +489,9 @@ def _test_masked_logits(
|
||||
num_tokens = batch_size * num_draft_tokens
|
||||
|
||||
# Create random draft probabilities.
|
||||
draft_probs = torch.rand((num_tokens, vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=DEVICE)
|
||||
draft_probs = torch.rand(
|
||||
(num_tokens, vocab_size), dtype=torch.float32, device=DEVICE
|
||||
)
|
||||
draft_probs = F.softmax(draft_probs, dim=-1)
|
||||
|
||||
# Randomly sample draft token ids from draft probs
|
||||
@@ -491,9 +500,7 @@ def _test_masked_logits(
|
||||
draft_token_ids = draft_token_ids.tolist()
|
||||
|
||||
# Bonus tokens not used but required
|
||||
bonus_token_ids = torch.zeros((batch_size, 1),
|
||||
dtype=torch.int64,
|
||||
device=DEVICE)
|
||||
bonus_token_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=DEVICE)
|
||||
|
||||
# Create spec decode metadata
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
@@ -531,8 +538,7 @@ def test_top_k(rejection_sampler, top_k):
|
||||
|
||||
# Randomly create top-k indices.
|
||||
top_k_indices = [
|
||||
torch.randperm(vocab_size, device=DEVICE)[:top_k]
|
||||
for _ in range(num_tokens)
|
||||
torch.randperm(vocab_size, device=DEVICE)[:top_k] for _ in range(num_tokens)
|
||||
]
|
||||
top_k_indices = torch.stack(top_k_indices)
|
||||
|
||||
@@ -550,9 +556,7 @@ def test_top_k(rejection_sampler, top_k):
|
||||
sampling_metadata = create_sampling_metadata(
|
||||
all_greedy=False,
|
||||
temperature=temperature,
|
||||
top_k=torch.tensor([top_k] * batch_size,
|
||||
device=DEVICE,
|
||||
dtype=torch.int64),
|
||||
top_k=torch.tensor([top_k] * batch_size, device=DEVICE, dtype=torch.int64),
|
||||
)
|
||||
|
||||
_test_masked_logits(
|
||||
@@ -595,9 +599,7 @@ def test_top_p(rejection_sampler, top_p):
|
||||
sampling_metadata = create_sampling_metadata(
|
||||
all_greedy=False,
|
||||
temperature=temperature,
|
||||
top_p=torch.tensor([top_p] * batch_size,
|
||||
device=DEVICE,
|
||||
dtype=torch.float32),
|
||||
top_p=torch.tensor([top_p] * batch_size, device=DEVICE, dtype=torch.float32),
|
||||
)
|
||||
|
||||
_test_masked_logits(
|
||||
|
||||
@@ -29,12 +29,12 @@ def _create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor:
|
||||
return fake_logits
|
||||
|
||||
|
||||
def _create_penalty_tensor(batch_size: int, penalty_value: float,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
return torch.full((batch_size, ),
|
||||
fill_value=penalty_value,
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
def _create_penalty_tensor(
|
||||
batch_size: int, penalty_value: float, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
return torch.full(
|
||||
(batch_size,), fill_value=penalty_value, dtype=torch.float, device=device
|
||||
)
|
||||
|
||||
|
||||
def _create_prompt_tokens_tensor(
|
||||
@@ -62,9 +62,9 @@ def _create_allowed_token_ids(
|
||||
if i % 2 == 1:
|
||||
continue
|
||||
if mask is None:
|
||||
mask = torch.zeros((batch_size, vocab_size),
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
mask = torch.zeros(
|
||||
(batch_size, vocab_size), dtype=torch.bool, device=device
|
||||
)
|
||||
start = min(i, vocab_size - 1)
|
||||
end = min(i + num_allowed_token_ids, vocab_size - 1)
|
||||
mask[i, start:end] = True
|
||||
@@ -80,9 +80,9 @@ def _create_bad_words_token_ids(
|
||||
for batch_idx in range(batch_size):
|
||||
token_ids_single_batch = []
|
||||
for bad_words_length in bad_words_lengths:
|
||||
token_ids = np.random.choice(vocab_size,
|
||||
size=bad_words_length,
|
||||
replace=True).tolist()
|
||||
token_ids = np.random.choice(
|
||||
vocab_size, size=bad_words_length, replace=True
|
||||
).tolist()
|
||||
token_ids_single_batch.append(token_ids)
|
||||
bad_words_token_ids[batch_idx] = token_ids_single_batch
|
||||
if batch_size >= 2:
|
||||
@@ -95,26 +95,27 @@ def _create_bad_words_token_ids(
|
||||
# Returns all last tokens of bad word sequences that share the same prefix
|
||||
# as `given_prefix` (excluding the last token).
|
||||
def _collect_suffixes_with_same_prefix(
|
||||
given_prefix: list[int],
|
||||
bad_words_token_ids: list[list[int]]) -> list[int]:
|
||||
given_prefix: list[int], bad_words_token_ids: list[list[int]]
|
||||
) -> list[int]:
|
||||
return [bwt[-1] for bwt in bad_words_token_ids if bwt[:-1] == given_prefix]
|
||||
|
||||
|
||||
# generate a valid token id that is not in bad_words_token_ids
|
||||
def _generate_valid_token_id(bad_words_token_ids: list[list[int]],
|
||||
vocab_size: int) -> int:
|
||||
def _generate_valid_token_id(
|
||||
bad_words_token_ids: list[list[int]], vocab_size: int
|
||||
) -> int:
|
||||
forbidden_start_tokens = set()
|
||||
for bad_word in bad_words_token_ids:
|
||||
forbidden_start_tokens.add(bad_word[0])
|
||||
# Get a safe token that's not in forbidden starts
|
||||
safe_token_candidates = list(
|
||||
set(range(vocab_size)) - forbidden_start_tokens)
|
||||
safe_token_candidates = list(set(range(vocab_size)) - forbidden_start_tokens)
|
||||
# Pick a random safe token
|
||||
return np.random.choice(safe_token_candidates)
|
||||
|
||||
|
||||
def _update_output_token_ids_for_bad_words(
|
||||
metadata: SamplingMetadata, vocab_size: int) -> dict[int, list[int]]:
|
||||
metadata: SamplingMetadata, vocab_size: int
|
||||
) -> dict[int, list[int]]:
|
||||
bad_words_last_tokens = {}
|
||||
for batch_idx, bad_words_token_ids in metadata.bad_words_token_ids.items():
|
||||
output_token_ids = metadata.output_token_ids[batch_idx]
|
||||
@@ -132,12 +133,13 @@ def _update_output_token_ids_for_bad_words(
|
||||
# Collect all last tokens from other bad words
|
||||
# that share this prefix
|
||||
bad_words_last_token.extend(
|
||||
_collect_suffixes_with_same_prefix(
|
||||
prefix, bad_words_token_ids))
|
||||
_collect_suffixes_with_same_prefix(prefix, bad_words_token_ids)
|
||||
)
|
||||
break # Maximum one update to output_token_ids
|
||||
else: # Make sure no accidental match to bad words
|
||||
output_token_ids[-1] = _generate_valid_token_id(
|
||||
bad_words_token_ids, vocab_size)
|
||||
bad_words_token_ids, vocab_size
|
||||
)
|
||||
bad_words_last_tokens[batch_idx] = bad_words_last_token
|
||||
return bad_words_last_tokens
|
||||
|
||||
@@ -152,22 +154,24 @@ def _create_default_sampling_metadata(
|
||||
prompt_token_ids: list[list[int]] = []
|
||||
for _ in range(batch_size):
|
||||
output_token_ids.append(
|
||||
np.random.randint(0, vocab_size, size=num_output_tokens).tolist())
|
||||
np.random.randint(0, vocab_size, size=num_output_tokens).tolist()
|
||||
)
|
||||
prompt_token_ids.append(
|
||||
np.random.randint(0,
|
||||
vocab_size,
|
||||
size=np.random.randint(
|
||||
1, MAX_NUM_PROMPT_TOKENS)).tolist())
|
||||
np.random.randint(
|
||||
0, vocab_size, size=np.random.randint(1, MAX_NUM_PROMPT_TOKENS)
|
||||
).tolist()
|
||||
)
|
||||
fake_sampling_metadata = SamplingMetadata(
|
||||
temperature=torch.full((batch_size, ), 0.0),
|
||||
temperature=torch.full((batch_size,), 0.0),
|
||||
all_greedy=True,
|
||||
all_random=False,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
generators={},
|
||||
max_num_logprobs=0,
|
||||
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
|
||||
vocab_size, device),
|
||||
prompt_token_ids=_create_prompt_tokens_tensor(
|
||||
prompt_token_ids, vocab_size, device
|
||||
),
|
||||
output_token_ids=output_token_ids,
|
||||
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
|
||||
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
|
||||
@@ -181,8 +185,8 @@ def _create_default_sampling_metadata(
|
||||
|
||||
|
||||
def _create_weighted_output_token_list(
|
||||
batch_size: int,
|
||||
vocab_size: int) -> tuple[list[list[int]], list[list[int]]]:
|
||||
batch_size: int, vocab_size: int
|
||||
) -> tuple[list[list[int]], list[list[int]]]:
|
||||
"""
|
||||
Creates an output token list where each token occurs a distinct
|
||||
number of times.
|
||||
@@ -203,14 +207,13 @@ def _create_weighted_output_token_list(
|
||||
output_token_ids: list[list[int]] = []
|
||||
sorted_token_ids_in_output: list[list[int]] = []
|
||||
for _ in range(batch_size):
|
||||
distinct_token_ids = np.random.choice(vocab_size,
|
||||
size=np.random.randint(1, 10),
|
||||
replace=False).tolist()
|
||||
distinct_token_ids = np.random.choice(
|
||||
vocab_size, size=np.random.randint(1, 10), replace=False
|
||||
).tolist()
|
||||
sorted_token_ids_in_output.append(distinct_token_ids)
|
||||
output_token_ids_for_batch = []
|
||||
for index, token_id in enumerate(distinct_token_ids):
|
||||
output_token_ids_for_batch.extend(
|
||||
[token_id for _ in range(index + 1)])
|
||||
output_token_ids_for_batch.extend([token_id for _ in range(index + 1)])
|
||||
output_token_ids.append(output_token_ids_for_batch)
|
||||
return output_token_ids, sorted_token_ids_in_output
|
||||
|
||||
@@ -218,8 +221,9 @@ def _create_weighted_output_token_list(
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
@pytest.mark.parametrize("presence_penalty", [-2.0, 2.0])
|
||||
def test_sampler_presence_penalty(device: str, batch_size: int,
|
||||
presence_penalty: float):
|
||||
def test_sampler_presence_penalty(
|
||||
device: str, batch_size: int, presence_penalty: float
|
||||
):
|
||||
"""
|
||||
Test to verify that if presence penalty is enabled then tokens
|
||||
are penalized as per their presence in the existing output.
|
||||
@@ -229,10 +233,12 @@ def test_sampler_presence_penalty(device: str, batch_size: int,
|
||||
# logit value.
|
||||
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
|
||||
sampling_metadata = _create_default_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)
|
||||
)
|
||||
output_token_ids = sampling_metadata.output_token_ids
|
||||
sampling_metadata.presence_penalties = _create_penalty_tensor(
|
||||
batch_size, presence_penalty, torch.device(device))
|
||||
batch_size, presence_penalty, torch.device(device)
|
||||
)
|
||||
sampling_metadata.no_penalties = False
|
||||
sampler = Sampler()
|
||||
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
||||
@@ -263,8 +269,9 @@ def test_sampler_presence_penalty(device: str, batch_size: int,
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
@pytest.mark.parametrize("frequency_penalty", [-2.0, 2.0])
|
||||
def test_sampler_frequency_penalty(device: str, batch_size: int,
|
||||
frequency_penalty: float):
|
||||
def test_sampler_frequency_penalty(
|
||||
device: str, batch_size: int, frequency_penalty: float
|
||||
):
|
||||
"""
|
||||
Test to verify that if frequency penalty is enabled then tokens are
|
||||
penalized as per their frequency of occurrence.
|
||||
@@ -274,14 +281,15 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
|
||||
# logit value.
|
||||
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
|
||||
sampling_metadata = _create_default_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)
|
||||
)
|
||||
sampling_metadata.frequency_penalties = _create_penalty_tensor(
|
||||
batch_size, frequency_penalty, torch.device(device))
|
||||
output_token_ids, sorted_token_ids_in_output = \
|
||||
_create_weighted_output_token_list(
|
||||
batch_size,
|
||||
VOCAB_SIZE,
|
||||
)
|
||||
batch_size, frequency_penalty, torch.device(device)
|
||||
)
|
||||
output_token_ids, sorted_token_ids_in_output = _create_weighted_output_token_list(
|
||||
batch_size,
|
||||
VOCAB_SIZE,
|
||||
)
|
||||
sampling_metadata.output_token_ids = output_token_ids
|
||||
sampling_metadata.no_penalties = False
|
||||
sampler = Sampler()
|
||||
@@ -290,18 +298,17 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
|
||||
for batch_idx in range(batch_size):
|
||||
non_penalized_token_id = logits[batch_idx].argmax().item()
|
||||
penalized_token_id = logits[batch_idx].argmin().item()
|
||||
distinct_sorted_token_ids_in_output = sorted_token_ids_in_output[
|
||||
batch_idx]
|
||||
distinct_sorted_token_ids_in_output = sorted_token_ids_in_output[batch_idx]
|
||||
most_frequent_token_id = distinct_sorted_token_ids_in_output[
|
||||
len(distinct_sorted_token_ids_in_output) - 1]
|
||||
len(distinct_sorted_token_ids_in_output) - 1
|
||||
]
|
||||
if frequency_penalty > 0:
|
||||
# If `frequency_penalty` is set to > 0, it indicates
|
||||
# a preference for new tokens over existing ones. Verify that the
|
||||
# non-penalized token ID is not present in the output, while the
|
||||
# most penalized token is the one that occurs most frequently in
|
||||
# the output.
|
||||
assert (non_penalized_token_id
|
||||
not in distinct_sorted_token_ids_in_output)
|
||||
assert non_penalized_token_id not in distinct_sorted_token_ids_in_output
|
||||
assert penalized_token_id == most_frequent_token_id
|
||||
elif frequency_penalty < 0:
|
||||
# If `frequency_penalty` is set to < 0, it indicates
|
||||
@@ -316,8 +323,9 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
@pytest.mark.parametrize("repetition_penalty", [0.1, 1.9])
|
||||
def test_sampler_repetition_penalty(device: str, batch_size: int,
|
||||
repetition_penalty: float):
|
||||
def test_sampler_repetition_penalty(
|
||||
device: str, batch_size: int, repetition_penalty: float
|
||||
):
|
||||
"""
|
||||
Test to verify that when the repetition penalty is enabled, tokens
|
||||
are penalized based on their presence in the prompt or the existing
|
||||
@@ -328,9 +336,11 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
|
||||
# logit value.
|
||||
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
|
||||
sampling_metadata = _create_default_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)
|
||||
)
|
||||
sampling_metadata.repetition_penalties = _create_penalty_tensor(
|
||||
batch_size, repetition_penalty, torch.device(device))
|
||||
batch_size, repetition_penalty, torch.device(device)
|
||||
)
|
||||
sampling_metadata.no_penalties = False
|
||||
sampler = Sampler()
|
||||
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
||||
@@ -338,32 +348,40 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
|
||||
for batch_idx in range(batch_size):
|
||||
non_penalized_token_id = logits[batch_idx].argmax().item()
|
||||
penalized_token_id = logits[batch_idx].argmin().item()
|
||||
prompt_tokens = sampling_metadata.prompt_token_ids[
|
||||
batch_idx][:].tolist()
|
||||
prompt_tokens = sampling_metadata.prompt_token_ids[batch_idx][:].tolist()
|
||||
output_tokens = sampling_metadata.output_token_ids[batch_idx]
|
||||
if repetition_penalty > 1.0:
|
||||
# If `repetition_penalty` > 1.0, verify that the non-penalized
|
||||
# token ID has not been seen before, while the penalized token ID
|
||||
# exists either in the prompt or the output.
|
||||
assert (non_penalized_token_id not in prompt_tokens
|
||||
and non_penalized_token_id not in output_tokens)
|
||||
assert (penalized_token_id in prompt_tokens
|
||||
or penalized_token_id in output_tokens)
|
||||
assert (
|
||||
non_penalized_token_id not in prompt_tokens
|
||||
and non_penalized_token_id not in output_tokens
|
||||
)
|
||||
assert (
|
||||
penalized_token_id in prompt_tokens
|
||||
or penalized_token_id in output_tokens
|
||||
)
|
||||
elif repetition_penalty < 1.0:
|
||||
# If `repetition_penalty` < 1.0, verify that the penalized
|
||||
# token ID has not been seen before, while the non-penalized
|
||||
# token ID exists either in the prompt or the output.
|
||||
assert (penalized_token_id not in prompt_tokens
|
||||
and penalized_token_id not in output_tokens)
|
||||
assert (non_penalized_token_id in prompt_tokens
|
||||
or non_penalized_token_id in output_tokens)
|
||||
assert (
|
||||
penalized_token_id not in prompt_tokens
|
||||
and penalized_token_id not in output_tokens
|
||||
)
|
||||
assert (
|
||||
non_penalized_token_id in prompt_tokens
|
||||
or non_penalized_token_id in output_tokens
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
@pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2])
|
||||
def test_sampler_allowed_token_ids(device: str, batch_size: int,
|
||||
num_allowed_token_ids: int):
|
||||
def test_sampler_allowed_token_ids(
|
||||
device: str, batch_size: int, num_allowed_token_ids: int
|
||||
):
|
||||
"""
|
||||
Test to verify that when the repetition penalty is enabled, tokens
|
||||
are penalized based on their presence in the prompt or the existing
|
||||
@@ -374,7 +392,8 @@ def test_sampler_allowed_token_ids(device: str, batch_size: int,
|
||||
# logit value.
|
||||
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
|
||||
sampling_metadata = _create_default_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)
|
||||
)
|
||||
mask = _create_allowed_token_ids(
|
||||
batch_size=batch_size,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
@@ -394,17 +413,19 @@ def test_sampler_allowed_token_ids(device: str, batch_size: int,
|
||||
start = min(batch_idx, VOCAB_SIZE - 1)
|
||||
end = min(batch_idx + num_allowed_token_ids, VOCAB_SIZE - 1)
|
||||
if token_id >= start and token_id < end:
|
||||
assert logits_for_req[token_id] == -float(
|
||||
"inf"), f"{batch_idx}, {token_id}"
|
||||
assert logits_for_req[token_id] == -float("inf"), (
|
||||
f"{batch_idx}, {token_id}"
|
||||
)
|
||||
else:
|
||||
assert logits_for_req[token_id] != -float("inf")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
@pytest.mark.parametrize("bad_words_lengths", [(1, ), (1, 3), (2, 2)])
|
||||
def test_sampler_bad_words(device: str, batch_size: int,
|
||||
bad_words_lengths: tuple[int, ...]):
|
||||
@pytest.mark.parametrize("bad_words_lengths", [(1,), (1, 3), (2, 2)])
|
||||
def test_sampler_bad_words(
|
||||
device: str, batch_size: int, bad_words_lengths: tuple[int, ...]
|
||||
):
|
||||
"""
|
||||
Test to verify that when the bad words restriction is present, tokens
|
||||
are penalized based on their match with the bad words.
|
||||
@@ -414,19 +435,24 @@ def test_sampler_bad_words(device: str, batch_size: int,
|
||||
# logit value.
|
||||
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
|
||||
sampling_metadata = _create_default_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)
|
||||
)
|
||||
sampling_metadata.bad_words_token_ids = _create_bad_words_token_ids(
|
||||
batch_size, VOCAB_SIZE, bad_words_lengths)
|
||||
batch_size, VOCAB_SIZE, bad_words_lengths
|
||||
)
|
||||
bad_words_last_tokens = _update_output_token_ids_for_bad_words(
|
||||
sampling_metadata, VOCAB_SIZE)
|
||||
sampling_metadata, VOCAB_SIZE
|
||||
)
|
||||
sampler = Sampler()
|
||||
logits = sampler.apply_bad_words(fake_logits, sampling_metadata)
|
||||
logits = logits.cpu()
|
||||
for batch_idx in range(batch_size):
|
||||
logits_for_req = logits[batch_idx]
|
||||
for token_id in range(VOCAB_SIZE):
|
||||
if (batch_idx in bad_words_last_tokens
|
||||
and token_id in bad_words_last_tokens[batch_idx]):
|
||||
if (
|
||||
batch_idx in bad_words_last_tokens
|
||||
and token_id in bad_words_last_tokens[batch_idx]
|
||||
):
|
||||
assert logits_for_req[token_id] == -float("inf")
|
||||
else:
|
||||
assert logits_for_req[token_id] != -float("inf")
|
||||
|
||||
@@ -66,9 +66,9 @@ def test_stop(llm):
|
||||
# Output should not contain the stop word.
|
||||
assert len(new_split_text) == STOP_IDX
|
||||
|
||||
params = SamplingParams(temperature=0,
|
||||
stop=split_text[STOP_IDX],
|
||||
include_stop_str_in_output=True)
|
||||
params = SamplingParams(
|
||||
temperature=0, stop=split_text[STOP_IDX], include_stop_str_in_output=True
|
||||
)
|
||||
output = llm.generate(PROMPT, params)
|
||||
new_split_text = output[0].outputs[0].text.split()
|
||||
|
||||
@@ -103,8 +103,8 @@ def test_detokenize_false(llm):
|
||||
assert len(output[0].outputs[0].text) == 0
|
||||
|
||||
output = llm.generate(
|
||||
PROMPT, SamplingParams(detokenize=False, logprobs=3,
|
||||
prompt_logprobs=3))
|
||||
PROMPT, SamplingParams(detokenize=False, logprobs=3, prompt_logprobs=3)
|
||||
)
|
||||
assert len(output[0].outputs[0].token_ids) > 0
|
||||
assert len(output[0].outputs[0].text) == 0
|
||||
|
||||
@@ -131,8 +131,7 @@ def test_bad_words(llm):
|
||||
assert bad_words_1 not in new_text
|
||||
|
||||
bad_words_2 = new_text.split()[-1]
|
||||
params = SamplingParams(temperature=0,
|
||||
bad_words=[bad_words_1, bad_words_2])
|
||||
params = SamplingParams(temperature=0, bad_words=[bad_words_1, bad_words_2])
|
||||
output = llm.generate(PROMPT, params)
|
||||
new_text = output[0].outputs[0].text
|
||||
assert bad_words_1 not in new_text
|
||||
@@ -158,8 +157,7 @@ def test_allowed_token_ids(llm):
|
||||
|
||||
TOKEN_ID = 10
|
||||
allowed_token_ids = [TOKEN_ID]
|
||||
output = llm.generate(PROMPT,
|
||||
SamplingParams(allowed_token_ids=allowed_token_ids))
|
||||
output = llm.generate(PROMPT, SamplingParams(allowed_token_ids=allowed_token_ids))
|
||||
assert output[0].outputs[0].token_ids[-1] == TOKEN_ID
|
||||
|
||||
# Reject empty allowed_token_ids.
|
||||
|
||||
@@ -5,8 +5,10 @@ import torch
|
||||
from torch import Generator
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
|
||||
is_flashinfer_available)
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import (
|
||||
apply_top_k_top_p,
|
||||
is_flashinfer_available,
|
||||
)
|
||||
|
||||
DEVICE = current_platform.device_type
|
||||
|
||||
@@ -30,19 +32,18 @@ def reset_default_device():
|
||||
|
||||
|
||||
def test_topk_impl_equivalence():
|
||||
|
||||
torch.set_default_device(DEVICE)
|
||||
generator = Generator(device=DEVICE).manual_seed(33)
|
||||
|
||||
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
|
||||
|
||||
# Random top-k values between 1 and 9.
|
||||
k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator)
|
||||
k = torch.randint(1, 10, (BATCH_SIZE,), generator=generator)
|
||||
|
||||
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
|
||||
k.masked_fill_(
|
||||
torch.randint(0, 2, (BATCH_SIZE, ), generator=generator, dtype=bool),
|
||||
VOCAB_SIZE)
|
||||
torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=bool), VOCAB_SIZE
|
||||
)
|
||||
|
||||
# Top-k only implementation
|
||||
result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
|
||||
@@ -55,7 +56,7 @@ def test_topk_impl_equivalence():
|
||||
|
||||
|
||||
def test_flashinfer_sampler():
|
||||
'''
|
||||
"""
|
||||
This test verifies that the FlashInfer top-k and top-p sampling
|
||||
implementation produces the same results as the Python implementation.
|
||||
|
||||
@@ -63,11 +64,10 @@ def test_flashinfer_sampler():
|
||||
top-p prob renorm (it did provide fused sampling but we cannot compare
|
||||
sampling results due to randomness), so we will compare the probability
|
||||
renormed consequently by top-k and then top-p of FlashInfer implementation.
|
||||
'''
|
||||
"""
|
||||
|
||||
if not FLASHINFER_ENABLED:
|
||||
pytest.skip(
|
||||
"FlashInfer not installed or not available on this platform.")
|
||||
pytest.skip("FlashInfer not installed or not available on this platform.")
|
||||
|
||||
torch.set_default_device(DEVICE)
|
||||
generator = Generator(device=DEVICE).manual_seed(42)
|
||||
@@ -76,23 +76,21 @@ def test_flashinfer_sampler():
|
||||
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
|
||||
|
||||
# Generate various top-k and top-p values
|
||||
k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator)
|
||||
p_values = torch.rand(
|
||||
(BATCH_SIZE, ), generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0]
|
||||
k_values = torch.randint(1, 1000, (BATCH_SIZE,), generator=generator)
|
||||
p_values = (
|
||||
torch.rand((BATCH_SIZE,), generator=generator) * 0.5 + 0.5
|
||||
) # range in [0.5, 1.0]
|
||||
|
||||
# Sometimes disable top-k (k=vocab_size)
|
||||
k_values.masked_fill_(
|
||||
torch.randint(0,
|
||||
2, (BATCH_SIZE, ),
|
||||
generator=generator,
|
||||
dtype=torch.bool), VOCAB_SIZE)
|
||||
torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool),
|
||||
VOCAB_SIZE,
|
||||
)
|
||||
|
||||
# Sometimes disable top-p (p=1.0)
|
||||
p_values.masked_fill_(
|
||||
torch.randint(0,
|
||||
2, (BATCH_SIZE, ),
|
||||
generator=generator,
|
||||
dtype=torch.bool), 1.0)
|
||||
torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool), 1.0
|
||||
)
|
||||
|
||||
python_logits = apply_top_k_top_p(
|
||||
logits=logits.clone(),
|
||||
@@ -113,5 +111,6 @@ def test_flashinfer_sampler():
|
||||
)
|
||||
|
||||
# Compare the results
|
||||
assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \
|
||||
assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), (
|
||||
"FlashInfer and Python sampling implementations do not match!"
|
||||
)
|
||||
|
||||
@@ -16,6 +16,7 @@ from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
class BatchLogprobsComposition(Enum):
|
||||
"""Types of logprobs configs to include in test batch"""
|
||||
|
||||
NONE = 0
|
||||
SAMPLE = 1
|
||||
PROMPT = 2
|
||||
@@ -26,10 +27,10 @@ BatchLogprobsSpecType = list[tuple[Optional[int], Optional[int]]]
|
||||
|
||||
|
||||
def get_test_batch(
|
||||
batch_logprobs_composition: BatchLogprobsComposition
|
||||
batch_logprobs_composition: BatchLogprobsComposition,
|
||||
) -> BatchLogprobsSpecType:
|
||||
"""Generate logprobs configs for a batch of requests
|
||||
|
||||
|
||||
A given request's logprobs configuration is (1) num_sample_logprobs and (2)
|
||||
num_prompt_logprobs. The batch logprobs configuration is the list of request
|
||||
logprobs configs.
|
||||
@@ -101,7 +102,7 @@ def assert_incr_detok_str_matches_non_incr_detok_str(
|
||||
msg: str,
|
||||
) -> None:
|
||||
"""Compare incrementally detok. text to non-incrementally detok. text
|
||||
|
||||
|
||||
Fail if the strings mismatch after non-alphanumeric characters are stripped
|
||||
out.
|
||||
|
||||
@@ -120,15 +121,15 @@ def assert_incr_detok_str_matches_non_incr_detok_str(
|
||||
tokens
|
||||
msg: error message if `assert` fails
|
||||
"""
|
||||
rgx = r'[^a-zA-Z0-9]+'
|
||||
assert (re.sub(rgx, '', incremental_detokenization_str) == re.sub(
|
||||
rgx, '', non_incremental_detokenization_str)), (msg)
|
||||
rgx = r"[^a-zA-Z0-9]+"
|
||||
assert re.sub(rgx, "", incremental_detokenization_str) == re.sub(
|
||||
rgx, "", non_incremental_detokenization_str
|
||||
), msg
|
||||
|
||||
|
||||
def compute_correct_cumulative_logprob(
|
||||
completion_output: CompletionOutput) -> float:
|
||||
def compute_correct_cumulative_logprob(completion_output: CompletionOutput) -> float:
|
||||
"""Compute known-good value for evaluating cumulative logprob
|
||||
|
||||
|
||||
Args:
|
||||
completion_output: completion output from engine
|
||||
|
||||
@@ -146,12 +147,12 @@ def create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor:
|
||||
return fake_logits
|
||||
|
||||
|
||||
def create_penalty_tensor(batch_size: int, penalty_value: float,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
return torch.full((batch_size, ),
|
||||
fill_value=penalty_value,
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
def create_penalty_tensor(
|
||||
batch_size: int, penalty_value: float, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
return torch.full(
|
||||
(batch_size,), fill_value=penalty_value, dtype=torch.float, device=device
|
||||
)
|
||||
|
||||
|
||||
def create_prompt_tokens_tensor(
|
||||
@@ -170,6 +171,7 @@ def create_prompt_tokens_tensor(
|
||||
|
||||
class LogitsprocsTestFakes(NamedTuple):
|
||||
"""Wraps fake data structures to support testing"""
|
||||
|
||||
logits: torch.Tensor
|
||||
sampling_metadata: SamplingMetadata
|
||||
|
||||
@@ -178,15 +180,16 @@ class LogitsprocsTestFakes(NamedTuple):
|
||||
cls: type[LogitsProcessor],
|
||||
) -> Iterator[LogitsProcessor]:
|
||||
"""Yield logits processors of a specific class.
|
||||
|
||||
|
||||
Args:
|
||||
cls: :class:`LogitsProcessor` subclass
|
||||
|
||||
Returns:
|
||||
Iterator over logits processors
|
||||
"""
|
||||
return (lp for lp in self.sampling_metadata.logitsprocs.all
|
||||
if isinstance(lp, cls))
|
||||
return (
|
||||
lp for lp in self.sampling_metadata.logitsprocs.all if isinstance(lp, cls)
|
||||
)
|
||||
|
||||
def get_logitsprocs(self) -> Iterator[LogitsProcessor]:
|
||||
"""Iterator over all logits processors."""
|
||||
@@ -208,8 +211,7 @@ def fake_apply_logitsprocs(
|
||||
slice_indices: list[int],
|
||||
) -> torch.Tensor:
|
||||
"""Imitate application of logits processors in engine core"""
|
||||
logits = test_fakes.logits[torch.tensor(slice_indices,
|
||||
dtype=torch.long)].clone()
|
||||
logits = test_fakes.logits[torch.tensor(slice_indices, dtype=torch.long)].clone()
|
||||
for processor in test_fakes.get_logitsprocs():
|
||||
logits = processor.apply(logits)
|
||||
return logits
|
||||
|
||||
Reference in New Issue
Block a user