[Model Runner V2] Rebuild attention metadata before eagle decode full… (#38311)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
This commit is contained in:
Giancarlo Delfin
2026-03-27 13:46:42 -07:00
committed by GitHub
parent 44a6528028
commit 384e4d5f48

View File

@@ -19,6 +19,9 @@ from vllm.v1.worker.gpu.attn_utils import (
init_attn_backend,
)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cudagraph_utils import (
BatchExecutionDescriptor,
)
from vllm.v1.worker.gpu.dp_utils import sync_cudagraph_and_dp_padding
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState
@@ -239,6 +242,66 @@ class EagleSpeculator:
idx_mapping, query_start_loc, pos, num_tokens_padded
)
def _dispatch_and_sync_dp(
self,
cudagraph_manager: EagleCudaGraphManager,
num_reqs: int,
num_tokens: int,
uniform_token_count: int | None,
) -> tuple[BatchExecutionDescriptor, torch.Tensor | None]:
batch_desc = cudagraph_manager.dispatch(
num_reqs, num_tokens, uniform_token_count
)
num_tokens_across_dp = None
if self.dp_size > 1:
batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding(
cudagraph_manager,
batch_desc,
num_tokens,
num_reqs,
uniform_token_count,
self.dp_size,
self.dp_rank,
)
return batch_desc, num_tokens_across_dp
def _build_draft_attn_metadata(
self,
num_reqs: int,
num_reqs_padded: int,
num_tokens_padded: int,
max_query_len: int,
) -> dict[str, Any] | None:
if not self.draft_attn_layer_names:
return None
query_start_loc_cpu = (
torch.arange(num_reqs_padded + 1, dtype=torch.int32, device="cpu").clamp_(
max=num_reqs
)
* max_query_len
)
block_tables = [
x[:num_reqs_padded] for x in self.block_tables.input_block_tables
]
slot_mappings = self.block_tables.slot_mappings[:, :num_tokens_padded]
attn_metadata = build_attn_metadata(
attn_groups=self.attn_groups,
num_reqs=num_reqs_padded,
num_tokens=num_tokens_padded,
query_start_loc_gpu=self.input_buffers.query_start_loc[
: num_reqs_padded + 1
],
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,
seq_lens=self.input_buffers.seq_lens[:num_reqs_padded],
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
)
return attn_metadata
def capture_model(self) -> None:
if self.num_speculative_steps == 1:
return
@@ -319,7 +382,6 @@ class EagleSpeculator:
logits = self.model.compute_logits(sample_hidden_states)
num_reqs = input_batch.num_reqs
num_reqs_padded = input_batch.num_reqs_after_padding
# NOTE(woosuk): For draft sampling, we only consider the temperature
# and ignore the other sampling parameters such as top_k and top_p,
# for simplicity and performance.
@@ -366,69 +428,49 @@ class EagleSpeculator:
self.max_num_reqs,
)
# Get batch descriptor and sync across DP ranks.
# Eagle uses FULL-only mode, dispatch with uniform_token_count=1 for decode
# Each request produces exactly 1 token per draft decode step,
# enabling FULL cudagraph.
decode_batch_desc, num_tokens_across_dp = self._dispatch_and_sync_dp(
self.cudagraph_manager,
num_reqs,
num_reqs,
uniform_token_count=1,
)
batch_desc = self.cudagraph_manager.dispatch(num_reqs, num_reqs, 1)
num_tokens_across_dp = None
if self.dp_size > 1:
batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding(
self.cudagraph_manager,
batch_desc,
num_reqs,
num_reqs,
1, # uniform_token_count
self.dp_size,
self.dp_rank,
)
if not (dummy_run and skip_attn_for_dummy_run):
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos, batch_desc.num_tokens
)
if batch_desc.cg_mode == CUDAGraphMode.FULL:
return self.cudagraph_manager.run_fullgraph(batch_desc)[:num_reqs]
# Run eager or piecewise CUDA graph.
attn_metadata_updated = None
slot_mappings_updated = None
if not (dummy_run and skip_attn_for_dummy_run):
query_start_loc_cpu = torch.arange(
num_reqs_padded + 1, dtype=torch.int32, device="cpu"
)
block_tables = [
x[:num_reqs_padded] for x in self.block_tables.input_block_tables
]
# FIXME(woosuk): This is UNSAFE!!
attn_metadata_updated = build_attn_metadata(
attn_groups=self.attn_groups,
num_reqs=num_reqs_padded,
num_tokens=num_reqs_padded,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=1,
seq_lens=self.input_buffers.seq_lens[:num_reqs_padded],
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
# Build attention metadata and slot mappings for the draft
# decode steps. It is necessary to rebuild the attention
# metadata even when replaying the FULL cudagraph so that
# any attention metadata builder state is updated.
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping,
self.input_buffers.query_start_loc[: num_reqs + 1],
pos,
decode_batch_desc.num_tokens,
)
slot_mappings_updated = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
attn_metadata_updated = self._build_draft_attn_metadata(
num_reqs=num_reqs,
num_reqs_padded=decode_batch_desc.num_reqs or num_reqs,
num_tokens_padded=decode_batch_desc.num_tokens,
max_query_len=1,
)
self.generate_draft(
num_reqs,
batch_desc.num_tokens,
attn_metadata_updated,
slot_mappings_updated,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=batch_desc.cg_mode,
)
if decode_batch_desc.cg_mode == CUDAGraphMode.FULL:
self.cudagraph_manager.run_fullgraph(decode_batch_desc)
else:
self.generate_draft(
num_reqs,
decode_batch_desc.num_tokens,
attn_metadata_updated,
slot_mappings_updated,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=decode_batch_desc.cg_mode,
)
return self.draft_tokens[:num_reqs]