[Speculative decoding 2/9] Multi-step worker for draft model (#2424)
This commit is contained in:
@@ -18,7 +18,7 @@ from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
|
||||
SequenceGroupOutput, SequenceOutput, SequenceStatus)
|
||||
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
||||
get_tokenizer)
|
||||
from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port
|
||||
from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port, get_distributed_init_method
|
||||
|
||||
if ray:
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
@@ -132,7 +132,8 @@ class LLMEngine:
|
||||
"Ray is required if parallel_config.world_size > 1.")
|
||||
|
||||
self.workers: List[Worker] = []
|
||||
distributed_init_method = f"tcp://{get_ip()}:{get_open_port()}"
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
self.driver_worker = Worker(
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
@@ -207,7 +208,8 @@ class LLMEngine:
|
||||
for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
|
||||
worker.set_cuda_visible_devices.remote(node_gpus[node_id])
|
||||
|
||||
distributed_init_method = f"tcp://{driver_ip}:{get_open_port()}"
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
driver_ip, get_open_port)
|
||||
|
||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||
|
||||
@@ -65,10 +65,9 @@ def initialize_cluster(
|
||||
the default Ray cluster address.
|
||||
|
||||
Returns:
|
||||
A tuple of (`distributed_init_method`, `placement_group`). The
|
||||
`distributed_init_method` is the address for initializing the
|
||||
distributed backend. `placement_group` includes the specification
|
||||
of the resources for each distributed worker.
|
||||
An optional `PlacementGroup`. It includes the specification
|
||||
of the resources for each distributed worker. None if Ray is
|
||||
not used.
|
||||
"""
|
||||
if parallel_config.worker_use_ray or engine_use_ray:
|
||||
if ray is None:
|
||||
|
||||
@@ -83,6 +83,31 @@ def initialize_model_parallel(
|
||||
_PIPELINE_GLOBAL_RANKS = ranks
|
||||
|
||||
|
||||
def ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size: int,
|
||||
pipeline_model_parallel_size: int,
|
||||
) -> None:
|
||||
"""Helper to initialize model parallel groups if they are not initialized,
|
||||
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
|
||||
values if the model parallel groups are initialized.
|
||||
"""
|
||||
if not model_parallel_is_initialized():
|
||||
initialize_model_parallel(tensor_model_parallel_size,
|
||||
pipeline_model_parallel_size)
|
||||
return
|
||||
|
||||
assert (
|
||||
get_tensor_model_parallel_world_size() == tensor_model_parallel_size
|
||||
), ("tensor parallel group already initialized, but of unexpected size: "
|
||||
f"{get_tensor_model_parallel_world_size()=} vs. "
|
||||
f"{tensor_model_parallel_size=}")
|
||||
assert (get_pipeline_model_parallel_world_size(
|
||||
) == pipeline_model_parallel_size), (
|
||||
"pipeline parallel group already initialized, but of unexpected size: "
|
||||
f"{get_pipeline_model_parallel_world_size()=} vs. "
|
||||
f"{pipeline_model_parallel_size=}")
|
||||
|
||||
|
||||
def model_parallel_is_initialized():
|
||||
"""Check if tensor and pipeline parallel groups are initialized."""
|
||||
return (_TENSOR_MODEL_PARALLEL_GROUP is not None
|
||||
|
||||
@@ -65,6 +65,10 @@ def get_ip() -> str:
|
||||
return s.getsockname()[0]
|
||||
|
||||
|
||||
def get_distributed_init_method(ip: str, port: int) -> str:
|
||||
return f"tcp://{ip}:{port}"
|
||||
|
||||
|
||||
def get_open_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
|
||||
@@ -277,8 +277,8 @@ class ModelRunner:
|
||||
input_block_tables[i, :len(block_table)] = block_table
|
||||
block_tables = torch.tensor(input_block_tables, device="cuda")
|
||||
else:
|
||||
max_block_table_len = (max_context_len + self.block_size -
|
||||
1) // self.block_size
|
||||
max_block_table_len = max(
|
||||
len(block_table) for block_table in block_tables)
|
||||
block_tables = _make_tensor_with_pad(
|
||||
block_tables,
|
||||
max_len=max_block_table_len,
|
||||
|
||||
178
vllm/worker/spec_decode/multi_step_worker.py
Normal file
178
vllm/worker/spec_decode/multi_step_worker.py
Normal file
@@ -0,0 +1,178 @@
|
||||
from typing import List, Dict
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
|
||||
class MultiStepWorker(Worker):
|
||||
"""The MultiStepWorker is equivalent to a Worker except that it allows
|
||||
multiple forward passes in a single call, assuming the scheduler has
|
||||
allocated enough space to store the additional KV. This reduces overhead
|
||||
by invoking the scheduler less.
|
||||
|
||||
The MultiStepWorker does not support cache swap operations, or beam search.
|
||||
Cache swap operations do not require large modifications. On the other hand,
|
||||
beam search requires memory allocations during sequence forks and thus
|
||||
requires more thought for MultiStepWorker support.
|
||||
"""
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model_multi_step(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
num_steps: int,
|
||||
) -> List[SamplerOutput]:
|
||||
"""Run the model forward pass num_steps times. Returns the list of
|
||||
sampler output, one per model forward pass.
|
||||
"""
|
||||
self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in,
|
||||
blocks_to_swap_out, blocks_to_copy)
|
||||
|
||||
# Shallow copy input data so modifications (such as appending tokens)
|
||||
# do not cause side-effects.
|
||||
copied_seq_group_metadata_list = self._shallow_copy_inputs(
|
||||
seq_group_metadata_list)
|
||||
|
||||
# Assert enough KV space for num_steps tokens per sequence.
|
||||
self._assert_enough_kv_space(seq_group_metadata_list, num_steps)
|
||||
|
||||
# Run model num_steps times.
|
||||
model_outputs = []
|
||||
for _ in range(num_steps):
|
||||
model_output = super().execute_model(
|
||||
seq_group_metadata_list=copied_seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
|
||||
self._append_new_tokens(model_output,
|
||||
copied_seq_group_metadata_list)
|
||||
model_outputs.append(model_output)
|
||||
|
||||
return model_outputs
|
||||
|
||||
def _append_new_tokens(
|
||||
self, model_output: SamplerOutput,
|
||||
seq_group_metadata_list: SequenceGroupMetadata) -> None:
|
||||
"""Given model output from a single run, append the tokens to the
|
||||
sequences. This is normally done outside of the worker, but it is
|
||||
required if the worker is to perform multiple forward passes.
|
||||
"""
|
||||
for seq_group_metadata, sequence_group_outputs in zip(
|
||||
seq_group_metadata_list, model_output):
|
||||
seq_group_metadata.is_prompt = False
|
||||
|
||||
for seq_output in sequence_group_outputs.samples:
|
||||
# NOTE: Beam search is not supported, so we can assume that
|
||||
# parent_seq_id == seq_id.
|
||||
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
|
||||
|
||||
token_id = seq_output.output_token
|
||||
token_logprob = seq_output.logprobs[token_id]
|
||||
|
||||
seq.append_token_id(token_id, token_logprob)
|
||||
|
||||
def _shallow_copy_inputs(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
"""Copy input data structures to remove side-effects when input data
|
||||
structures are shared with other modules.
|
||||
|
||||
The multi-step worker must be able to append tokens to sequences after
|
||||
a forward pass. This necessitates modification of the data structures
|
||||
used by the worker. Since these data structures are shared with other
|
||||
parts of vLLM, like the scheduler, we must take care not to introduce
|
||||
unexpected side-effects.
|
||||
|
||||
When Ray is used to orchestrate worker processes (such as when the
|
||||
tensor-parallel degree is >1), this is not a problem because the input
|
||||
datastructures will be serialized and created anew in the worker
|
||||
process.
|
||||
|
||||
However, when Ray is not used to orchestrate the worker processes (such
|
||||
as when the tensor-parallel degree is 1), this is a problem. We avoid
|
||||
the problem by shallow-copying the input datastructures (specifically,
|
||||
the parts that will change in multiple steps).
|
||||
"""
|
||||
|
||||
# Shallow-copy the list of SequenceGroupMetadata. This allows us to
|
||||
# append tokens and change is_prompt without external side-effects.
|
||||
new_seq_group_metadata_list = []
|
||||
|
||||
for old_seq_group_metadata in seq_group_metadata_list:
|
||||
# We must shallow-copy seq_group_metadata as is_prompt could change.
|
||||
seq_group_metadata = copy.copy(old_seq_group_metadata)
|
||||
new_seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
# We must shallow-copy seq_data as we will append token ids
|
||||
new_seq_data = {}
|
||||
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
|
||||
new_seq_data[seq_id] = copy.copy(old_seq_data)
|
||||
new_seq_data[
|
||||
seq_id].output_token_ids = old_seq_data.output_token_ids[:]
|
||||
|
||||
seq_group_metadata.seq_data = new_seq_data
|
||||
|
||||
return new_seq_group_metadata_list
|
||||
|
||||
def _assert_enough_kv_space(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
num_steps: int) -> None:
|
||||
"""Assert there are enough physical blocks per sequence to store the
|
||||
current KV plus additional KV from num_steps tokens.
|
||||
"""
|
||||
assert self.model_runner.block_size is not None
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
# Only one seq_id is guaranteed because there is no beam search.
|
||||
seq_id = list(seq_group_metadata.seq_data.keys())[0]
|
||||
seq = seq_group_metadata.seq_data[seq_id]
|
||||
|
||||
# After num_steps, the seq len will be the current seq len
|
||||
# plus one token per step.
|
||||
final_seq_len = seq.get_len() + num_steps
|
||||
|
||||
# We will have final_seq_len - 1 KV because vLLM saves KV for a
|
||||
# token in the iteration after the token was generated.
|
||||
required_num_kv_slots = final_seq_len - 1
|
||||
|
||||
# The allocated number of kv slots is the number of allocated blocks
|
||||
# times the number of slots of block.
|
||||
number_physical_blocks = len(
|
||||
seq_group_metadata.block_tables[seq_id])
|
||||
allocated_kv_slots = (number_physical_blocks *
|
||||
self.model_runner.block_size)
|
||||
|
||||
if required_num_kv_slots > allocated_kv_slots:
|
||||
request_id = seq_group_metadata.request_id
|
||||
raise ValueError(
|
||||
"The worker attempted to run "
|
||||
f"{num_steps} times but found insufficient KV space for "
|
||||
f"{request_id=} {seq_id=}. ({allocated_kv_slots=} "
|
||||
f"{required_num_kv_slots=}).")
|
||||
|
||||
def _raise_if_unsupported(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
) -> None:
|
||||
"""MultiStepWorker does not yet implement support for cache swap
|
||||
operations or beam search.
|
||||
"""
|
||||
if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
|
||||
raise NotImplementedError(
|
||||
"MultiStepWorker does not support cache operations")
|
||||
|
||||
if any(
|
||||
len(seq_group_metadata.seq_data.keys()) != 1
|
||||
for seq_group_metadata in seq_group_metadata_list):
|
||||
raise NotImplementedError(
|
||||
"MultiStepWorker does not support beam search.")
|
||||
@@ -11,7 +11,7 @@ from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
broadcast_tensor_dict)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
initialize_model_parallel)
|
||||
ensure_model_parallel_initialized)
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
@@ -227,8 +227,8 @@ def _init_distributed_environment(
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
||||
initialize_model_parallel(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
|
||||
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||
|
||||
Reference in New Issue
Block a user