[Bugfix][SpecDecode] kv corruption with bonus tokens in spec decode (#9730)
Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
@@ -5,6 +5,8 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.selector import (_Backend,
|
||||
global_force_attn_backend_context_manager)
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob,
|
||||
@@ -303,6 +305,7 @@ def test_multi_step_with_batch_expansion_correct_output():
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
multi_step_worker.set_include_gpu_probs_tensor()
|
||||
worker = create_worker(
|
||||
Worker,
|
||||
model_name,
|
||||
@@ -397,6 +400,7 @@ def test_multi_step_with_batch_expansion_incorrect_output():
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
multi_step_worker.set_include_gpu_probs_tensor()
|
||||
worker = create_worker(
|
||||
Worker,
|
||||
model_name,
|
||||
@@ -477,6 +481,109 @@ def test_multi_step_with_batch_expansion_incorrect_output():
|
||||
assert (num_mismatch > 0)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize('num_steps', [1, 2, 3, 4])
|
||||
# The choice of backends forces the multi_step_worker to choose between
|
||||
# the vanilla model_runner and TP1DraftModelRunner and that we can test
|
||||
# both code paths.
|
||||
@pytest.mark.parametrize('attn_backend',
|
||||
[_Backend.XFORMERS, _Backend.FLASH_ATTN])
|
||||
def test_multi_step_correct_kvcache(num_steps, attn_backend):
|
||||
"""Verify that the KV cache of the draft model
|
||||
is correctly updated for sequences with bonus token.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = "JackFram/llama-68m"
|
||||
|
||||
block_size = 16
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
batch_size = 1
|
||||
|
||||
with global_force_attn_backend_context_manager(attn_backend):
|
||||
dtype = 'float16' if attn_backend == _Backend.FLASH_ATTN else 'float32'
|
||||
multi_step_worker = create_worker(MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
dtype=dtype)
|
||||
multi_step_worker.set_include_gpu_probs_tensor()
|
||||
worker = create_worker(Worker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
dtype=dtype)
|
||||
|
||||
prompts = [[0] for _ in range(batch_size)]
|
||||
# Already generate two tokens for the sequence
|
||||
# so that we can simulate the bonus token case
|
||||
multi_step_continuations = [[
|
||||
random.randint(0, 1000),
|
||||
random.randint(0, 1000)
|
||||
] for _ in prompts]
|
||||
final_prompt_lens = [len(prompt) + 2 + num_steps for prompt in prompts]
|
||||
|
||||
seq_ids_with_bonus_token_in_last_step = set(range(batch_size))
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=multi_step_continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Run multi-step.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
multi_step_worker.sampler_output(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list),
|
||||
sample_len=num_steps,
|
||||
seq_ids_with_bonus_token_in_last_step=
|
||||
seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
# Run single-step repeatedly.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
# Generate the kv cache for the bonus token first
|
||||
single_step_continuations = [c[:1] for c in multi_step_continuations]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=single_step_continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
single_step_output = worker.execute_model(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list))
|
||||
for _ in range(num_steps):
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=multi_step_continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
single_step_output = worker.execute_model(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list))
|
||||
|
||||
for i, seq_group_output in enumerate(single_step_output[-1]):
|
||||
multi_step_continuations[i].append(
|
||||
seq_group_output.samples[0].output_token)
|
||||
|
||||
# Verify that the KV cache of the single-step and
|
||||
# multi-step workers are the same.
|
||||
single_step_gpu_cache = worker.cache_engine[0].gpu_cache
|
||||
multi_step_gpu_cache = multi_step_worker.cache_engine[0].gpu_cache
|
||||
num_layers = len(single_step_gpu_cache)
|
||||
allclose = lambda a, b: torch.allclose(
|
||||
a.cuda(), b.cuda(), rtol=1e-2, atol=1e-2)
|
||||
for i in range(num_layers):
|
||||
assert allclose(single_step_gpu_cache[i][0],
|
||||
multi_step_gpu_cache[i][0])
|
||||
assert allclose(single_step_gpu_cache[i][1],
|
||||
multi_step_gpu_cache[i][1])
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_draft_proposals_full_speculation_len():
|
||||
"""Verify Top1Proposer correctly handles case where all sequences
|
||||
|
||||
Reference in New Issue
Block a user