Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -9,9 +9,12 @@ import pytest
import torch
from tests.v1.sample.utils import (
BatchLogprobsComposition, BatchLogprobsSpecType,
BatchLogprobsComposition,
BatchLogprobsSpecType,
assert_incr_detok_str_matches_non_incr_detok_str,
compute_correct_cumulative_logprob, get_test_batch)
compute_correct_cumulative_logprob,
get_test_batch,
)
from vllm import SamplingParams
from vllm.config import LogprobsMode
@@ -29,22 +32,23 @@ SAMPLE_PROMPT = BatchLogprobsComposition.SAMPLE_PROMPT
@pytest.fixture(
scope="module",
# Parameterize APC
params=[False, True])
params=[False, True],
)
def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]:
with vllm_runner(
MODEL,
dtype=DTYPE,
max_logprobs=7,
# Very small number of batched tokens to ensure
# that we test chunking.
max_num_batched_tokens=16,
max_num_seqs=16,
max_model_len=128,
enforce_eager=True,
#TODO: enable this once we support it for
# prompt logprobs.
enable_prefix_caching=request.param,
gpu_memory_utilization=0.4, # up to 2 alive concurrently
MODEL,
dtype=DTYPE,
max_logprobs=7,
# Very small number of batched tokens to ensure
# that we test chunking.
max_num_batched_tokens=16,
max_num_seqs=16,
max_model_len=128,
enforce_eager=True,
# TODO: enable this once we support it for
# prompt logprobs.
enable_prefix_caching=request.param,
gpu_memory_utilization=0.4, # up to 2 alive concurrently
) as vllm_model:
yield vllm_model
@@ -96,8 +100,8 @@ def _repeat_logprob_config(
num_test_prompts = len(test_prompts)
# Make sure there is a logprobs configuration for each test prompt
logprob_prompt_logprob_list = list(
itertools.islice(itertools.cycle(logprob_prompt_logprob_list),
num_test_prompts))
itertools.islice(itertools.cycle(logprob_prompt_logprob_list), num_test_prompts)
)
# Now the number of prompts should match the number of sample params combos
assert num_test_prompts == len(logprob_prompt_logprob_list)
return logprob_prompt_logprob_list
@@ -115,24 +119,28 @@ def _run_and_validate(
do_apc: bool,
) -> None:
vllm_results = vllm_model.llm.generate(
test_prompts, sampling_params=vllm_sampling_params)
test_prompts, sampling_params=vllm_sampling_params
)
for vllm_result, hf_logprob, hf_output, logprob_prompt_logprob in zip(
vllm_results, hf_logprobs, hf_outputs,
logprob_prompt_logprob_list):
vllm_results, hf_logprobs, hf_outputs, logprob_prompt_logprob_list
):
# Extract request-level (prompt)logprobs config
num_top_logprobs, num_top_prompt_logprobs = logprob_prompt_logprob
# Test whether sampled token output is consistent between vLLM and HF
# vLLM prompt+completion should match HF output
if temperature == 0.0:
assert (vllm_result.prompt_token_ids +
vllm_result.outputs[0].token_ids == hf_output[0])
assert (
vllm_result.prompt_token_ids + vllm_result.outputs[0].token_ids
== hf_output[0]
)
else:
# Sampled tokens won't match if not greedy
assert (vllm_result.prompt_token_ids == hf_output[0]
[:len(vllm_result.prompt_token_ids)])
assert (
vllm_result.prompt_token_ids
== hf_output[0][: len(vllm_result.prompt_token_ids)]
)
# Validate sample logprobs
if num_top_logprobs is not None:
@@ -141,8 +149,9 @@ def _run_and_validate(
# correct
assert vllm_result.outputs[0].logprobs is not None
assert len(vllm_result.outputs[0].logprobs) == max_tokens
for logprobs, token_id in zip(vllm_result.outputs[0].logprobs,
vllm_result.outputs[0].token_ids):
for logprobs, token_id in zip(
vllm_result.outputs[0].logprobs, vllm_result.outputs[0].token_ids
):
assert logprobs is not None
# Confirm that the output token appears among the logprobs
@@ -159,23 +168,26 @@ def _run_and_validate(
if num_top_logprobs > 0:
# We should have an entry for each of the topk ranks
all_ranks = {lp.rank for lp in logprobs.values()}
assert all(r in all_ranks
for r in range(1, num_top_logprobs + 1))
assert all(r in all_ranks for r in range(1, num_top_logprobs + 1))
output_text = vllm_result.outputs[0].text
output_string_from_most_likely_tokens_lst: list[str] = []
for top_logprobs in vllm_result.outputs[0].logprobs:
top_logprob = next(iter(top_logprobs.values()))
output_string_from_most_likely_tokens_lst.append(
top_logprob.decoded_token)
top_logprob.decoded_token
)
output_string_from_most_likely_tokens = "".join(
output_string_from_most_likely_tokens_lst)
output_string_from_most_likely_tokens_lst
)
assert_incr_detok_str_matches_non_incr_detok_str(
output_text, output_string_from_most_likely_tokens,
output_text,
output_string_from_most_likely_tokens,
"The output text from the top logprob for each token "
"position should be the same as the output text in the "
"result.")
"result.",
)
# Compare vLLM sample logprobs to HF
vllm_sample_logprobs = vllm_result.outputs[0].logprobs
@@ -187,11 +199,12 @@ def _run_and_validate(
logprob,
hf_logprob[i][-1][token_id].item(),
atol=1e-2,
rtol=1e-2)
assert isinstance(
sample_logprob.decoded_token,
str), ("The token should be decoded by the time it is"
" returned to the user.")
rtol=1e-2,
)
assert isinstance(sample_logprob.decoded_token, str), (
"The token should be decoded by the time it is"
" returned to the user."
)
# At this point we know the sample logprobs are correct for this
# request. Validate that cumulative_logprob is actually the sum.
@@ -201,7 +214,8 @@ def _run_and_validate(
vllm_result.outputs[0].cumulative_logprob,
compute_correct_cumulative_logprob(vllm_result.outputs[0]),
atol=1e-6,
rtol=1e-6)
rtol=1e-6,
)
else:
# Logprobs disabled for this request; should be None
assert vllm_result.outputs[0].logprobs is None
@@ -214,17 +228,17 @@ def _run_and_validate(
assert vllm_result.prompt_logprobs[0] is None
# - Prompt logprobs are returned for all indices in
# the prompt
assert len(vllm_result.prompt_logprobs) == len(
vllm_result.prompt_token_ids)
assert len(vllm_result.prompt_logprobs) == len(vllm_result.prompt_token_ids)
for prompt_logprobs, prompt_token_id in zip(
vllm_result.prompt_logprobs[1:],
vllm_result.prompt_token_ids[1:]):
vllm_result.prompt_logprobs[1:], vllm_result.prompt_token_ids[1:]
):
assert prompt_logprobs is not None
# Confirm that the prompt token appears among the logprobs
assert prompt_token_id in prompt_logprobs
token_in_topk = prompt_logprobs[
prompt_token_id].rank <= num_top_prompt_logprobs
token_in_topk = (
prompt_logprobs[prompt_token_id].rank <= num_top_prompt_logprobs
)
# If the prompt token is not included in the top K
# logprob, it can return 1 more data
@@ -236,8 +250,9 @@ def _run_and_validate(
if num_top_prompt_logprobs > 0:
# We should have an entry for each of the topk ranks
all_ranks = {lp.rank for lp in prompt_logprobs.values()}
assert all(r in all_ranks
for r in range(1, num_top_prompt_logprobs + 1))
assert all(
r in all_ranks for r in range(1, num_top_prompt_logprobs + 1)
)
# Compare prompt logprobs to HF
# The first prompt logprob is always None, so we compare it from
@@ -249,19 +264,24 @@ def _run_and_validate(
logprob.logprob,
hf_logprob[0][i][token_id].item(),
atol=2e-2,
rtol=2e-2)
rtol=2e-2,
)
else:
assert vllm_result.prompt_logprobs is None
@pytest.mark.parametrize("batch_logprobs_composition",
[NONE, SAMPLE, PROMPT, SAMPLE_PROMPT])
@pytest.mark.parametrize(
"batch_logprobs_composition", [NONE, SAMPLE, PROMPT, SAMPLE_PROMPT]
)
@pytest.mark.parametrize("temperature", [0.0, 2.0])
def test_get_logprobs_and_prompt_logprobs(
hf_model, vllm_model,
batch_logprobs_composition: BatchLogprobsComposition,
temperature: float, example_prompts: list[str],
monkeypatch: pytest.MonkeyPatch) -> None:
hf_model,
vllm_model,
batch_logprobs_composition: BatchLogprobsComposition,
temperature: float,
example_prompts: list[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test V1 Engine logprobs & prompt logprobs
Exercise a variety of combinations of `logprobs` and `prompt_logprobs`
@@ -291,8 +311,9 @@ def test_get_logprobs_and_prompt_logprobs(
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
do_apc = vllm_model.llm.llm_engine.cache_config.enable_prefix_caching
if do_apc and (temperature < 2.0
or batch_logprobs_composition != SAMPLE_PROMPT):
if do_apc and (
temperature < 2.0 or batch_logprobs_composition != SAMPLE_PROMPT
):
# Skip some test-cases to save time.
pytest.skip()
test_prompts = example_prompts
@@ -309,19 +330,21 @@ def test_get_logprobs_and_prompt_logprobs(
# Batch has mixed sample params
# (different logprobs/prompt logprobs combos)
logprob_prompt_logprob_list = get_test_batch(
batch_logprobs_composition)
logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition)
# Ensure that each test prompt has a logprob config for testing
logprob_prompt_logprob_list = _repeat_logprob_config(
test_prompts, logprob_prompt_logprob_list)
test_prompts, logprob_prompt_logprob_list
)
# Generate SamplingParams
vllm_sampling_params = [
SamplingParams(max_tokens=max_tokens,
logprobs=num_lp,
prompt_logprobs=num_plp,
temperature=temperature,
seed=1984)
SamplingParams(
max_tokens=max_tokens,
logprobs=num_lp,
prompt_logprobs=num_plp,
temperature=temperature,
seed=1984,
)
for num_lp, num_plp in logprob_prompt_logprob_list
]
for _ in range(2 if do_apc else 1):
@@ -334,7 +357,8 @@ def test_get_logprobs_and_prompt_logprobs(
logprob_prompt_logprob_list=logprob_prompt_logprob_list,
temperature=temperature,
max_tokens=max_tokens,
do_apc=do_apc)
do_apc=do_apc,
)
def test_max_logprobs(monkeypatch: pytest.MonkeyPatch):
@@ -351,19 +375,18 @@ def test_max_logprobs(monkeypatch: pytest.MonkeyPatch):
enable_prefix_caching=False,
# 2 other llms alive during whole session
gpu_memory_utilization=0.15,
max_model_len=256)
max_model_len=256,
)
vllm_sampling_params = SamplingParams(logprobs=1)
# should pass
runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
bad_sampling_params = SamplingParams(logprobs=2)
with pytest.raises(ValueError):
runner.generate(["Hello world"],
sampling_params=bad_sampling_params)
runner.generate(["Hello world"], sampling_params=bad_sampling_params)
def test_none_logprobs(vllm_model, example_prompts,
monkeypatch: pytest.MonkeyPatch):
def test_none_logprobs(vllm_model, example_prompts, monkeypatch: pytest.MonkeyPatch):
"""Engine should return `logprobs` and `prompt_logprobs` as `None`
Args:
@@ -388,14 +411,12 @@ def test_none_logprobs(vllm_model, example_prompts,
for i in range(len(results_logprobs_none)):
# Check sample logprobs are None
assert results_logprobs_none[i].outputs[0].logprobs is None
assert results_logprobs_none[i].outputs[
0].cumulative_logprob is None
assert results_logprobs_none[i].outputs[0].cumulative_logprob is None
# Check prompt logprobs are None
assert results_logprobs_none[i].prompt_logprobs is None
def test_zero_logprobs(vllm_model, example_prompts,
monkeypatch: pytest.MonkeyPatch):
def test_zero_logprobs(vllm_model, example_prompts, monkeypatch: pytest.MonkeyPatch):
"""Engine should return sampled token and prompt token logprobs
Args:
@@ -406,12 +427,12 @@ def test_zero_logprobs(vllm_model, example_prompts,
m.setenv("VLLM_USE_V1", "1")
max_tokens = 5
sampling_params_logprobs_zero = SamplingParams(max_tokens=max_tokens,
logprobs=0,
prompt_logprobs=0,
temperature=0.0)
sampling_params_logprobs_zero = SamplingParams(
max_tokens=max_tokens, logprobs=0, prompt_logprobs=0, temperature=0.0
)
results_logprobs_zero = vllm_model.llm.generate(
example_prompts, sampling_params=sampling_params_logprobs_zero)
example_prompts, sampling_params=sampling_params_logprobs_zero
)
for i in range(len(results_logprobs_zero)):
# Check that there is one sample logprob dict for each
@@ -422,8 +443,7 @@ def test_zero_logprobs(vllm_model, example_prompts,
prompt_token_ids = results_logprobs_zero[i].prompt_token_ids
assert logprobs is not None
assert len(sampled_token_ids) == len(logprobs)
assert results_logprobs_zero[i].outputs[
0].cumulative_logprob is not None
assert results_logprobs_zero[i].outputs[0].cumulative_logprob is not None
# Check that there is one prompt logprob dict for each
# prompt token
assert prompt_logprobs is not None
@@ -444,13 +464,15 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
enable_prefix_caching=False,
# 2 other llms alive during whole session
gpu_memory_utilization=0.15,
max_model_len=256)
max_model_len=256,
)
sampling_params_logprobs_all = SamplingParams(max_tokens=5,
logprobs=-1,
prompt_logprobs=-1)
sampling_params_logprobs_all = SamplingParams(
max_tokens=5, logprobs=-1, prompt_logprobs=-1
)
results_logprobs_all = runner.llm.generate(
example_prompts, sampling_params=sampling_params_logprobs_all)
example_prompts, sampling_params=sampling_params_logprobs_all
)
vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size()
for i in range(len(results_logprobs_all)):
@@ -466,13 +488,13 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode))
def test_logprobs_mode(logprobs_mode: LogprobsMode,
monkeypatch: pytest.MonkeyPatch):
def test_logprobs_mode(logprobs_mode: LogprobsMode, monkeypatch: pytest.MonkeyPatch):
"""Test with LLM engine with different logprobs_mode.
For logprobs, we should have non-positive values.
For logits, we should expect at least one positive values.
"""
from vllm import LLM
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
@@ -483,10 +505,10 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode,
# 2 other llms alive during whole session
gpu_memory_utilization=0.05,
max_model_len=16,
logprobs_mode=logprobs_mode)
logprobs_mode=logprobs_mode,
)
vllm_sampling_params = SamplingParams(logprobs=1)
results = llm.generate(["Hello world"],
sampling_params=vllm_sampling_params)
results = llm.generate(["Hello world"], sampling_params=vllm_sampling_params)
total_token_with_logprobs = 0
positive_values = 0

View File

@@ -15,22 +15,23 @@ EXPECTED_VALUE = 0.62
MODEL = "meta-llama/Llama-3.2-1B-Instruct"
MODEL_ARGS = f"pretrained={MODEL},enforce_eager=True,enable_prefix_caching=False,gpu_memory_utilization=0.8" # noqa: E501
SERVER_ARGS = [
"--enforce_eager", "--no_enable_prefix_caching",
"--gpu-memory-utilization=0.8"
"--enforce_eager",
"--no_enable_prefix_caching",
"--gpu-memory-utilization=0.8",
]
NUM_CONCURRENT = 100
def test_prompt_logprobs_e2e():
results = lm_eval.simple_evaluate(model="vllm",
model_args=MODEL_ARGS,
tasks=TASK,
batch_size="auto")
results = lm_eval.simple_evaluate(
model="vllm", model_args=MODEL_ARGS, tasks=TASK, batch_size="auto"
)
measured_value = results["results"][TASK][FILTER]
assert (measured_value - RTOL < EXPECTED_VALUE
and measured_value + RTOL > EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
assert (
measured_value - RTOL < EXPECTED_VALUE
and measured_value + RTOL > EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
def test_prompt_logprobs_e2e_server():
@@ -40,7 +41,8 @@ def test_prompt_logprobs_e2e_server():
model_args = (
f"model={MODEL},"
f"base_url={url},"
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False")
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False"
)
results = lm_eval.simple_evaluate(
model="local-completions",
@@ -49,6 +51,7 @@ def test_prompt_logprobs_e2e_server():
)
measured_value = results["results"][TASK][FILTER]
assert (measured_value - RTOL < EXPECTED_VALUE
and measured_value + RTOL > EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
assert (
measured_value - RTOL < EXPECTED_VALUE
and measured_value + RTOL > EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"

View File

@@ -9,8 +9,7 @@ import torch.nn.functional as F
from vllm.platforms import current_platform
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID,
RejectionSampler)
from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID, RejectionSampler
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
DEVICE = current_platform.device_type
@@ -21,10 +20,11 @@ def rejection_sampler():
return RejectionSampler()
def create_logits_tensor(output_token_ids: list[list[int]],
vocab_size: int = 100) -> torch.Tensor:
def create_logits_tensor(
output_token_ids: list[list[int]], vocab_size: int = 100
) -> torch.Tensor:
"""Helper function to create logits tensor that
will produce desired token ids on argmax"""
will produce desired token ids on argmax"""
token_ids = [tokens[:-1] for tokens in output_token_ids]
num_total_tokens = sum(len(tokens) for tokens in token_ids)
logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE)
@@ -44,8 +44,8 @@ def create_sampling_metadata(
generators: Optional[dict[int, Any]] = None,
) -> SamplingMetadata:
"""Create a v1 sampling metadata object with all_greedy set
to the given value. Either all greedy or all random sampling
is used.
to the given value. Either all greedy or all random sampling
is used.
"""
generators = generators or {}
if all_greedy:
@@ -81,10 +81,10 @@ def test_perfect_match(rejection_sampler):
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
output = rejection_sampler(
spec_decode_metadata,
@@ -93,9 +93,7 @@ def test_perfect_match(rejection_sampler):
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor([[1, 2, 3, 4]],
dtype=torch.int,
device=logits.device)
expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device)
assert torch.equal(output, expected)
@@ -106,10 +104,10 @@ def test_early_mismatch(rejection_sampler):
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
output = rejection_sampler(
spec_decode_metadata,
@@ -129,15 +127,16 @@ def test_early_mismatch(rejection_sampler):
def test_multiple_sequences(rejection_sampler):
"""Test handling multiple sequences of speculated tokens"""
spec_tokens = [[1, 2], [3]]
output_tokens = [[1, 2, 5], [3,
4]] # Two sequences with bonus tokens 5 and 4
output_tokens = [[1, 2, 5], [3, 4]] # Two sequences with bonus tokens 5 and 4
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor(
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
output = rejection_sampler(
spec_decode_metadata,
@@ -146,9 +145,9 @@ def test_multiple_sequences(rejection_sampler):
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor([[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]],
dtype=torch.int,
device=logits.device)
expected = torch.tensor(
[[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, device=logits.device
)
assert torch.equal(output, expected)
@@ -159,10 +158,10 @@ def test_single_token_sequence(rejection_sampler):
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
output = rejection_sampler(
spec_decode_metadata,
@@ -182,10 +181,10 @@ def test_empty_sequence(rejection_sampler):
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
output = rejection_sampler(
spec_decode_metadata,
@@ -201,15 +200,16 @@ def test_empty_sequence(rejection_sampler):
def test_multiple_mismatches(rejection_sampler):
"""Test handling multiple sequences with mismatches"""
spec_tokens = [[1, 2, 3], [4, 5, 6]]
output_tokens = [[1, 2, 7, 6], [4, 8, 6,
9]] # Mismatches in both sequences
output_tokens = [[1, 2, 7, 6], [4, 8, 6, 9]] # Mismatches in both sequences
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor(
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
output = rejection_sampler(
spec_decode_metadata,
@@ -219,8 +219,10 @@ def test_multiple_mismatches(rejection_sampler):
sampling_metadata=metadata,
)
expected = torch.tensor(
[[1, 2, 7, PLACEHOLDER_TOKEN_ID],
[4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]],
[
[1, 2, 7, PLACEHOLDER_TOKEN_ID],
[4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID],
],
dtype=torch.int,
device=logits.device,
)
@@ -232,18 +234,23 @@ def test_multiple_mismatches(rejection_sampler):
[
([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]), # Perfect match with bonus
([[1]], [[2, 3]], [[2, PLACEHOLDER_TOKEN_ID]]), # First mismatch
([[1, 2], [3, 4]], [[1, 5, 6], [3, 4, 7]],
[[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]]), # Mixed matches
])
def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens,
expected):
(
[[1, 2], [3, 4]],
[[1, 5, 6], [3, 4, 7]],
[[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]],
), # Mixed matches
],
)
def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, expected):
"""Parametrized test for various matching scenarios"""
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens],
device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
bonus_token_tensor = torch.tensor(
[tokens[-1] for tokens in output_tokens], device=logits.device
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
output = rejection_sampler(
spec_decode_metadata,
@@ -252,9 +259,7 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected_tensor = torch.tensor(expected,
dtype=torch.int,
device=logits.device)
expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device)
assert torch.equal(output, expected_tensor)
@@ -273,22 +278,15 @@ def test_deterministic_when_seeded(
n_rep: int,
):
num_tokens = batch_size * k
draft_probs = torch.rand(num_tokens,
vocab_size,
dtype=torch.float32,
device=DEVICE)
draft_probs = torch.rand(num_tokens, vocab_size, dtype=torch.float32, device=DEVICE)
draft_probs = F.softmax(draft_probs, dim=-1)
target_logits = torch.rand_like(draft_probs)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64,
device=DEVICE)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device=DEVICE)
bonus_token_ids = torch.randint(
low=0, high=vocab_size, size=(batch_size, 1), dtype=torch.int64, device=DEVICE
)
draft_token_ids = torch.randint(
low=0, high=vocab_size, size=(batch_size, k), dtype=torch.int64, device=DEVICE
)
seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded
@@ -296,17 +294,17 @@ def test_deterministic_when_seeded(
for _ in range(n_rep):
seeded_seqs = {
i: torch.Generator(device=DEVICE).manual_seed(i)
for i in range(batch_size) if seeded_mask[i]
for i in range(batch_size)
if seeded_mask[i]
}
temperature = torch.ones(batch_size,
dtype=torch.float32,
device=DEVICE)
sampling_metadata = create_sampling_metadata(all_greedy=False,
temperature=temperature,
generators=seeded_seqs)
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
sampling_metadata = create_sampling_metadata(
all_greedy=False, temperature=temperature, generators=seeded_seqs
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids.tolist(), device=DEVICE)
draft_token_ids.tolist(), device=DEVICE
)
rep_result = rejection_sampler(
spec_decode_metadata,
draft_probs=draft_probs,
@@ -352,8 +350,7 @@ def test_rejection_sampling_approximates_target_distribution():
num_reference_probs = 100
# Prepare draft, target, and reference probability distributions
draft_probs = F.softmax(torch.rand(vocab_size, dtype=torch.float32),
dim=-1)
draft_probs = F.softmax(torch.rand(vocab_size, dtype=torch.float32), dim=-1)
target_logits = torch.rand(vocab_size, dtype=torch.float32)
target_probs = F.softmax(target_logits, dim=-1)
reference_probs = F.softmax(
@@ -368,38 +365,48 @@ def test_rejection_sampling_approximates_target_distribution():
for num_samples in sample_sizes:
# Sample using rejection sampling.
rej_sample_probs = estimate_rejection_sampling_pdf(
draft_probs, target_logits, k, vocab_size, num_samples)
draft_probs, target_logits, k, vocab_size, num_samples
)
rej_sample_probs = rej_sample_probs.to(DEVICE)
# Average distance from reference probs.
reference_vs_rejsample_dist = torch.dist(
reference_probs,
rej_sample_probs).item() / reference_probs.shape[0]
target_vs_rejsample_dist = torch.dist(target_probs,
rej_sample_probs).item()
reference_vs_rejsample_dist = (
torch.dist(reference_probs, rej_sample_probs).item()
/ reference_probs.shape[0]
)
target_vs_rejsample_dist = torch.dist(target_probs, rej_sample_probs).item()
distance_wrt_reference.append(reference_vs_rejsample_dist)
distance_wrt_target.append(target_vs_rejsample_dist)
relative_change_in_distance_wrt_target = get_ratio_first_to_last(
distance_wrt_target)
distance_wrt_target
)
relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
distance_wrt_reference)
distance_wrt_reference
)
print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
f"{reference_vs_rejsample_dist=:.05f}")
print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} "
f"{relative_change_in_distance_wrt_reference=:.02f}")
print(
f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
f"{reference_vs_rejsample_dist=:.05f}"
)
print(
f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} "
f"{relative_change_in_distance_wrt_reference=:.02f}"
)
relative_change_in_distance_wrt_target = get_ratio_first_to_last(
distance_wrt_target)
distance_wrt_target
)
relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
distance_wrt_reference)
distance_wrt_reference
)
expected_improvement_multiplier = 20
assert (relative_change_in_distance_wrt_target
> relative_change_in_distance_wrt_reference *
expected_improvement_multiplier)
assert (
relative_change_in_distance_wrt_target
> relative_change_in_distance_wrt_reference * expected_improvement_multiplier
)
def get_ratio_first_to_last(elements: list[float]) -> float:
@@ -427,28 +434,29 @@ def estimate_rejection_sampling_pdf(
rejection_sampler = RejectionSampler()
num_tokens = num_samples * k
# Repeat draft probs num_samples * k times.
draft_probs = draft_probs.reshape(1, 1,
vocab_size).repeat(num_samples, k, 1)
draft_probs = draft_probs.reshape(1, 1, vocab_size).repeat(num_samples, k, 1)
# Repeat target probs num_tokens times.
target_logits = target_logits.reshape(1, vocab_size).repeat(num_tokens, 1)
# Randomly sample draft token ids from draft probs.
draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
num_samples=k,
replacement=True).reshape(
num_samples, k)
draft_token_ids = torch.multinomial(
draft_probs[:, 0, :], num_samples=k, replacement=True
).reshape(num_samples, k)
draft_probs = draft_probs.view(num_tokens, vocab_size)
# Bonus tokens not used but required.
bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64,
device=DEVICE).repeat(num_samples, 1)
bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64, device=DEVICE).repeat(
num_samples, 1
)
temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE)
sampling_metadata = create_sampling_metadata(all_greedy=False,
temperature=temperature)
sampling_metadata = create_sampling_metadata(
all_greedy=False, temperature=temperature
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids.tolist(), device=bonus_token_ids.device)
draft_token_ids.tolist(), device=bonus_token_ids.device
)
output_token_ids = rejection_sampler(
spec_decode_metadata,
draft_probs=draft_probs,
@@ -458,11 +466,12 @@ def estimate_rejection_sampling_pdf(
)
output_token_ids = output_token_ids[:, :-1].flatten()
hist = torch.histogram(output_token_ids.to(dtype=torch.float,
device="cpu"),
bins=vocab_size,
range=(0, vocab_size),
density=True)
hist = torch.histogram(
output_token_ids.to(dtype=torch.float, device="cpu"),
bins=vocab_size,
range=(0, vocab_size),
density=True,
)
return hist.hist
@@ -480,9 +489,9 @@ def _test_masked_logits(
num_tokens = batch_size * num_draft_tokens
# Create random draft probabilities.
draft_probs = torch.rand((num_tokens, vocab_size),
dtype=torch.float32,
device=DEVICE)
draft_probs = torch.rand(
(num_tokens, vocab_size), dtype=torch.float32, device=DEVICE
)
draft_probs = F.softmax(draft_probs, dim=-1)
# Randomly sample draft token ids from draft probs
@@ -491,9 +500,7 @@ def _test_masked_logits(
draft_token_ids = draft_token_ids.tolist()
# Bonus tokens not used but required
bonus_token_ids = torch.zeros((batch_size, 1),
dtype=torch.int64,
device=DEVICE)
bonus_token_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=DEVICE)
# Create spec decode metadata
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
@@ -531,8 +538,7 @@ def test_top_k(rejection_sampler, top_k):
# Randomly create top-k indices.
top_k_indices = [
torch.randperm(vocab_size, device=DEVICE)[:top_k]
for _ in range(num_tokens)
torch.randperm(vocab_size, device=DEVICE)[:top_k] for _ in range(num_tokens)
]
top_k_indices = torch.stack(top_k_indices)
@@ -550,9 +556,7 @@ def test_top_k(rejection_sampler, top_k):
sampling_metadata = create_sampling_metadata(
all_greedy=False,
temperature=temperature,
top_k=torch.tensor([top_k] * batch_size,
device=DEVICE,
dtype=torch.int64),
top_k=torch.tensor([top_k] * batch_size, device=DEVICE, dtype=torch.int64),
)
_test_masked_logits(
@@ -595,9 +599,7 @@ def test_top_p(rejection_sampler, top_p):
sampling_metadata = create_sampling_metadata(
all_greedy=False,
temperature=temperature,
top_p=torch.tensor([top_p] * batch_size,
device=DEVICE,
dtype=torch.float32),
top_p=torch.tensor([top_p] * batch_size, device=DEVICE, dtype=torch.float32),
)
_test_masked_logits(

View File

@@ -29,12 +29,12 @@ def _create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor:
return fake_logits
def _create_penalty_tensor(batch_size: int, penalty_value: float,
device: torch.device) -> torch.Tensor:
return torch.full((batch_size, ),
fill_value=penalty_value,
dtype=torch.float,
device=device)
def _create_penalty_tensor(
batch_size: int, penalty_value: float, device: torch.device
) -> torch.Tensor:
return torch.full(
(batch_size,), fill_value=penalty_value, dtype=torch.float, device=device
)
def _create_prompt_tokens_tensor(
@@ -62,9 +62,9 @@ def _create_allowed_token_ids(
if i % 2 == 1:
continue
if mask is None:
mask = torch.zeros((batch_size, vocab_size),
dtype=torch.bool,
device=device)
mask = torch.zeros(
(batch_size, vocab_size), dtype=torch.bool, device=device
)
start = min(i, vocab_size - 1)
end = min(i + num_allowed_token_ids, vocab_size - 1)
mask[i, start:end] = True
@@ -80,9 +80,9 @@ def _create_bad_words_token_ids(
for batch_idx in range(batch_size):
token_ids_single_batch = []
for bad_words_length in bad_words_lengths:
token_ids = np.random.choice(vocab_size,
size=bad_words_length,
replace=True).tolist()
token_ids = np.random.choice(
vocab_size, size=bad_words_length, replace=True
).tolist()
token_ids_single_batch.append(token_ids)
bad_words_token_ids[batch_idx] = token_ids_single_batch
if batch_size >= 2:
@@ -95,26 +95,27 @@ def _create_bad_words_token_ids(
# Returns all last tokens of bad word sequences that share the same prefix
# as `given_prefix` (excluding the last token).
def _collect_suffixes_with_same_prefix(
given_prefix: list[int],
bad_words_token_ids: list[list[int]]) -> list[int]:
given_prefix: list[int], bad_words_token_ids: list[list[int]]
) -> list[int]:
return [bwt[-1] for bwt in bad_words_token_ids if bwt[:-1] == given_prefix]
# generate a valid token id that is not in bad_words_token_ids
def _generate_valid_token_id(bad_words_token_ids: list[list[int]],
vocab_size: int) -> int:
def _generate_valid_token_id(
bad_words_token_ids: list[list[int]], vocab_size: int
) -> int:
forbidden_start_tokens = set()
for bad_word in bad_words_token_ids:
forbidden_start_tokens.add(bad_word[0])
# Get a safe token that's not in forbidden starts
safe_token_candidates = list(
set(range(vocab_size)) - forbidden_start_tokens)
safe_token_candidates = list(set(range(vocab_size)) - forbidden_start_tokens)
# Pick a random safe token
return np.random.choice(safe_token_candidates)
def _update_output_token_ids_for_bad_words(
metadata: SamplingMetadata, vocab_size: int) -> dict[int, list[int]]:
metadata: SamplingMetadata, vocab_size: int
) -> dict[int, list[int]]:
bad_words_last_tokens = {}
for batch_idx, bad_words_token_ids in metadata.bad_words_token_ids.items():
output_token_ids = metadata.output_token_ids[batch_idx]
@@ -132,12 +133,13 @@ def _update_output_token_ids_for_bad_words(
# Collect all last tokens from other bad words
# that share this prefix
bad_words_last_token.extend(
_collect_suffixes_with_same_prefix(
prefix, bad_words_token_ids))
_collect_suffixes_with_same_prefix(prefix, bad_words_token_ids)
)
break # Maximum one update to output_token_ids
else: # Make sure no accidental match to bad words
output_token_ids[-1] = _generate_valid_token_id(
bad_words_token_ids, vocab_size)
bad_words_token_ids, vocab_size
)
bad_words_last_tokens[batch_idx] = bad_words_last_token
return bad_words_last_tokens
@@ -152,22 +154,24 @@ def _create_default_sampling_metadata(
prompt_token_ids: list[list[int]] = []
for _ in range(batch_size):
output_token_ids.append(
np.random.randint(0, vocab_size, size=num_output_tokens).tolist())
np.random.randint(0, vocab_size, size=num_output_tokens).tolist()
)
prompt_token_ids.append(
np.random.randint(0,
vocab_size,
size=np.random.randint(
1, MAX_NUM_PROMPT_TOKENS)).tolist())
np.random.randint(
0, vocab_size, size=np.random.randint(1, MAX_NUM_PROMPT_TOKENS)
).tolist()
)
fake_sampling_metadata = SamplingMetadata(
temperature=torch.full((batch_size, ), 0.0),
temperature=torch.full((batch_size,), 0.0),
all_greedy=True,
all_random=False,
top_p=None,
top_k=None,
generators={},
max_num_logprobs=0,
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
vocab_size, device),
prompt_token_ids=_create_prompt_tokens_tensor(
prompt_token_ids, vocab_size, device
),
output_token_ids=output_token_ids,
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
@@ -181,8 +185,8 @@ def _create_default_sampling_metadata(
def _create_weighted_output_token_list(
batch_size: int,
vocab_size: int) -> tuple[list[list[int]], list[list[int]]]:
batch_size: int, vocab_size: int
) -> tuple[list[list[int]], list[list[int]]]:
"""
Creates an output token list where each token occurs a distinct
number of times.
@@ -203,14 +207,13 @@ def _create_weighted_output_token_list(
output_token_ids: list[list[int]] = []
sorted_token_ids_in_output: list[list[int]] = []
for _ in range(batch_size):
distinct_token_ids = np.random.choice(vocab_size,
size=np.random.randint(1, 10),
replace=False).tolist()
distinct_token_ids = np.random.choice(
vocab_size, size=np.random.randint(1, 10), replace=False
).tolist()
sorted_token_ids_in_output.append(distinct_token_ids)
output_token_ids_for_batch = []
for index, token_id in enumerate(distinct_token_ids):
output_token_ids_for_batch.extend(
[token_id for _ in range(index + 1)])
output_token_ids_for_batch.extend([token_id for _ in range(index + 1)])
output_token_ids.append(output_token_ids_for_batch)
return output_token_ids, sorted_token_ids_in_output
@@ -218,8 +221,9 @@ def _create_weighted_output_token_list(
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("presence_penalty", [-2.0, 2.0])
def test_sampler_presence_penalty(device: str, batch_size: int,
presence_penalty: float):
def test_sampler_presence_penalty(
device: str, batch_size: int, presence_penalty: float
):
"""
Test to verify that if presence penalty is enabled then tokens
are penalized as per their presence in the existing output.
@@ -229,10 +233,12 @@ def test_sampler_presence_penalty(device: str, batch_size: int,
# logit value.
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)
)
output_token_ids = sampling_metadata.output_token_ids
sampling_metadata.presence_penalties = _create_penalty_tensor(
batch_size, presence_penalty, torch.device(device))
batch_size, presence_penalty, torch.device(device)
)
sampling_metadata.no_penalties = False
sampler = Sampler()
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
@@ -263,8 +269,9 @@ def test_sampler_presence_penalty(device: str, batch_size: int,
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("frequency_penalty", [-2.0, 2.0])
def test_sampler_frequency_penalty(device: str, batch_size: int,
frequency_penalty: float):
def test_sampler_frequency_penalty(
device: str, batch_size: int, frequency_penalty: float
):
"""
Test to verify that if frequency penalty is enabled then tokens are
penalized as per their frequency of occurrence.
@@ -274,14 +281,15 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
# logit value.
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)
)
sampling_metadata.frequency_penalties = _create_penalty_tensor(
batch_size, frequency_penalty, torch.device(device))
output_token_ids, sorted_token_ids_in_output = \
_create_weighted_output_token_list(
batch_size,
VOCAB_SIZE,
)
batch_size, frequency_penalty, torch.device(device)
)
output_token_ids, sorted_token_ids_in_output = _create_weighted_output_token_list(
batch_size,
VOCAB_SIZE,
)
sampling_metadata.output_token_ids = output_token_ids
sampling_metadata.no_penalties = False
sampler = Sampler()
@@ -290,18 +298,17 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
for batch_idx in range(batch_size):
non_penalized_token_id = logits[batch_idx].argmax().item()
penalized_token_id = logits[batch_idx].argmin().item()
distinct_sorted_token_ids_in_output = sorted_token_ids_in_output[
batch_idx]
distinct_sorted_token_ids_in_output = sorted_token_ids_in_output[batch_idx]
most_frequent_token_id = distinct_sorted_token_ids_in_output[
len(distinct_sorted_token_ids_in_output) - 1]
len(distinct_sorted_token_ids_in_output) - 1
]
if frequency_penalty > 0:
# If `frequency_penalty` is set to > 0, it indicates
# a preference for new tokens over existing ones. Verify that the
# non-penalized token ID is not present in the output, while the
# most penalized token is the one that occurs most frequently in
# the output.
assert (non_penalized_token_id
not in distinct_sorted_token_ids_in_output)
assert non_penalized_token_id not in distinct_sorted_token_ids_in_output
assert penalized_token_id == most_frequent_token_id
elif frequency_penalty < 0:
# If `frequency_penalty` is set to < 0, it indicates
@@ -316,8 +323,9 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("repetition_penalty", [0.1, 1.9])
def test_sampler_repetition_penalty(device: str, batch_size: int,
repetition_penalty: float):
def test_sampler_repetition_penalty(
device: str, batch_size: int, repetition_penalty: float
):
"""
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
@@ -328,9 +336,11 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
# logit value.
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)
)
sampling_metadata.repetition_penalties = _create_penalty_tensor(
batch_size, repetition_penalty, torch.device(device))
batch_size, repetition_penalty, torch.device(device)
)
sampling_metadata.no_penalties = False
sampler = Sampler()
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
@@ -338,32 +348,40 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
for batch_idx in range(batch_size):
non_penalized_token_id = logits[batch_idx].argmax().item()
penalized_token_id = logits[batch_idx].argmin().item()
prompt_tokens = sampling_metadata.prompt_token_ids[
batch_idx][:].tolist()
prompt_tokens = sampling_metadata.prompt_token_ids[batch_idx][:].tolist()
output_tokens = sampling_metadata.output_token_ids[batch_idx]
if repetition_penalty > 1.0:
# If `repetition_penalty` > 1.0, verify that the non-penalized
# token ID has not been seen before, while the penalized token ID
# exists either in the prompt or the output.
assert (non_penalized_token_id not in prompt_tokens
and non_penalized_token_id not in output_tokens)
assert (penalized_token_id in prompt_tokens
or penalized_token_id in output_tokens)
assert (
non_penalized_token_id not in prompt_tokens
and non_penalized_token_id not in output_tokens
)
assert (
penalized_token_id in prompt_tokens
or penalized_token_id in output_tokens
)
elif repetition_penalty < 1.0:
# If `repetition_penalty` < 1.0, verify that the penalized
# token ID has not been seen before, while the non-penalized
# token ID exists either in the prompt or the output.
assert (penalized_token_id not in prompt_tokens
and penalized_token_id not in output_tokens)
assert (non_penalized_token_id in prompt_tokens
or non_penalized_token_id in output_tokens)
assert (
penalized_token_id not in prompt_tokens
and penalized_token_id not in output_tokens
)
assert (
non_penalized_token_id in prompt_tokens
or non_penalized_token_id in output_tokens
)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2])
def test_sampler_allowed_token_ids(device: str, batch_size: int,
num_allowed_token_ids: int):
def test_sampler_allowed_token_ids(
device: str, batch_size: int, num_allowed_token_ids: int
):
"""
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
@@ -374,7 +392,8 @@ def test_sampler_allowed_token_ids(device: str, batch_size: int,
# logit value.
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)
)
mask = _create_allowed_token_ids(
batch_size=batch_size,
vocab_size=VOCAB_SIZE,
@@ -394,17 +413,19 @@ def test_sampler_allowed_token_ids(device: str, batch_size: int,
start = min(batch_idx, VOCAB_SIZE - 1)
end = min(batch_idx + num_allowed_token_ids, VOCAB_SIZE - 1)
if token_id >= start and token_id < end:
assert logits_for_req[token_id] == -float(
"inf"), f"{batch_idx}, {token_id}"
assert logits_for_req[token_id] == -float("inf"), (
f"{batch_idx}, {token_id}"
)
else:
assert logits_for_req[token_id] != -float("inf")
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("bad_words_lengths", [(1, ), (1, 3), (2, 2)])
def test_sampler_bad_words(device: str, batch_size: int,
bad_words_lengths: tuple[int, ...]):
@pytest.mark.parametrize("bad_words_lengths", [(1,), (1, 3), (2, 2)])
def test_sampler_bad_words(
device: str, batch_size: int, bad_words_lengths: tuple[int, ...]
):
"""
Test to verify that when the bad words restriction is present, tokens
are penalized based on their match with the bad words.
@@ -414,19 +435,24 @@ def test_sampler_bad_words(device: str, batch_size: int,
# logit value.
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)
)
sampling_metadata.bad_words_token_ids = _create_bad_words_token_ids(
batch_size, VOCAB_SIZE, bad_words_lengths)
batch_size, VOCAB_SIZE, bad_words_lengths
)
bad_words_last_tokens = _update_output_token_ids_for_bad_words(
sampling_metadata, VOCAB_SIZE)
sampling_metadata, VOCAB_SIZE
)
sampler = Sampler()
logits = sampler.apply_bad_words(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
logits_for_req = logits[batch_idx]
for token_id in range(VOCAB_SIZE):
if (batch_idx in bad_words_last_tokens
and token_id in bad_words_last_tokens[batch_idx]):
if (
batch_idx in bad_words_last_tokens
and token_id in bad_words_last_tokens[batch_idx]
):
assert logits_for_req[token_id] == -float("inf")
else:
assert logits_for_req[token_id] != -float("inf")

View File

@@ -66,9 +66,9 @@ def test_stop(llm):
# Output should not contain the stop word.
assert len(new_split_text) == STOP_IDX
params = SamplingParams(temperature=0,
stop=split_text[STOP_IDX],
include_stop_str_in_output=True)
params = SamplingParams(
temperature=0, stop=split_text[STOP_IDX], include_stop_str_in_output=True
)
output = llm.generate(PROMPT, params)
new_split_text = output[0].outputs[0].text.split()
@@ -103,8 +103,8 @@ def test_detokenize_false(llm):
assert len(output[0].outputs[0].text) == 0
output = llm.generate(
PROMPT, SamplingParams(detokenize=False, logprobs=3,
prompt_logprobs=3))
PROMPT, SamplingParams(detokenize=False, logprobs=3, prompt_logprobs=3)
)
assert len(output[0].outputs[0].token_ids) > 0
assert len(output[0].outputs[0].text) == 0
@@ -131,8 +131,7 @@ def test_bad_words(llm):
assert bad_words_1 not in new_text
bad_words_2 = new_text.split()[-1]
params = SamplingParams(temperature=0,
bad_words=[bad_words_1, bad_words_2])
params = SamplingParams(temperature=0, bad_words=[bad_words_1, bad_words_2])
output = llm.generate(PROMPT, params)
new_text = output[0].outputs[0].text
assert bad_words_1 not in new_text
@@ -158,8 +157,7 @@ def test_allowed_token_ids(llm):
TOKEN_ID = 10
allowed_token_ids = [TOKEN_ID]
output = llm.generate(PROMPT,
SamplingParams(allowed_token_ids=allowed_token_ids))
output = llm.generate(PROMPT, SamplingParams(allowed_token_ids=allowed_token_ids))
assert output[0].outputs[0].token_ids[-1] == TOKEN_ID
# Reject empty allowed_token_ids.

View File

@@ -5,8 +5,10 @@ import torch
from torch import Generator
from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
is_flashinfer_available)
from vllm.v1.sample.ops.topk_topp_sampler import (
apply_top_k_top_p,
is_flashinfer_available,
)
DEVICE = current_platform.device_type
@@ -30,19 +32,18 @@ def reset_default_device():
def test_topk_impl_equivalence():
torch.set_default_device(DEVICE)
generator = Generator(device=DEVICE).manual_seed(33)
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
# Random top-k values between 1 and 9.
k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator)
k = torch.randint(1, 10, (BATCH_SIZE,), generator=generator)
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
k.masked_fill_(
torch.randint(0, 2, (BATCH_SIZE, ), generator=generator, dtype=bool),
VOCAB_SIZE)
torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=bool), VOCAB_SIZE
)
# Top-k only implementation
result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
@@ -55,7 +56,7 @@ def test_topk_impl_equivalence():
def test_flashinfer_sampler():
'''
"""
This test verifies that the FlashInfer top-k and top-p sampling
implementation produces the same results as the Python implementation.
@@ -63,11 +64,10 @@ def test_flashinfer_sampler():
top-p prob renorm (it did provide fused sampling but we cannot compare
sampling results due to randomness), so we will compare the probability
renormed consequently by top-k and then top-p of FlashInfer implementation.
'''
"""
if not FLASHINFER_ENABLED:
pytest.skip(
"FlashInfer not installed or not available on this platform.")
pytest.skip("FlashInfer not installed or not available on this platform.")
torch.set_default_device(DEVICE)
generator = Generator(device=DEVICE).manual_seed(42)
@@ -76,23 +76,21 @@ def test_flashinfer_sampler():
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
# Generate various top-k and top-p values
k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator)
p_values = torch.rand(
(BATCH_SIZE, ), generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0]
k_values = torch.randint(1, 1000, (BATCH_SIZE,), generator=generator)
p_values = (
torch.rand((BATCH_SIZE,), generator=generator) * 0.5 + 0.5
) # range in [0.5, 1.0]
# Sometimes disable top-k (k=vocab_size)
k_values.masked_fill_(
torch.randint(0,
2, (BATCH_SIZE, ),
generator=generator,
dtype=torch.bool), VOCAB_SIZE)
torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool),
VOCAB_SIZE,
)
# Sometimes disable top-p (p=1.0)
p_values.masked_fill_(
torch.randint(0,
2, (BATCH_SIZE, ),
generator=generator,
dtype=torch.bool), 1.0)
torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool), 1.0
)
python_logits = apply_top_k_top_p(
logits=logits.clone(),
@@ -113,5 +111,6 @@ def test_flashinfer_sampler():
)
# Compare the results
assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \
assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), (
"FlashInfer and Python sampling implementations do not match!"
)

View File

@@ -16,6 +16,7 @@ from vllm.v1.sample.metadata import SamplingMetadata
class BatchLogprobsComposition(Enum):
"""Types of logprobs configs to include in test batch"""
NONE = 0
SAMPLE = 1
PROMPT = 2
@@ -26,10 +27,10 @@ BatchLogprobsSpecType = list[tuple[Optional[int], Optional[int]]]
def get_test_batch(
batch_logprobs_composition: BatchLogprobsComposition
batch_logprobs_composition: BatchLogprobsComposition,
) -> BatchLogprobsSpecType:
"""Generate logprobs configs for a batch of requests
A given request's logprobs configuration is (1) num_sample_logprobs and (2)
num_prompt_logprobs. The batch logprobs configuration is the list of request
logprobs configs.
@@ -101,7 +102,7 @@ def assert_incr_detok_str_matches_non_incr_detok_str(
msg: str,
) -> None:
"""Compare incrementally detok. text to non-incrementally detok. text
Fail if the strings mismatch after non-alphanumeric characters are stripped
out.
@@ -120,15 +121,15 @@ def assert_incr_detok_str_matches_non_incr_detok_str(
tokens
msg: error message if `assert` fails
"""
rgx = r'[^a-zA-Z0-9]+'
assert (re.sub(rgx, '', incremental_detokenization_str) == re.sub(
rgx, '', non_incremental_detokenization_str)), (msg)
rgx = r"[^a-zA-Z0-9]+"
assert re.sub(rgx, "", incremental_detokenization_str) == re.sub(
rgx, "", non_incremental_detokenization_str
), msg
def compute_correct_cumulative_logprob(
completion_output: CompletionOutput) -> float:
def compute_correct_cumulative_logprob(completion_output: CompletionOutput) -> float:
"""Compute known-good value for evaluating cumulative logprob
Args:
completion_output: completion output from engine
@@ -146,12 +147,12 @@ def create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor:
return fake_logits
def create_penalty_tensor(batch_size: int, penalty_value: float,
device: torch.device) -> torch.Tensor:
return torch.full((batch_size, ),
fill_value=penalty_value,
dtype=torch.float,
device=device)
def create_penalty_tensor(
batch_size: int, penalty_value: float, device: torch.device
) -> torch.Tensor:
return torch.full(
(batch_size,), fill_value=penalty_value, dtype=torch.float, device=device
)
def create_prompt_tokens_tensor(
@@ -170,6 +171,7 @@ def create_prompt_tokens_tensor(
class LogitsprocsTestFakes(NamedTuple):
"""Wraps fake data structures to support testing"""
logits: torch.Tensor
sampling_metadata: SamplingMetadata
@@ -178,15 +180,16 @@ class LogitsprocsTestFakes(NamedTuple):
cls: type[LogitsProcessor],
) -> Iterator[LogitsProcessor]:
"""Yield logits processors of a specific class.
Args:
cls: :class:`LogitsProcessor` subclass
Returns:
Iterator over logits processors
"""
return (lp for lp in self.sampling_metadata.logitsprocs.all
if isinstance(lp, cls))
return (
lp for lp in self.sampling_metadata.logitsprocs.all if isinstance(lp, cls)
)
def get_logitsprocs(self) -> Iterator[LogitsProcessor]:
"""Iterator over all logits processors."""
@@ -208,8 +211,7 @@ def fake_apply_logitsprocs(
slice_indices: list[int],
) -> torch.Tensor:
"""Imitate application of logits processors in engine core"""
logits = test_fakes.logits[torch.tensor(slice_indices,
dtype=torch.long)].clone()
logits = test_fakes.logits[torch.tensor(slice_indices, dtype=torch.long)].clone()
for processor in test_fakes.get_logitsprocs():
logits = processor.apply(logits)
return logits