[Speculative Decoding] Enabling bonus token in speculative decoding for KV cache based models (#5765)
This commit is contained in:
@@ -118,7 +118,8 @@ def test_same_output_for_single_step():
|
||||
actual_output, _ = multi_step_worker.sampler_output(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=multi_step_seq_group),
|
||||
sample_len=num_steps)
|
||||
sample_len=num_steps,
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
assert len(actual_output) == num_steps
|
||||
actual_output = actual_output[0]
|
||||
|
||||
@@ -210,7 +211,8 @@ def test_same_output_for_multi_step():
|
||||
multi_step_output, _ = multi_step_worker.sampler_output(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list),
|
||||
sample_len=num_steps)
|
||||
sample_len=num_steps,
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
|
||||
# Run single-step repeatedly.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
@@ -277,6 +279,203 @@ def test_same_output_for_multi_step():
|
||||
single_step_logprobs)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_multi_step_with_batch_expansion_correct_output():
|
||||
"""
|
||||
In this test we verify that the MultiStepWorker is able to handle bonus
|
||||
tokens correctly. The test verifies that if a sequence has a
|
||||
bonus token then the MultiStepWorker is able to expand the batch by adding
|
||||
new sequences corresponding to the sequences with bonus tokens. The
|
||||
expanded batch is then used for predicting the next tokens.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
|
||||
block_size = 16
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
batch_size = 128
|
||||
multi_step_worker = create_worker(
|
||||
MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
worker = create_worker(
|
||||
Worker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
random.seed(seed)
|
||||
prompts = [[0] for _ in range(batch_size)]
|
||||
num_steps = 2
|
||||
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
|
||||
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
|
||||
multi_step_worker.execute_model = patch_execute_model_with_seeds(
|
||||
multi_step_worker, rand_seeds)
|
||||
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
|
||||
# Create the test continuations
|
||||
continuations = [[random.randint(0, 1000)] for _ in prompts]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Run single-step twice to generate 2 tokens. This
|
||||
# will simulate the bonus token case with the second token
|
||||
# being the bonus token.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
single_step_output: List[SamplerOutput] = []
|
||||
set_random_seed(seed)
|
||||
for _ in range(num_steps):
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
single_step_output.extend(
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list)))
|
||||
# Append output tokens to new sequence data.
|
||||
for i, seq_group_output in enumerate(single_step_output[-1]):
|
||||
continuations[i].append(seq_group_output.samples[0].output_token)
|
||||
|
||||
# Create continuations for the MultiStepWorker. The continuations have
|
||||
# 2 tokens in order to simulate the bonus token case.
|
||||
multi_step_continuations = []
|
||||
for continuation in continuations:
|
||||
multi_step_continuations.append(continuation[:2])
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=multi_step_continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Run multi-step and verify that the third token prediction is accurate
|
||||
# for all sequences.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
all_seq_ids = {i for i in range(batch_size)}
|
||||
multi_step_output, _ = multi_step_worker.sampler_output(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list),
|
||||
sample_len=1,
|
||||
seq_ids_with_bonus_token_in_last_step=all_seq_ids)
|
||||
for index, output in enumerate(multi_step_output[-1].outputs):
|
||||
assert (continuations[index][-1] == output.samples[0].output_token)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_multi_step_with_batch_expansion_incorrect_output():
|
||||
"""
|
||||
Tests the MultiStepWorker's ability to handle batch expansion with bonus
|
||||
tokens in a negative case scenario. This test provides the MultiStepWorker
|
||||
with a batch containing sequences with bonus tokens but specifies the
|
||||
sequence IDs with bonus tokens incorrectly. The test verifies that the
|
||||
MultiStepWorker generates correct tokens for the sequences where the
|
||||
sequence ID is specified correctly and incorrect tokens for those where
|
||||
the sequence ID is specified incorrectly.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
|
||||
block_size = 16
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
batch_size = 128
|
||||
multi_step_worker = create_worker(
|
||||
MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
worker = create_worker(
|
||||
Worker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
random.seed(seed)
|
||||
prompts = [[0] for _ in range(batch_size)]
|
||||
num_steps = 2
|
||||
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
|
||||
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
|
||||
multi_step_worker.execute_model = patch_execute_model_with_seeds(
|
||||
multi_step_worker, rand_seeds)
|
||||
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
|
||||
# Create the test continuations
|
||||
continuations = [[random.randint(0, 1000)] for _ in prompts]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
# Run single-step twice to generate 2 tokens. This
|
||||
# will simulate the bonus token case with the second token
|
||||
# being the bonus token.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
single_step_output: List[SamplerOutput] = []
|
||||
set_random_seed(seed)
|
||||
for _ in range(num_steps):
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
single_step_output.extend(
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list)))
|
||||
# Append output tokens to new sequence data.
|
||||
for i, seq_group_output in enumerate(single_step_output[-1]):
|
||||
continuations[i].append(seq_group_output.samples[0].output_token)
|
||||
|
||||
# Create continuations for the MultiStepWorker. The continuations have
|
||||
# 2 tokens in order to simulate the bonus token case.
|
||||
multi_step_continuations = []
|
||||
for continuation in continuations:
|
||||
multi_step_continuations.append(continuation[:2])
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=multi_step_continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Run multi-step. In this run INCORRECTLY specify that only the odd number
|
||||
# sequences have bonus tokens. Verify that with this setting the third token
|
||||
# prediction is accurate only for the odd numbered sequences. Also verify
|
||||
# that the prediction might be wrong for some of the even numbered
|
||||
# sequences.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
odd_seq_ids = {i for i in range(batch_size) if i % 2 != 0}
|
||||
multi_step_output, _ = multi_step_worker.sampler_output(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list),
|
||||
sample_len=1,
|
||||
seq_ids_with_bonus_token_in_last_step=odd_seq_ids)
|
||||
num_mismatch = 0
|
||||
for index, output in enumerate(multi_step_output[-1].outputs):
|
||||
if (index % 2) != 0:
|
||||
assert (continuations[index][-1] == output.samples[0].output_token)
|
||||
elif (continuations[index][-1] != output.samples[0].output_token):
|
||||
num_mismatch += 1
|
||||
# The prediction is accurate for some of the sequences even without proper
|
||||
# handling of the bonus tokens. Hence verify that the number of sequences
|
||||
# for which there is a mismatch is > 0.
|
||||
assert (num_mismatch > 0)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_draft_proposals_full_speculation_len():
|
||||
"""Verify Top1Proposer correctly handles case where all sequences
|
||||
@@ -318,7 +517,8 @@ def test_draft_proposals_full_speculation_len():
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
num_lookahead_slots=k),
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
@@ -356,7 +556,8 @@ def test_draft_proposals_no_speculations():
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
num_lookahead_slots=k),
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
@@ -428,7 +629,8 @@ def test_draft_proposals_mixed_k():
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
num_lookahead_slots=k),
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
Reference in New Issue
Block a user