[V1][spec decode] return logprobs for spec decoding (#26060)

Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Giancarlo Delfin
2025-10-22 22:59:59 -07:00
committed by GitHub
parent ff93cc8c84
commit 6644796bf4
8 changed files with 392 additions and 186 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from unittest.mock import Mock
import pytest
import torch
@@ -11,6 +12,7 @@ from vllm.platforms import current_platform
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID, RejectionSampler
from vllm.v1.sample.sampler import Sampler, SamplerOutput
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
DEVICE = current_platform.device_type
@@ -18,7 +20,28 @@ DEVICE = current_platform.device_type
@pytest.fixture
def rejection_sampler():
return RejectionSampler()
mock_sampler = Mock(spec=Sampler)
mock_sampler.logprobs_mode = "raw_logprobs"
return RejectionSampler(mock_sampler)
def mock_sampler_output(
rejection_sampler: RejectionSampler, bonus_token_ids: torch.Tensor
):
rejection_sampler.sampler.return_value = SamplerOutput(
sampled_token_ids=bonus_token_ids, logprobs_tensors=None
)
def create_spec_decode_metadata(
spec_tokens: list[list[int]], logits: torch.Tensor
) -> SpecDecodeMetadata:
metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device)
metadata.target_logits_indices = torch.arange(logits.shape[0])
# Output bonus token ids are mocked, so the bonus logit indices should
# be empty.
metadata.bonus_logits_indices = torch.empty(0, dtype=torch.int32)
return metadata
def create_logits_tensor(
@@ -111,19 +134,17 @@ def test_perfect_match(rejection_sampler):
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
logits=logits,
sampling_metadata=metadata,
)
expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device)
assert torch.equal(output, expected)
assert torch.equal(output.sampled_token_ids, expected)
def test_early_mismatch(rejection_sampler):
@@ -134,15 +155,13 @@ def test_early_mismatch(rejection_sampler):
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
logits=logits,
sampling_metadata=metadata,
)
expected = torch.tensor(
@@ -150,7 +169,7 @@ def test_early_mismatch(rejection_sampler):
dtype=torch.int,
device=logits.device,
)
assert torch.equal(output, expected)
assert torch.equal(output.sampled_token_ids, expected)
def test_multiple_sequences(rejection_sampler):
@@ -163,21 +182,19 @@ def test_multiple_sequences(rejection_sampler):
bonus_token_tensor = torch.tensor(
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
logits=logits,
sampling_metadata=metadata,
)
expected = torch.tensor(
[[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, device=logits.device
)
assert torch.equal(output, expected)
assert torch.equal(output.sampled_token_ids, expected)
def test_single_token_sequence(rejection_sampler):
@@ -188,19 +205,17 @@ def test_single_token_sequence(rejection_sampler):
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
logits=logits,
sampling_metadata=metadata,
)
expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
assert torch.equal(output, expected)
assert torch.equal(output.sampled_token_ids, expected)
def test_empty_sequence(rejection_sampler):
@@ -211,19 +226,17 @@ def test_empty_sequence(rejection_sampler):
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
logits=logits,
sampling_metadata=metadata,
)
expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
assert torch.equal(output, expected)
assert torch.equal(output.sampled_token_ids, expected)
def test_multiple_mismatches(rejection_sampler):
@@ -236,15 +249,13 @@ def test_multiple_mismatches(rejection_sampler):
bonus_token_tensor = torch.tensor(
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
logits=logits,
sampling_metadata=metadata,
)
expected = torch.tensor(
@@ -255,7 +266,7 @@ def test_multiple_mismatches(rejection_sampler):
dtype=torch.int,
device=logits.device,
)
assert torch.equal(output, expected)
assert torch.equal(output.sampled_token_ids, expected)
@pytest.mark.parametrize(
@@ -277,19 +288,17 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, expec
bonus_token_tensor = torch.tensor(
[tokens[-1] for tokens in output_tokens], device=logits.device
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
logits=logits,
sampling_metadata=metadata,
)
expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device)
assert torch.equal(output, expected_tensor)
assert torch.equal(output.sampled_token_ids, expected_tensor)
########################### Tests for Random Sampling ###################
@@ -331,18 +340,19 @@ def test_deterministic_when_seeded(
sampling_metadata = create_sampling_metadata(
all_greedy=False, temperature=temperature, generators=seeded_seqs
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids.tolist(), device=DEVICE
spec_decode_metadata = create_spec_decode_metadata(
draft_token_ids.tolist(), target_logits
)
mock_sampler_output(rejection_sampler, bonus_token_ids)
rep_result = rejection_sampler(
spec_decode_metadata,
draft_probs=draft_probs,
target_logits=target_logits,
bonus_token_ids=bonus_token_ids,
draft_probs=None,
logits=target_logits,
sampling_metadata=sampling_metadata,
)
results.append(rep_result)
results.append(rep_result.sampled_token_ids)
for i in range(batch_size):
if seeded_mask[i]:
@@ -460,7 +470,9 @@ def estimate_rejection_sampling_pdf(
Returns:
Estimated probability distribution of the output tokens.
"""
rejection_sampler = RejectionSampler()
mock_sampler = Mock(spec=Sampler)
mock_sampler.logprobs_mode = "raw_logprobs"
rejection_sampler = RejectionSampler(mock_sampler)
num_tokens = num_samples * k
# Repeat draft probs num_samples * k times.
draft_probs = draft_probs.reshape(1, 1, vocab_size).repeat(num_samples, k, 1)
@@ -483,17 +495,18 @@ def estimate_rejection_sampling_pdf(
sampling_metadata = create_sampling_metadata(
all_greedy=False, temperature=temperature
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids.tolist(), device=bonus_token_ids.device
spec_decode_metadata = create_spec_decode_metadata(
draft_token_ids.tolist(), target_logits
)
output_token_ids = rejection_sampler(
mock_sampler_output(rejection_sampler, bonus_token_ids)
sampler_output = rejection_sampler(
spec_decode_metadata,
draft_probs=draft_probs,
target_logits=target_logits,
bonus_token_ids=bonus_token_ids,
logits=target_logits,
sampling_metadata=sampling_metadata,
)
output_token_ids = output_token_ids[:, :-1].flatten()
output_token_ids = sampler_output.sampled_token_ids[:, :-1].flatten()
hist = torch.histogram(
output_token_ids.to(dtype=torch.float, device="cpu"),
@@ -532,22 +545,19 @@ def _test_masked_logits(
bonus_token_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=DEVICE)
# Create spec decode metadata
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids,
device=DEVICE,
)
spec_decode_metadata = create_spec_decode_metadata(draft_token_ids, target_logits)
# Run rejection sampling
output_token_ids = rejection_sampler(
mock_sampler_output(rejection_sampler, bonus_token_ids)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=draft_probs,
target_logits=target_logits,
bonus_token_ids=bonus_token_ids,
logits=target_logits,
sampling_metadata=sampling_metadata,
)
# Remove bonus tokens and reshape
output_token_ids = output_token_ids[:, :-1].flatten().tolist()
output_token_ids = output.sampled_token_ids[:, :-1].flatten().tolist()
# Check that all sampled tokens are within the unmasked indices.
for i in range(num_tokens):
@@ -665,11 +675,11 @@ def test_frequency_penalties(rejection_sampler):
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
logits=logits,
sampling_metadata=metadata,
)
expected = torch.tensor(
@@ -677,7 +687,7 @@ def test_frequency_penalties(rejection_sampler):
dtype=torch.int,
device=logits.device,
)
assert torch.equal(output, expected)
assert torch.equal(output.sampled_token_ids, expected)
def test_bad_words(rejection_sampler):
@@ -707,14 +717,12 @@ def test_bad_words(rejection_sampler):
bonus_token_tensor = torch.tensor(
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
logits=logits,
sampling_metadata=metadata,
)
@@ -723,7 +731,7 @@ def test_bad_words(rejection_sampler):
dtype=torch.int,
device=logits.device,
)
assert torch.equal(output, expected)
assert torch.equal(output.sampled_token_ids, expected)
def test_allowed_token_ids(rejection_sampler):
@@ -756,14 +764,12 @@ def test_allowed_token_ids(rejection_sampler):
bonus_token_tensor = torch.tensor(
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
logits=logits,
sampling_metadata=metadata,
)
@@ -772,4 +778,4 @@ def test_allowed_token_ids(rejection_sampler):
dtype=torch.int,
device=logits.device,
)
assert torch.equal(output, expected)
assert torch.equal(output.sampled_token_ids, expected)