[Speculative decoding 3/9] Worker which speculates, scores, and applies rejection sampling (#3103)
This commit is contained in:
257
tests/spec_decode/utils.py
Normal file
257
tests/spec_decode/utils.py
Normal file
@@ -0,0 +1,257 @@
|
||||
import torch
|
||||
from typing import List, Optional, Dict, Iterable, Union
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.sequence import (Logprob, SequenceGroupMetadata, SequenceData,
|
||||
SamplerOutput, SequenceGroupOutput, SequenceOutput)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from itertools import count
|
||||
from dataclasses import dataclass, fields
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecuteModelData:
|
||||
"""Helper data structure which facilitates cleaner tests.
|
||||
"""
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
blocks_to_swap_in: Dict[int, int]
|
||||
blocks_to_swap_out: Dict[int, int]
|
||||
blocks_to_copy: Dict[int, List[int]]
|
||||
|
||||
def to_dict(self):
|
||||
return dict(
|
||||
(field.name, getattr(self, field.name)) for field in fields(self))
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
cleaned = dict((field.name, d[field.name]) for field in fields(cls))
|
||||
return cls(**cleaned)
|
||||
|
||||
|
||||
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
|
||||
return (seq_len + block_size - 1) // block_size
|
||||
|
||||
|
||||
def create_execute_model_data(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Optional[Dict[int, int]] = None,
|
||||
blocks_to_swap_out: Optional[Dict[int, int]] = None,
|
||||
blocks_to_copy: Optional[Dict[int, int]] = None,
|
||||
) -> ExecuteModelData:
|
||||
if blocks_to_swap_in is None:
|
||||
blocks_to_swap_in = {}
|
||||
if blocks_to_swap_out is None:
|
||||
blocks_to_swap_out = {}
|
||||
if blocks_to_copy is None:
|
||||
blocks_to_copy = {}
|
||||
|
||||
return ExecuteModelData(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
|
||||
|
||||
def mock_worker(cls=None,
|
||||
vocab_size: int = 30_000,
|
||||
max_model_len: int = 2048,
|
||||
rank: int = 0) -> MagicMock:
|
||||
if cls is None:
|
||||
cls = Worker
|
||||
|
||||
worker = MagicMock(spec=cls)
|
||||
worker.vocab_size = vocab_size
|
||||
worker.max_model_len = max_model_len
|
||||
worker.rank = rank
|
||||
worker.device = 'cuda:0'
|
||||
return worker
|
||||
|
||||
|
||||
def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
|
||||
seed_iter = iter(rand_seeds)
|
||||
original_execute_model = worker.execute_model
|
||||
|
||||
def new_execute_model(*args, **kwargs):
|
||||
result = original_execute_model(*args, **kwargs)
|
||||
set_random_seed(next(seed_iter))
|
||||
return result
|
||||
|
||||
return new_execute_model
|
||||
|
||||
|
||||
def zero_kv_cache(cache_engine: CacheEngine):
|
||||
assert cache_engine.gpu_cache
|
||||
for key_blocks, value_blocks in cache_engine.gpu_cache:
|
||||
key_blocks.zero_()
|
||||
value_blocks.zero_()
|
||||
|
||||
|
||||
def create_worker(cls: type,
|
||||
model_name: str,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
seed: int,
|
||||
is_driver_worker: bool = True,
|
||||
enforce_eager: bool = True):
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
seed=seed,
|
||||
block_size=block_size,
|
||||
enforce_eager=enforce_eager,
|
||||
)
|
||||
|
||||
(model_config, cache_config, parallel_config, scheduler_config,
|
||||
device_config, _) = engine_args.create_engine_configs()
|
||||
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
|
||||
worker = cls(
|
||||
model_config=model_config,
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=device_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=is_driver_worker,
|
||||
)
|
||||
|
||||
worker.init_model()
|
||||
worker.load_model()
|
||||
|
||||
cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
cache_config.num_cpu_blocks = 0
|
||||
worker.init_cache_engine(cache_config)
|
||||
worker.warm_up_model()
|
||||
|
||||
return worker
|
||||
|
||||
|
||||
def create_seq_group_metadata_from_prompts(
|
||||
prompts: List[List[int]],
|
||||
num_gpu_blocks: int,
|
||||
block_size: int,
|
||||
final_seq_lens: List[int],
|
||||
continuations: Optional[List[List[int]]] = None,
|
||||
seq_ids: Optional[List[int]] = None,
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
|
||||
if continuations is None:
|
||||
continuations = [[] for _ in prompts]
|
||||
|
||||
if seq_ids is None:
|
||||
seq_ids = list(i for i, _ in enumerate(prompts))
|
||||
|
||||
free_gpu_blocks = list(range(num_gpu_blocks))
|
||||
|
||||
block_allocations = {
|
||||
i: [
|
||||
free_gpu_blocks.pop()
|
||||
for _ in range(round_up_to_next_block(final_len, block_size))
|
||||
]
|
||||
for i, final_len in enumerate(final_seq_lens)
|
||||
}
|
||||
|
||||
return [
|
||||
SequenceGroupMetadata(
|
||||
request_id=str(i),
|
||||
is_prompt=len(cont_token_ids) == 0,
|
||||
seq_data={
|
||||
i:
|
||||
SequenceData(
|
||||
prompt_token_ids=prompt_token_ids[:],
|
||||
output_token_ids=cont_token_ids[:],
|
||||
),
|
||||
},
|
||||
sampling_params=SamplingParams(temperature=0.0, ),
|
||||
block_tables={i: block_allocations[i][:]},
|
||||
) for i, (prompt_token_ids,
|
||||
cont_token_ids) in enumerate(zip(prompts, continuations))
|
||||
]
|
||||
|
||||
|
||||
def assert_logprobs_dict_allclose(
|
||||
actual_logprobs: List[Dict[int, Logprob]],
|
||||
expected_logprobs: List[Dict[int, Logprob]]) -> None:
|
||||
for single_step_actual_logprobs, single_step_expected_logprobs in zip(
|
||||
actual_logprobs, expected_logprobs):
|
||||
assert set(single_step_actual_logprobs.keys()) == set(
|
||||
single_step_expected_logprobs.keys())
|
||||
for token_id in single_step_actual_logprobs:
|
||||
actual = torch.tensor(
|
||||
single_step_actual_logprobs[token_id].logprob)
|
||||
expected = torch.tensor(
|
||||
single_step_expected_logprobs[token_id].logprob)
|
||||
assert torch.allclose(actual, expected)
|
||||
|
||||
|
||||
def create_sampler_output_list(
|
||||
token_ids: torch.Tensor,
|
||||
probs: Iterable[Optional[torch.Tensor]],
|
||||
seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]:
|
||||
num_steps, batch_size = token_ids.shape
|
||||
token_ids_by_step = token_ids.tolist()
|
||||
|
||||
if seq_ids is None:
|
||||
seq_ids = list(range(batch_size))
|
||||
|
||||
return [
|
||||
SamplerOutput(outputs=[
|
||||
SequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
output_token=token_id,
|
||||
parent_seq_id=seq_ids[seq_index],
|
||||
logprobs={token_id: 0},
|
||||
)
|
||||
],
|
||||
prompt_logprobs=None,
|
||||
) for seq_index, token_id in enumerate(token_ids_by_step[step])
|
||||
],
|
||||
sampled_token_probs=probs[step],
|
||||
sampled_token_ids=token_ids[step])
|
||||
for step in range(num_steps)
|
||||
]
|
||||
|
||||
|
||||
def create_batch(batch_size,
|
||||
k,
|
||||
prompt_len: Union[int, List[int]] = 10,
|
||||
prev_output_token_len: int = 10,
|
||||
seq_ids: Optional[List[int]] = None,
|
||||
num_gpu_blocks: Optional[int] = None,
|
||||
block_size: Optional[int] = None):
|
||||
if block_size is None:
|
||||
block_size = 8
|
||||
|
||||
if num_gpu_blocks is None:
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
|
||||
iterator = count()
|
||||
|
||||
if isinstance(prompt_len, int):
|
||||
prompt_lens = [prompt_len for _ in range(batch_size)]
|
||||
else:
|
||||
prompt_lens = prompt_len
|
||||
|
||||
prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens]
|
||||
prev_output_tokens = [[
|
||||
next(iterator) for _ in range(prev_output_token_len)
|
||||
] for _ in range(batch_size)]
|
||||
final_seq_lens = [
|
||||
len(prompt) + len(prev_output_token) + k + 1
|
||||
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
|
||||
]
|
||||
|
||||
execute_model_data = create_execute_model_data(
|
||||
create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks,
|
||||
block_size, final_seq_lens,
|
||||
prev_output_tokens, seq_ids), )
|
||||
return execute_model_data, prompts, prev_output_tokens
|
||||
Reference in New Issue
Block a user