[Core] draft_model_runner: Implement prepare_inputs on GPU for advance_step (#6338)
This commit is contained in:
committed by
GitHub
parent
5f0b9933e6
commit
e76466dde2
@@ -2,17 +2,22 @@ from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import (IntermediateTensors, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
|
||||
SamplerOutput)
|
||||
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
|
||||
ModelRunner)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
debug_advance_input = False
|
||||
enable_gpu_advance_step = True
|
||||
|
||||
|
||||
class TP1DraftModelRunner(ModelRunner):
|
||||
"""Specialized model runner for speculative decoding draft model.
|
||||
@@ -21,18 +26,9 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
we could get rid of most CPU-GPU synchronization and data transfer
|
||||
overheads by keeping model input and output tensors on GPU all the time.
|
||||
|
||||
This runner is still under development so there's no performance gain
|
||||
at this moment. Currently we adopt a temporary solution that caches the
|
||||
seq_group_metadata_list for multi-step execution, so that we can
|
||||
leverage existing prepare_model_input to be compatible with the current
|
||||
execution flow, but we plan to remove this cache and avoid calling
|
||||
prepare_model_input in execute_model at all.
|
||||
|
||||
The detail development plan includes:
|
||||
1. Use "update_model_input" to update existing model_input without
|
||||
creating a new one.
|
||||
2. Improve the performance of "update_model_input" with a GPU kernel.
|
||||
3. Support TP > 1 (this requires some designs because we do not expect
|
||||
TODOs:
|
||||
1. Currently supports only flash-attn, add support for other attn_backends.
|
||||
2. Support TP > 1 (this requires some designs because we do not expect
|
||||
any broadcasting inside execute_model).
|
||||
"""
|
||||
|
||||
@@ -71,51 +67,156 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
return_hidden_states=return_hidden_states,
|
||||
)
|
||||
|
||||
# TODO: Remove this cache when we are able to update model_input
|
||||
# directly in advance_step.
|
||||
self.cached_seq_group_metadata_list: Optional[
|
||||
List[SequenceGroupMetadata]] = None
|
||||
def _update_flash_attn_metadata(self, attn_metadata, num_seqs,
|
||||
num_queries):
|
||||
assert isinstance(attn_metadata, FlashAttentionMetadata)
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> ModelInputForGPUWithSamplingMetadata:
|
||||
"""A temporary solution that caches the seq_group_metadata_list
|
||||
for multi-step execution.
|
||||
TODO: In-place update model_input and remove this function.
|
||||
"""
|
||||
self.cached_seq_group_metadata_list = seq_group_metadata_list
|
||||
return super().prepare_model_input(
|
||||
seq_group_metadata_list,
|
||||
finished_requests_ids=finished_requests_ids)
|
||||
if num_seqs != num_queries:
|
||||
assert num_seqs > num_queries
|
||||
assert attn_metadata.use_cuda_graph
|
||||
|
||||
def update_model_input(
|
||||
assert attn_metadata.num_prefills == 0
|
||||
assert attn_metadata.num_prefill_tokens == 0
|
||||
assert attn_metadata.num_decode_tokens == num_seqs
|
||||
assert attn_metadata.slot_mapping.shape == (num_seqs, )
|
||||
|
||||
assert len(attn_metadata.seq_lens) == num_seqs
|
||||
assert attn_metadata.seq_lens_tensor.shape == (num_seqs, )
|
||||
assert attn_metadata.max_query_len == 1
|
||||
assert attn_metadata.max_prefill_seq_len == 0
|
||||
assert attn_metadata.max_decode_seq_len == max(attn_metadata.seq_lens)
|
||||
|
||||
assert attn_metadata.query_start_loc.shape == (num_queries + 1, )
|
||||
assert attn_metadata.seq_start_loc.shape == (num_seqs + 1, )
|
||||
|
||||
assert attn_metadata.context_lens_tensor.shape == (num_queries, )
|
||||
|
||||
assert attn_metadata.block_tables.shape[0] == num_seqs
|
||||
|
||||
# Update query lengths. Note that we update only queries and not seqs,
|
||||
# since tensors may be padded due to captured cuda graph batch size
|
||||
for i in range(num_queries):
|
||||
attn_metadata.seq_lens[i] += 1
|
||||
attn_metadata.max_decode_seq_len = max(attn_metadata.seq_lens)
|
||||
|
||||
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
|
||||
num_queries):
|
||||
|
||||
assert sampling_metadata.num_prompts == 0
|
||||
assert len(sampling_metadata.seq_groups) == num_queries
|
||||
assert sampling_metadata.selected_token_indices.shape == (
|
||||
num_queries, )
|
||||
# assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
|
||||
|
||||
# Verify that all sequences are decodes
|
||||
for i in range(num_queries):
|
||||
seq_group = sampling_metadata.seq_groups[i]
|
||||
|
||||
assert seq_group.is_prompt is False # No prompt
|
||||
assert seq_group.prompt_logprob_indices == [] # No prompt
|
||||
assert seq_group.sample_indices == [i] # Simple
|
||||
assert seq_group.seq_len is None # Decode
|
||||
assert seq_group.query_len is None # Decode
|
||||
|
||||
def _gpu_advance_step(
|
||||
self, model_input: ModelInputForGPUWithSamplingMetadata,
|
||||
last_output: SamplerOutput
|
||||
) -> ModelInputForGPUWithSamplingMetadata:
|
||||
"""Prepare the model inputs for the next step.
|
||||
TODO: In-place update model_input instead of calling
|
||||
prepare_model_input.
|
||||
# Currently, we expect "decode mode" only
|
||||
assert not model_input.is_prompt
|
||||
|
||||
# Get num_seqs
|
||||
num_seqs = len(model_input.seq_lens)
|
||||
num_queries = len(model_input.query_lens)
|
||||
|
||||
# Get output tokens GPU tensor
|
||||
sampled_token_ids = last_output.sampled_token_ids
|
||||
assert sampled_token_ids is not None
|
||||
|
||||
# Update attn_metadata
|
||||
attn_metadata = model_input.attn_metadata
|
||||
assert isinstance(attn_metadata, FlashAttentionMetadata)
|
||||
self._update_flash_attn_metadata(attn_metadata, num_seqs, num_queries)
|
||||
|
||||
# Update GPU tensors
|
||||
ops.advance_step(num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=self.block_size,
|
||||
input_tokens=model_input.input_tokens,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
input_positions=model_input.input_positions,
|
||||
seq_lens=attn_metadata.seq_lens_tensor,
|
||||
slot_mapping=attn_metadata.slot_mapping,
|
||||
block_tables=attn_metadata.block_tables)
|
||||
|
||||
# Update sampling_metadata
|
||||
sampling_metadata = model_input.sampling_metadata
|
||||
self._update_sampling_metadata(sampling_metadata, num_seqs,
|
||||
num_queries)
|
||||
|
||||
# Create new input
|
||||
new_model_input = self._model_input_cls(
|
||||
input_tokens=model_input.input_tokens,
|
||||
input_positions=model_input.input_positions,
|
||||
attn_metadata=attn_metadata,
|
||||
seq_lens=attn_metadata.seq_lens,
|
||||
query_lens=model_input.query_lens,
|
||||
lora_mapping=model_input.lora_mapping,
|
||||
lora_requests=model_input.lora_requests,
|
||||
multi_modal_kwargs=model_input.multi_modal_kwargs,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
is_prompt=False,
|
||||
)
|
||||
|
||||
# Ensure we skip CPU samples
|
||||
assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True
|
||||
# We can reuse sampling tensors since every decode iteration is the same
|
||||
new_model_input.sampling_metadata.reuse_sampling_tensors = True
|
||||
|
||||
if debug_advance_input:
|
||||
logger.debug("NEW INPUT: ")
|
||||
logger.debug(" input_tokens = %s", new_model_input.input_tokens)
|
||||
logger.debug(" input_positions = %s",
|
||||
new_model_input.input_positions)
|
||||
logger.debug(" seq_lens = %d", new_model_input.seq_lens)
|
||||
logger.debug(" query_lens = %d", new_model_input.query_lens)
|
||||
logger.debug(" attn_metadata:")
|
||||
logger.debug(" seq_lens_tensor: %s",
|
||||
attn_metadata.seq_lens_tensor)
|
||||
logger.debug(" slot_mapping: %s", attn_metadata.slot_mapping)
|
||||
logger.debug(" block_tables: %s", attn_metadata.block_tables)
|
||||
|
||||
return new_model_input
|
||||
|
||||
def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
|
||||
"""Determines if draft_model_runner GPU multi-step can be used.
|
||||
Currently required conditions are:
|
||||
1. Only decodes
|
||||
2. Only flash-attn
|
||||
3. No LORA
|
||||
4. No prompt_adapter_config
|
||||
"""
|
||||
if not enable_gpu_advance_step:
|
||||
return False
|
||||
|
||||
# Append the output token to the sequence data.
|
||||
assert self.cached_seq_group_metadata_list is not None
|
||||
for seq_group_metadata, sequence_group_outputs in zip(
|
||||
self.cached_seq_group_metadata_list, last_output.outputs):
|
||||
seq_group_metadata.is_prompt = False
|
||||
# We allow multi-step GPU only in decode mode
|
||||
for seq_group in execute_model_req.seq_group_metadata_list:
|
||||
if seq_group.is_prompt:
|
||||
return False
|
||||
|
||||
for seq_output in sequence_group_outputs.samples:
|
||||
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
|
||||
# TODO: Add support for other attn backends
|
||||
if self.attn_backend.get_name() != "flash-attn":
|
||||
return False
|
||||
|
||||
token_id = seq_output.output_token
|
||||
token_logprob = seq_output.logprobs[token_id]
|
||||
# TODO: Add support for LORA
|
||||
if self.lora_config:
|
||||
return False
|
||||
|
||||
seq.append_token_id(token_id, token_logprob.logprob)
|
||||
seq.update_num_computed_tokens(1)
|
||||
# TODO: Add soft-tuning prompt adapter support
|
||||
if self.prompt_adapter_config:
|
||||
return False
|
||||
|
||||
return self.prepare_model_input(self.cached_seq_group_metadata_list)
|
||||
return True
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
@@ -125,42 +226,86 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
# Since we do not broadcast data inside execute_model anymore,
|
||||
# we need to figure out the best way to support TP > 1 in this
|
||||
# case, because we will at least need to broadcast the sampled
|
||||
# tokens to all workers.
|
||||
if not self.is_driver_worker:
|
||||
raise ValueError("TP1DraftModelRunner only supports TP=1.")
|
||||
"""Executes num_steps forward passes with advacement of input tensors
|
||||
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
|
||||
|
||||
if self.lora_config:
|
||||
assert model_input.lora_requests is not None
|
||||
assert model_input.lora_mapping is not None
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
Optimizations used:
|
||||
1. Input tensors are updated on the GPU directly
|
||||
2. Skips GPU=>CPU serialization of sampler outputs (we don't need
|
||||
them since we do batch expansion later that uses GPU outputs)
|
||||
3. Reuses sampling tensors (since we run only decodes and they have
|
||||
a repeating sampling logic)
|
||||
"""
|
||||
|
||||
if self.prompt_adapter_config:
|
||||
assert model_input.prompt_adapter_requests is not None
|
||||
assert model_input.prompt_adapter_mapping is not None
|
||||
self.set_active_prompt_adapters(
|
||||
model_input.prompt_adapter_requests,
|
||||
model_input.prompt_adapter_mapping)
|
||||
# When num_steps == 1, we execute the fallback here for the GPU
|
||||
# advance_step, which runs prepare_inputs on CPU and for each spec
|
||||
# iteration invokes this function only once
|
||||
# (Look at multi-step-worker code)
|
||||
is_fallback = num_steps == 1
|
||||
if not is_fallback:
|
||||
# Since we do not broadcast data inside execute_model anymore,
|
||||
# we need to figure out the best way to support TP > 1 in this
|
||||
# case, because we will at least need to broadcast the sampled
|
||||
# tokens to all workers.
|
||||
if not self.is_driver_worker:
|
||||
raise ValueError("TP1DraftModelRunner only supports TP=1.")
|
||||
|
||||
# Sanity
|
||||
if self.lora_config is not None:
|
||||
raise ValueError("TP1DraftModelRunner has no support for LORA")
|
||||
if self.prompt_adapter_config is not None:
|
||||
raise ValueError("TP1DraftModelRunner has no support for "
|
||||
"prompt_adapter_config")
|
||||
if model_input.multi_modal_kwargs:
|
||||
raise ValueError(
|
||||
"TP1DraftModelRunner has no support for multi_modal_kwargs"
|
||||
)
|
||||
else:
|
||||
if self.lora_config:
|
||||
assert model_input.lora_requests is not None
|
||||
assert model_input.lora_mapping is not None
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
|
||||
if self.prompt_adapter_config:
|
||||
assert model_input.prompt_adapter_requests is not None
|
||||
assert model_input.prompt_adapter_mapping is not None
|
||||
self.set_active_prompt_adapters(
|
||||
model_input.prompt_adapter_requests,
|
||||
model_input.prompt_adapter_mapping)
|
||||
|
||||
# Detect exec mode
|
||||
assert model_input.attn_metadata is not None
|
||||
use_cuda_graph = False
|
||||
if model_input.attn_metadata.num_prefills > 0:
|
||||
# In this case, execute_model(..) was called directly
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"execute_model(..) of draft_model_runner can be called "
|
||||
"directly only with a single-step prefill")
|
||||
else:
|
||||
# We can skip CPU samples for spec token generation.
|
||||
# (We do allow CPU samples for num_steps == 1 to support the
|
||||
# fallback case, where supports_gpu_multi_step(..) does not pass)
|
||||
model_input.sampling_metadata.skip_sampler_cpu_output = (
|
||||
not is_fallback)
|
||||
|
||||
# Attn attr defines if we use cuda graphs
|
||||
use_cuda_graph = model_input.attn_metadata.use_cuda_graph
|
||||
|
||||
# Get model
|
||||
if use_cuda_graph:
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = (self.graph_runners[model_input.virtual_engine]
|
||||
[graph_batch_size])
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
virtual_engine = model_input.virtual_engine
|
||||
outputs: List[SamplerOutput] = []
|
||||
for step in range(num_steps):
|
||||
# Currently cuda graph is only supported by the decode phase.
|
||||
assert model_input.attn_metadata is not None
|
||||
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||
decode_meta = model_input.attn_metadata.decode_metadata
|
||||
if prefill_meta is None and decode_meta.use_cuda_graph:
|
||||
assert model_input.input_tokens is not None
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = (
|
||||
self.graph_runners[virtual_engine][graph_batch_size])
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
|
||||
# Run model
|
||||
hidden_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
@@ -181,8 +326,8 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
))
|
||||
|
||||
# Prepare the inputs for the next step.
|
||||
# Prepare inputs for the next step
|
||||
if step != num_steps - 1:
|
||||
model_input = self.update_model_input(model_input, outputs[-1])
|
||||
model_input = self._gpu_advance_step(model_input, outputs[-1])
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -67,14 +67,23 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
expanded_request, indices_of_seq_with_bonus_tokens =\
|
||||
self._expand_execute_model_request(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
# Run model sample_len times.
|
||||
model_outputs: List[SamplerOutput] = []
|
||||
if isinstance(self.model_runner, TP1DraftModelRunner):
|
||||
if isinstance(
|
||||
self.model_runner, TP1DraftModelRunner
|
||||
) and self.model_runner.supports_gpu_multi_step(expanded_request):
|
||||
# Here we run the draft_model_runner with multi-step prepare
|
||||
# on the GPU directly
|
||||
expanded_request.num_steps = sample_len
|
||||
model_outputs = self.execute_model(
|
||||
execute_model_req=expanded_request)
|
||||
else:
|
||||
# TODO: Remove this branch once DraftModelRunner supports TP>1.
|
||||
# Here we run multi-step directly, with every step prepared
|
||||
# on the CPU.
|
||||
# TODO: Remove this branch once DraftModelRunner supports TP>1
|
||||
# and other restrictions that are part of DraftModelRunner's
|
||||
# supports_gpu_multi_step(..)
|
||||
for _ in range(sample_len):
|
||||
model_output: List[SamplerOutput] = super().execute_model(
|
||||
execute_model_req=expanded_request)
|
||||
@@ -171,7 +180,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
outputs=[
|
||||
expanded_batch_output.outputs[i]
|
||||
for i in output_indices_to_retain
|
||||
],
|
||||
] if len(expanded_batch_output.outputs) > 0 else [],
|
||||
sampled_token_probs=(
|
||||
expanded_batch_output.
|
||||
sampled_token_probs[output_indices_to_retain]
|
||||
|
||||
Reference in New Issue
Block a user