[Model Runner V2] Rebuild attention metadata before eagle decode full… (#38311)
Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
This commit is contained in:
@@ -19,6 +19,9 @@ from vllm.v1.worker.gpu.attn_utils import (
|
||||
init_attn_backend,
|
||||
)
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.cudagraph_utils import (
|
||||
BatchExecutionDescriptor,
|
||||
)
|
||||
from vllm.v1.worker.gpu.dp_utils import sync_cudagraph_and_dp_padding
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
@@ -239,6 +242,66 @@ class EagleSpeculator:
|
||||
idx_mapping, query_start_loc, pos, num_tokens_padded
|
||||
)
|
||||
|
||||
def _dispatch_and_sync_dp(
|
||||
self,
|
||||
cudagraph_manager: EagleCudaGraphManager,
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
uniform_token_count: int | None,
|
||||
) -> tuple[BatchExecutionDescriptor, torch.Tensor | None]:
|
||||
batch_desc = cudagraph_manager.dispatch(
|
||||
num_reqs, num_tokens, uniform_token_count
|
||||
)
|
||||
num_tokens_across_dp = None
|
||||
if self.dp_size > 1:
|
||||
batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding(
|
||||
cudagraph_manager,
|
||||
batch_desc,
|
||||
num_tokens,
|
||||
num_reqs,
|
||||
uniform_token_count,
|
||||
self.dp_size,
|
||||
self.dp_rank,
|
||||
)
|
||||
return batch_desc, num_tokens_across_dp
|
||||
|
||||
def _build_draft_attn_metadata(
|
||||
self,
|
||||
num_reqs: int,
|
||||
num_reqs_padded: int,
|
||||
num_tokens_padded: int,
|
||||
max_query_len: int,
|
||||
) -> dict[str, Any] | None:
|
||||
if not self.draft_attn_layer_names:
|
||||
return None
|
||||
|
||||
query_start_loc_cpu = (
|
||||
torch.arange(num_reqs_padded + 1, dtype=torch.int32, device="cpu").clamp_(
|
||||
max=num_reqs
|
||||
)
|
||||
* max_query_len
|
||||
)
|
||||
block_tables = [
|
||||
x[:num_reqs_padded] for x in self.block_tables.input_block_tables
|
||||
]
|
||||
slot_mappings = self.block_tables.slot_mappings[:, :num_tokens_padded]
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=self.attn_groups,
|
||||
num_reqs=num_reqs_padded,
|
||||
num_tokens=num_tokens_padded,
|
||||
query_start_loc_gpu=self.input_buffers.query_start_loc[
|
||||
: num_reqs_padded + 1
|
||||
],
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=max_query_len,
|
||||
seq_lens=self.input_buffers.seq_lens[:num_reqs_padded],
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def capture_model(self) -> None:
|
||||
if self.num_speculative_steps == 1:
|
||||
return
|
||||
@@ -319,7 +382,6 @@ class EagleSpeculator:
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
|
||||
num_reqs = input_batch.num_reqs
|
||||
num_reqs_padded = input_batch.num_reqs_after_padding
|
||||
# NOTE(woosuk): For draft sampling, we only consider the temperature
|
||||
# and ignore the other sampling parameters such as top_k and top_p,
|
||||
# for simplicity and performance.
|
||||
@@ -366,69 +428,49 @@ class EagleSpeculator:
|
||||
self.max_num_reqs,
|
||||
)
|
||||
|
||||
# Get batch descriptor and sync across DP ranks.
|
||||
# Eagle uses FULL-only mode, dispatch with uniform_token_count=1 for decode
|
||||
# Each request produces exactly 1 token per draft decode step,
|
||||
# enabling FULL cudagraph.
|
||||
decode_batch_desc, num_tokens_across_dp = self._dispatch_and_sync_dp(
|
||||
self.cudagraph_manager,
|
||||
num_reqs,
|
||||
num_reqs,
|
||||
uniform_token_count=1,
|
||||
)
|
||||
|
||||
batch_desc = self.cudagraph_manager.dispatch(num_reqs, num_reqs, 1)
|
||||
num_tokens_across_dp = None
|
||||
|
||||
if self.dp_size > 1:
|
||||
batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding(
|
||||
self.cudagraph_manager,
|
||||
batch_desc,
|
||||
num_reqs,
|
||||
num_reqs,
|
||||
1, # uniform_token_count
|
||||
self.dp_size,
|
||||
self.dp_rank,
|
||||
)
|
||||
|
||||
if not (dummy_run and skip_attn_for_dummy_run):
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
idx_mapping, query_start_loc, pos, batch_desc.num_tokens
|
||||
)
|
||||
|
||||
if batch_desc.cg_mode == CUDAGraphMode.FULL:
|
||||
return self.cudagraph_manager.run_fullgraph(batch_desc)[:num_reqs]
|
||||
|
||||
# Run eager or piecewise CUDA graph.
|
||||
attn_metadata_updated = None
|
||||
slot_mappings_updated = None
|
||||
if not (dummy_run and skip_attn_for_dummy_run):
|
||||
query_start_loc_cpu = torch.arange(
|
||||
num_reqs_padded + 1, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
block_tables = [
|
||||
x[:num_reqs_padded] for x in self.block_tables.input_block_tables
|
||||
]
|
||||
|
||||
# FIXME(woosuk): This is UNSAFE!!
|
||||
attn_metadata_updated = build_attn_metadata(
|
||||
attn_groups=self.attn_groups,
|
||||
num_reqs=num_reqs_padded,
|
||||
num_tokens=num_reqs_padded,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=1,
|
||||
seq_lens=self.input_buffers.seq_lens[:num_reqs_padded],
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
# Build attention metadata and slot mappings for the draft
|
||||
# decode steps. It is necessary to rebuild the attention
|
||||
# metadata even when replaying the FULL cudagraph so that
|
||||
# any attention metadata builder state is updated.
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
idx_mapping,
|
||||
self.input_buffers.query_start_loc[: num_reqs + 1],
|
||||
pos,
|
||||
decode_batch_desc.num_tokens,
|
||||
)
|
||||
slot_mappings_updated = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
attn_metadata_updated = self._build_draft_attn_metadata(
|
||||
num_reqs=num_reqs,
|
||||
num_reqs_padded=decode_batch_desc.num_reqs or num_reqs,
|
||||
num_tokens_padded=decode_batch_desc.num_tokens,
|
||||
max_query_len=1,
|
||||
)
|
||||
|
||||
self.generate_draft(
|
||||
num_reqs,
|
||||
batch_desc.num_tokens,
|
||||
attn_metadata_updated,
|
||||
slot_mappings_updated,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=batch_desc.cg_mode,
|
||||
)
|
||||
if decode_batch_desc.cg_mode == CUDAGraphMode.FULL:
|
||||
self.cudagraph_manager.run_fullgraph(decode_batch_desc)
|
||||
else:
|
||||
self.generate_draft(
|
||||
num_reqs,
|
||||
decode_batch_desc.num_tokens,
|
||||
attn_metadata_updated,
|
||||
slot_mappings_updated,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=decode_batch_desc.cg_mode,
|
||||
)
|
||||
return self.draft_tokens[:num_reqs]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user