[Core] draft_model_runner: Implement prepare_inputs on GPU for advance_step (#6338)

This commit is contained in:
Alexander Matveev
2024-07-17 17:30:28 -04:00
committed by GitHub
parent 5f0b9933e6
commit e76466dde2
12 changed files with 568 additions and 130 deletions

View File

@@ -2,17 +2,22 @@ from typing import List, Optional
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SamplerOutput)
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner)
logger = init_logger(__name__)
debug_advance_input = False
enable_gpu_advance_step = True
class TP1DraftModelRunner(ModelRunner):
"""Specialized model runner for speculative decoding draft model.
@@ -21,18 +26,9 @@ class TP1DraftModelRunner(ModelRunner):
we could get rid of most CPU-GPU synchronization and data transfer
overheads by keeping model input and output tensors on GPU all the time.
This runner is still under development so there's no performance gain
at this moment. Currently we adopt a temporary solution that caches the
seq_group_metadata_list for multi-step execution, so that we can
leverage existing prepare_model_input to be compatible with the current
execution flow, but we plan to remove this cache and avoid calling
prepare_model_input in execute_model at all.
The detail development plan includes:
1. Use "update_model_input" to update existing model_input without
creating a new one.
2. Improve the performance of "update_model_input" with a GPU kernel.
3. Support TP > 1 (this requires some designs because we do not expect
TODOs:
1. Currently supports only flash-attn, add support for other attn_backends.
2. Support TP > 1 (this requires some designs because we do not expect
any broadcasting inside execute_model).
"""
@@ -71,51 +67,156 @@ class TP1DraftModelRunner(ModelRunner):
return_hidden_states=return_hidden_states,
)
# TODO: Remove this cache when we are able to update model_input
# directly in advance_step.
self.cached_seq_group_metadata_list: Optional[
List[SequenceGroupMetadata]] = None
def _update_flash_attn_metadata(self, attn_metadata, num_seqs,
num_queries):
assert isinstance(attn_metadata, FlashAttentionMetadata)
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForGPUWithSamplingMetadata:
"""A temporary solution that caches the seq_group_metadata_list
for multi-step execution.
TODO: In-place update model_input and remove this function.
"""
self.cached_seq_group_metadata_list = seq_group_metadata_list
return super().prepare_model_input(
seq_group_metadata_list,
finished_requests_ids=finished_requests_ids)
if num_seqs != num_queries:
assert num_seqs > num_queries
assert attn_metadata.use_cuda_graph
def update_model_input(
assert attn_metadata.num_prefills == 0
assert attn_metadata.num_prefill_tokens == 0
assert attn_metadata.num_decode_tokens == num_seqs
assert attn_metadata.slot_mapping.shape == (num_seqs, )
assert len(attn_metadata.seq_lens) == num_seqs
assert attn_metadata.seq_lens_tensor.shape == (num_seqs, )
assert attn_metadata.max_query_len == 1
assert attn_metadata.max_prefill_seq_len == 0
assert attn_metadata.max_decode_seq_len == max(attn_metadata.seq_lens)
assert attn_metadata.query_start_loc.shape == (num_queries + 1, )
assert attn_metadata.seq_start_loc.shape == (num_seqs + 1, )
assert attn_metadata.context_lens_tensor.shape == (num_queries, )
assert attn_metadata.block_tables.shape[0] == num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
attn_metadata.seq_lens[i] += 1
attn_metadata.max_decode_seq_len = max(attn_metadata.seq_lens)
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
num_queries):
assert sampling_metadata.num_prompts == 0
assert len(sampling_metadata.seq_groups) == num_queries
assert sampling_metadata.selected_token_indices.shape == (
num_queries, )
# assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
# Verify that all sequences are decodes
for i in range(num_queries):
seq_group = sampling_metadata.seq_groups[i]
assert seq_group.is_prompt is False # No prompt
assert seq_group.prompt_logprob_indices == [] # No prompt
assert seq_group.sample_indices == [i] # Simple
assert seq_group.seq_len is None # Decode
assert seq_group.query_len is None # Decode
def _gpu_advance_step(
self, model_input: ModelInputForGPUWithSamplingMetadata,
last_output: SamplerOutput
) -> ModelInputForGPUWithSamplingMetadata:
"""Prepare the model inputs for the next step.
TODO: In-place update model_input instead of calling
prepare_model_input.
# Currently, we expect "decode mode" only
assert not model_input.is_prompt
# Get num_seqs
num_seqs = len(model_input.seq_lens)
num_queries = len(model_input.query_lens)
# Get output tokens GPU tensor
sampled_token_ids = last_output.sampled_token_ids
assert sampled_token_ids is not None
# Update attn_metadata
attn_metadata = model_input.attn_metadata
assert isinstance(attn_metadata, FlashAttentionMetadata)
self._update_flash_attn_metadata(attn_metadata, num_seqs, num_queries)
# Update GPU tensors
ops.advance_step(num_seqs=num_seqs,
num_queries=num_queries,
block_size=self.block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions,
seq_lens=attn_metadata.seq_lens_tensor,
slot_mapping=attn_metadata.slot_mapping,
block_tables=attn_metadata.block_tables)
# Update sampling_metadata
sampling_metadata = model_input.sampling_metadata
self._update_sampling_metadata(sampling_metadata, num_seqs,
num_queries)
# Create new input
new_model_input = self._model_input_cls(
input_tokens=model_input.input_tokens,
input_positions=model_input.input_positions,
attn_metadata=attn_metadata,
seq_lens=attn_metadata.seq_lens,
query_lens=model_input.query_lens,
lora_mapping=model_input.lora_mapping,
lora_requests=model_input.lora_requests,
multi_modal_kwargs=model_input.multi_modal_kwargs,
sampling_metadata=model_input.sampling_metadata,
is_prompt=False,
)
# Ensure we skip CPU samples
assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True
# We can reuse sampling tensors since every decode iteration is the same
new_model_input.sampling_metadata.reuse_sampling_tensors = True
if debug_advance_input:
logger.debug("NEW INPUT: ")
logger.debug(" input_tokens = %s", new_model_input.input_tokens)
logger.debug(" input_positions = %s",
new_model_input.input_positions)
logger.debug(" seq_lens = %d", new_model_input.seq_lens)
logger.debug(" query_lens = %d", new_model_input.query_lens)
logger.debug(" attn_metadata:")
logger.debug(" seq_lens_tensor: %s",
attn_metadata.seq_lens_tensor)
logger.debug(" slot_mapping: %s", attn_metadata.slot_mapping)
logger.debug(" block_tables: %s", attn_metadata.block_tables)
return new_model_input
def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
"""Determines if draft_model_runner GPU multi-step can be used.
Currently required conditions are:
1. Only decodes
2. Only flash-attn
3. No LORA
4. No prompt_adapter_config
"""
if not enable_gpu_advance_step:
return False
# Append the output token to the sequence data.
assert self.cached_seq_group_metadata_list is not None
for seq_group_metadata, sequence_group_outputs in zip(
self.cached_seq_group_metadata_list, last_output.outputs):
seq_group_metadata.is_prompt = False
# We allow multi-step GPU only in decode mode
for seq_group in execute_model_req.seq_group_metadata_list:
if seq_group.is_prompt:
return False
for seq_output in sequence_group_outputs.samples:
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
# TODO: Add support for other attn backends
if self.attn_backend.get_name() != "flash-attn":
return False
token_id = seq_output.output_token
token_logprob = seq_output.logprobs[token_id]
# TODO: Add support for LORA
if self.lora_config:
return False
seq.append_token_id(token_id, token_logprob.logprob)
seq.update_num_computed_tokens(1)
# TODO: Add soft-tuning prompt adapter support
if self.prompt_adapter_config:
return False
return self.prepare_model_input(self.cached_seq_group_metadata_list)
return True
@torch.inference_mode()
def execute_model(
@@ -125,42 +226,86 @@ class TP1DraftModelRunner(ModelRunner):
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
# Since we do not broadcast data inside execute_model anymore,
# we need to figure out the best way to support TP > 1 in this
# case, because we will at least need to broadcast the sampled
# tokens to all workers.
if not self.is_driver_worker:
raise ValueError("TP1DraftModelRunner only supports TP=1.")
"""Executes num_steps forward passes with advacement of input tensors
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
Optimizations used:
1. Input tensors are updated on the GPU directly
2. Skips GPU=>CPU serialization of sampler outputs (we don't need
them since we do batch expansion later that uses GPU outputs)
3. Reuses sampling tensors (since we run only decodes and they have
a repeating sampling logic)
"""
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
# When num_steps == 1, we execute the fallback here for the GPU
# advance_step, which runs prepare_inputs on CPU and for each spec
# iteration invokes this function only once
# (Look at multi-step-worker code)
is_fallback = num_steps == 1
if not is_fallback:
# Since we do not broadcast data inside execute_model anymore,
# we need to figure out the best way to support TP > 1 in this
# case, because we will at least need to broadcast the sampled
# tokens to all workers.
if not self.is_driver_worker:
raise ValueError("TP1DraftModelRunner only supports TP=1.")
# Sanity
if self.lora_config is not None:
raise ValueError("TP1DraftModelRunner has no support for LORA")
if self.prompt_adapter_config is not None:
raise ValueError("TP1DraftModelRunner has no support for "
"prompt_adapter_config")
if model_input.multi_modal_kwargs:
raise ValueError(
"TP1DraftModelRunner has no support for multi_modal_kwargs"
)
else:
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
# Detect exec mode
assert model_input.attn_metadata is not None
use_cuda_graph = False
if model_input.attn_metadata.num_prefills > 0:
# In this case, execute_model(..) was called directly
if num_steps > 1:
raise ValueError(
"execute_model(..) of draft_model_runner can be called "
"directly only with a single-step prefill")
else:
# We can skip CPU samples for spec token generation.
# (We do allow CPU samples for num_steps == 1 to support the
# fallback case, where supports_gpu_multi_step(..) does not pass)
model_input.sampling_metadata.skip_sampler_cpu_output = (
not is_fallback)
# Attn attr defines if we use cuda graphs
use_cuda_graph = model_input.attn_metadata.use_cuda_graph
# Get model
if use_cuda_graph:
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = (self.graph_runners[model_input.virtual_engine]
[graph_batch_size])
else:
model_executable = self.model
virtual_engine = model_input.virtual_engine
outputs: List[SamplerOutput] = []
for step in range(num_steps):
# Currently cuda graph is only supported by the decode phase.
assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata
if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = (
self.graph_runners[virtual_engine][graph_batch_size])
else:
model_executable = self.model
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
# Run model
hidden_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
@@ -181,8 +326,8 @@ class TP1DraftModelRunner(ModelRunner):
sampling_metadata=model_input.sampling_metadata,
))
# Prepare the inputs for the next step.
# Prepare inputs for the next step
if step != num_steps - 1:
model_input = self.update_model_input(model_input, outputs[-1])
model_input = self._gpu_advance_step(model_input, outputs[-1])
return outputs

View File

@@ -67,14 +67,23 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
expanded_request, indices_of_seq_with_bonus_tokens =\
self._expand_execute_model_request(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
# Run model sample_len times.
model_outputs: List[SamplerOutput] = []
if isinstance(self.model_runner, TP1DraftModelRunner):
if isinstance(
self.model_runner, TP1DraftModelRunner
) and self.model_runner.supports_gpu_multi_step(expanded_request):
# Here we run the draft_model_runner with multi-step prepare
# on the GPU directly
expanded_request.num_steps = sample_len
model_outputs = self.execute_model(
execute_model_req=expanded_request)
else:
# TODO: Remove this branch once DraftModelRunner supports TP>1.
# Here we run multi-step directly, with every step prepared
# on the CPU.
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
for _ in range(sample_len):
model_output: List[SamplerOutput] = super().execute_model(
execute_model_req=expanded_request)
@@ -171,7 +180,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
outputs=[
expanded_batch_output.outputs[i]
for i in output_indices_to_retain
],
] if len(expanded_batch_output.outputs) > 0 else [],
sampled_token_probs=(
expanded_batch_output.
sampled_token_probs[output_indices_to_retain]