[V1][spec decode] return logprobs for spec decoding (#26060)
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -11,6 +12,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID, RejectionSampler
|
||||
from vllm.v1.sample.sampler import Sampler, SamplerOutput
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
DEVICE = current_platform.device_type
|
||||
@@ -18,7 +20,28 @@ DEVICE = current_platform.device_type
|
||||
|
||||
@pytest.fixture
|
||||
def rejection_sampler():
|
||||
return RejectionSampler()
|
||||
mock_sampler = Mock(spec=Sampler)
|
||||
mock_sampler.logprobs_mode = "raw_logprobs"
|
||||
return RejectionSampler(mock_sampler)
|
||||
|
||||
|
||||
def mock_sampler_output(
|
||||
rejection_sampler: RejectionSampler, bonus_token_ids: torch.Tensor
|
||||
):
|
||||
rejection_sampler.sampler.return_value = SamplerOutput(
|
||||
sampled_token_ids=bonus_token_ids, logprobs_tensors=None
|
||||
)
|
||||
|
||||
|
||||
def create_spec_decode_metadata(
|
||||
spec_tokens: list[list[int]], logits: torch.Tensor
|
||||
) -> SpecDecodeMetadata:
|
||||
metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device)
|
||||
metadata.target_logits_indices = torch.arange(logits.shape[0])
|
||||
# Output bonus token ids are mocked, so the bonus logit indices should
|
||||
# be empty.
|
||||
metadata.bonus_logits_indices = torch.empty(0, dtype=torch.int32)
|
||||
return metadata
|
||||
|
||||
|
||||
def create_logits_tensor(
|
||||
@@ -111,19 +134,17 @@ def test_perfect_match(rejection_sampler):
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||
|
||||
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
logits=logits,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device)
|
||||
assert torch.equal(output, expected)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
|
||||
def test_early_mismatch(rejection_sampler):
|
||||
@@ -134,15 +155,13 @@ def test_early_mismatch(rejection_sampler):
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||
|
||||
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
logits=logits,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor(
|
||||
@@ -150,7 +169,7 @@ def test_early_mismatch(rejection_sampler):
|
||||
dtype=torch.int,
|
||||
device=logits.device,
|
||||
)
|
||||
assert torch.equal(output, expected)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
|
||||
def test_multiple_sequences(rejection_sampler):
|
||||
@@ -163,21 +182,19 @@ def test_multiple_sequences(rejection_sampler):
|
||||
bonus_token_tensor = torch.tensor(
|
||||
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device
|
||||
)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||
|
||||
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
logits=logits,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor(
|
||||
[[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, device=logits.device
|
||||
)
|
||||
assert torch.equal(output, expected)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
|
||||
def test_single_token_sequence(rejection_sampler):
|
||||
@@ -188,19 +205,17 @@ def test_single_token_sequence(rejection_sampler):
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||
|
||||
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
logits=logits,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
|
||||
assert torch.equal(output, expected)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
|
||||
def test_empty_sequence(rejection_sampler):
|
||||
@@ -211,19 +226,17 @@ def test_empty_sequence(rejection_sampler):
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||
|
||||
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
logits=logits,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
|
||||
assert torch.equal(output, expected)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
|
||||
def test_multiple_mismatches(rejection_sampler):
|
||||
@@ -236,15 +249,13 @@ def test_multiple_mismatches(rejection_sampler):
|
||||
bonus_token_tensor = torch.tensor(
|
||||
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device
|
||||
)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||
|
||||
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
logits=logits,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor(
|
||||
@@ -255,7 +266,7 @@ def test_multiple_mismatches(rejection_sampler):
|
||||
dtype=torch.int,
|
||||
device=logits.device,
|
||||
)
|
||||
assert torch.equal(output, expected)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -277,19 +288,17 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, expec
|
||||
bonus_token_tensor = torch.tensor(
|
||||
[tokens[-1] for tokens in output_tokens], device=logits.device
|
||||
)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||
|
||||
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
logits=logits,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device)
|
||||
assert torch.equal(output, expected_tensor)
|
||||
assert torch.equal(output.sampled_token_ids, expected_tensor)
|
||||
|
||||
|
||||
########################### Tests for Random Sampling ###################
|
||||
@@ -331,18 +340,19 @@ def test_deterministic_when_seeded(
|
||||
sampling_metadata = create_sampling_metadata(
|
||||
all_greedy=False, temperature=temperature, generators=seeded_seqs
|
||||
)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
draft_token_ids.tolist(), device=DEVICE
|
||||
spec_decode_metadata = create_spec_decode_metadata(
|
||||
draft_token_ids.tolist(), target_logits
|
||||
)
|
||||
|
||||
mock_sampler_output(rejection_sampler, bonus_token_ids)
|
||||
rep_result = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=draft_probs,
|
||||
target_logits=target_logits,
|
||||
bonus_token_ids=bonus_token_ids,
|
||||
draft_probs=None,
|
||||
logits=target_logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
results.append(rep_result)
|
||||
results.append(rep_result.sampled_token_ids)
|
||||
|
||||
for i in range(batch_size):
|
||||
if seeded_mask[i]:
|
||||
@@ -460,7 +470,9 @@ def estimate_rejection_sampling_pdf(
|
||||
Returns:
|
||||
Estimated probability distribution of the output tokens.
|
||||
"""
|
||||
rejection_sampler = RejectionSampler()
|
||||
mock_sampler = Mock(spec=Sampler)
|
||||
mock_sampler.logprobs_mode = "raw_logprobs"
|
||||
rejection_sampler = RejectionSampler(mock_sampler)
|
||||
num_tokens = num_samples * k
|
||||
# Repeat draft probs num_samples * k times.
|
||||
draft_probs = draft_probs.reshape(1, 1, vocab_size).repeat(num_samples, k, 1)
|
||||
@@ -483,17 +495,18 @@ def estimate_rejection_sampling_pdf(
|
||||
sampling_metadata = create_sampling_metadata(
|
||||
all_greedy=False, temperature=temperature
|
||||
)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
draft_token_ids.tolist(), device=bonus_token_ids.device
|
||||
spec_decode_metadata = create_spec_decode_metadata(
|
||||
draft_token_ids.tolist(), target_logits
|
||||
)
|
||||
output_token_ids = rejection_sampler(
|
||||
|
||||
mock_sampler_output(rejection_sampler, bonus_token_ids)
|
||||
sampler_output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=draft_probs,
|
||||
target_logits=target_logits,
|
||||
bonus_token_ids=bonus_token_ids,
|
||||
logits=target_logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
output_token_ids = output_token_ids[:, :-1].flatten()
|
||||
output_token_ids = sampler_output.sampled_token_ids[:, :-1].flatten()
|
||||
|
||||
hist = torch.histogram(
|
||||
output_token_ids.to(dtype=torch.float, device="cpu"),
|
||||
@@ -532,22 +545,19 @@ def _test_masked_logits(
|
||||
bonus_token_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=DEVICE)
|
||||
|
||||
# Create spec decode metadata
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
draft_token_ids,
|
||||
device=DEVICE,
|
||||
)
|
||||
spec_decode_metadata = create_spec_decode_metadata(draft_token_ids, target_logits)
|
||||
|
||||
# Run rejection sampling
|
||||
output_token_ids = rejection_sampler(
|
||||
mock_sampler_output(rejection_sampler, bonus_token_ids)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=draft_probs,
|
||||
target_logits=target_logits,
|
||||
bonus_token_ids=bonus_token_ids,
|
||||
logits=target_logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
# Remove bonus tokens and reshape
|
||||
output_token_ids = output_token_ids[:, :-1].flatten().tolist()
|
||||
output_token_ids = output.sampled_token_ids[:, :-1].flatten().tolist()
|
||||
|
||||
# Check that all sampled tokens are within the unmasked indices.
|
||||
for i in range(num_tokens):
|
||||
@@ -665,11 +675,11 @@ def test_frequency_penalties(rejection_sampler):
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
logits=logits,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor(
|
||||
@@ -677,7 +687,7 @@ def test_frequency_penalties(rejection_sampler):
|
||||
dtype=torch.int,
|
||||
device=logits.device,
|
||||
)
|
||||
assert torch.equal(output, expected)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
|
||||
def test_bad_words(rejection_sampler):
|
||||
@@ -707,14 +717,12 @@ def test_bad_words(rejection_sampler):
|
||||
bonus_token_tensor = torch.tensor(
|
||||
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
|
||||
)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
logits=logits,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
|
||||
@@ -723,7 +731,7 @@ def test_bad_words(rejection_sampler):
|
||||
dtype=torch.int,
|
||||
device=logits.device,
|
||||
)
|
||||
assert torch.equal(output, expected)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
|
||||
def test_allowed_token_ids(rejection_sampler):
|
||||
@@ -756,14 +764,12 @@ def test_allowed_token_ids(rejection_sampler):
|
||||
bonus_token_tensor = torch.tensor(
|
||||
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
|
||||
)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
logits=logits,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
|
||||
@@ -772,4 +778,4 @@ def test_allowed_token_ids(rejection_sampler):
|
||||
dtype=torch.int,
|
||||
device=logits.device,
|
||||
)
|
||||
assert torch.equal(output, expected)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
Reference in New Issue
Block a user