[Speculative Decoding] Enabling bonus token in speculative decoding for KV cache based models (#5765)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Set
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@@ -377,8 +378,10 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool,
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
||||
metrics_collector)
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
spec_decode_sampler,
|
||||
metrics_collector=metrics_collector)
|
||||
worker.init_device()
|
||||
|
||||
proposal_token_ids = torch.randint(low=0,
|
||||
@@ -554,7 +557,6 @@ def test_init_device(acceptance_sampler_method: str):
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
||||
metrics_collector)
|
||||
|
||||
worker.init_device()
|
||||
|
||||
draft_worker.init_device.assert_called_once()
|
||||
@@ -645,3 +647,140 @@ def test_split_num_cache_blocks_evenly(available_gpu_blocks: int,
|
||||
assert (num_blocks * target_cache_block_size_bytes) + (
|
||||
num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks *
|
||||
target_cache_block_size_bytes)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_populate_seq_ids_with_bonus_tokens():
|
||||
"""
|
||||
Verify that a call to _create_output_sampler_list correctly updates
|
||||
seq_with_bonus_token_in_last_step.
|
||||
|
||||
seq_with_bonus_token_in_last_step is an internal data structure in
|
||||
SpecDecodeWorker that tracks the sequence IDs which are assigned bonus
|
||||
tokens by the target model in their last forward pass. This state is
|
||||
maintained only for models relying on the KV cache, such as those using
|
||||
the MultiStepWorker.
|
||||
"""
|
||||
batch_size = 10
|
||||
k = 5
|
||||
vocab_size = 10000
|
||||
num_sequences_with_bonus_tokens = 5
|
||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
|
||||
target_worker.device = 'cuda'
|
||||
|
||||
set_random_seed(1)
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
draft_worker.device = 'cuda'
|
||||
# The sequence_ids attached to each sequence in the batch.
|
||||
# The sequence at index i has seq_id assigned_seq_ids[i]
|
||||
assigned_seq_ids = list(range(batch_size))
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
seq_ids=assigned_seq_ids,
|
||||
prev_output_token_len=10)
|
||||
target_token_logprobs = torch.rand(batch_size, (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
accepted_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, (k + 1)),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
expected_request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
for seq_id in seq_group_metadata.seq_data:
|
||||
expected_request_id_seq_ids_mapping[
|
||||
seq_group_metadata.request_id].add(seq_id)
|
||||
# Generate a random sample of sequence indexes with bonus tokens
|
||||
seq_indexes_with_bonus_tokens = random.sample(
|
||||
range(batch_size), num_sequences_with_bonus_tokens)
|
||||
# Create a mask that is True for indices in seq_indexes_with_bonus_tokens
|
||||
mask = torch.ones(batch_size, dtype=torch.bool, device='cuda')
|
||||
mask[seq_indexes_with_bonus_tokens] = False
|
||||
# Set the last token ID to -1 for all indices not in
|
||||
# seq_indexes_with_bonus_tokens to indicate the lack of bonus token in
|
||||
# those indices.
|
||||
accepted_token_ids[mask, -1:] = -1
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
mock_spec_decode_sampler("rejection_sampler"),
|
||||
metrics_collector=metrics_collector)
|
||||
# Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs.
|
||||
# This set includes all sequence IDs in the batch as well as an additional
|
||||
# `num_extra_sequence_ids` sequence IDs. Note that the sequence IDs are in
|
||||
# the range [0, batch_size + num_extra_sequence_ids).
|
||||
num_extra_sequence_ids = 10
|
||||
worker._seq_with_bonus_token_in_last_step = set(
|
||||
range(batch_size + num_extra_sequence_ids))
|
||||
worker._create_output_sampler_list(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
accepted_token_ids=accepted_token_ids,
|
||||
target_logprobs=target_token_logprobs,
|
||||
k=k)
|
||||
# Verify that _seq_with_bonus_token_in_last_step contains the following:
|
||||
# 1. Sequence IDs that were already present in
|
||||
# _seq_with_bonus_token_in_last_step but were not part of the current
|
||||
# batch are retained.
|
||||
# 2. Of the sequence IDs present in the current batch, only those with a
|
||||
# bonus token are retained in _seq_with_bonus_token_in_last_step.
|
||||
# Sequence IDs that are present in the current batch but do not have
|
||||
# bonus tokens are removed from _seq_with_bonus_token_in_last_step.
|
||||
expected_seq_ids_with_bonus_tokens = \
|
||||
set([assigned_seq_ids[i] for i in seq_indexes_with_bonus_tokens])
|
||||
additional_sequence_ids = \
|
||||
set(range(batch_size, batch_size + num_extra_sequence_ids))
|
||||
assert worker._seq_with_bonus_token_in_last_step == \
|
||||
expected_seq_ids_with_bonus_tokens.union(additional_sequence_ids)
|
||||
assert worker._request_id_seq_id_mapping == \
|
||||
expected_request_id_seq_ids_mapping
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_handle_finished_requests():
|
||||
"""
|
||||
Test to verify that finished request IDs are appropriately processed to
|
||||
update the internal state of the SpecDecodeWorker.
|
||||
|
||||
This test initializes the SpecDecodeWorker with mock data, marks certain
|
||||
requests as finished, and ensures that the corresponding sequence IDs are
|
||||
correctly removed from the internal mappings.
|
||||
"""
|
||||
batch_size = 32
|
||||
k = 3
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker,
|
||||
mock_spec_decode_sampler("rejection_sampler"),
|
||||
metrics_collector)
|
||||
# Initialize the request_id_seq_id_mapping mapping dict with a few fake
|
||||
# request ids and corresponding sequence ids.
|
||||
worker._request_id_seq_id_mapping = \
|
||||
{'request-1': {1,2,3}, 'request-2': {4,5,6,7},
|
||||
'request-3': {8,9}, 'request-4': {10,11}}
|
||||
# Initialize seq_with_bonus_token_in_last_step with a few fake
|
||||
# sequence ids.
|
||||
worker._seq_with_bonus_token_in_last_step = {1, 4, 5, 8, 9, 10}
|
||||
exception_secret = 'artificial stop'
|
||||
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
# Mark requests with ids request-1 and request-3 as finished.
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k,
|
||||
finished_requests_ids=['request-1', 'request-3'])
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
# Verify that request-1 and request-3 are removed from
|
||||
# request_id_seq_id_mapping
|
||||
assert worker._request_id_seq_id_mapping == \
|
||||
{'request-2': {4,5,6,7}, 'request-4': {10,11}}
|
||||
# Verify that all sequence ids corresponding to 'request-1'
|
||||
# and 'request-3' are removed from seq_with_bonus_token_in_last_step.
|
||||
assert worker._seq_with_bonus_token_in_last_step == \
|
||||
{4,5,10}
|
||||
|
||||
Reference in New Issue
Block a user