[BugFix] Add support for MTP num_speculative_tokens > 1 with sparse MLA (#34552)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Lucas Wilkinson
2026-03-03 10:21:57 -05:00
committed by GitHub
parent fb7fdc49c4
commit 28ef9ba399
7 changed files with 260 additions and 197 deletions

View File

@@ -476,12 +476,12 @@ def test_set_inputs_first_pass_draft_model():
proposer.max_num_tokens, dtype=torch.bool, device=device
)
# Mock the attn_metadata_builder to avoid needing the full model setup
# Mock draft_attn_groups to avoid needing the full model setup
mock_kv_cache_spec = mock.MagicMock()
mock_kv_cache_spec.block_size = block_size
mock_builder = mock.MagicMock()
mock_builder.kv_cache_spec = mock_kv_cache_spec
proposer.attn_metadata_builder = mock_builder
mock_attn_group = mock.MagicMock()
mock_attn_group.kv_cache_spec = mock_kv_cache_spec
proposer.draft_attn_groups = [mock_attn_group]
# Request 0: query_len=3 (but 1 rejected), Request 1: query_len=2
batch_spec = BatchSpec(
@@ -616,12 +616,12 @@ def test_set_inputs_first_pass_parallel_drafting():
proposer.max_num_tokens, dtype=torch.bool, device=device
)
# Mock the attn_metadata_builder
# Mock draft_attn_groups
mock_kv_cache_spec = mock.MagicMock()
mock_kv_cache_spec.block_size = block_size
mock_builder = mock.MagicMock()
mock_builder.kv_cache_spec = mock_kv_cache_spec
proposer.attn_metadata_builder = mock_builder
mock_attn_group = mock.MagicMock()
mock_attn_group.kv_cache_spec = mock_kv_cache_spec
proposer.draft_attn_groups = [mock_attn_group]
# Request 0: query_len=4 (1 rejected), Request 1: query_len=4 (all valid)
batch_spec = BatchSpec(
@@ -916,7 +916,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
proposer.model = model_mock
# Assign draft attn_layer_names since load_model is not invoked
proposer.attn_layer_names = ["layer.0"]
proposer._draft_attn_layer_names = {"layer.0"}
# Create input tensors
batch_spec = BatchSpec(
@@ -961,20 +961,18 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,
layer_names=proposer._draft_attn_layer_names,
vllm_config=proposer.vllm_config,
device=device,
)
# Mock runner for attention metadata building
# Mock runner and draft_attn_groups for attention metadata building
proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][
0
].get_metadata_builder.return_value = attn_metadata_builder
proposer._get_attention_metadata_builder = mock.MagicMock(
return_value=attn_metadata_builder
)
mock_attn_group = mock.MagicMock()
mock_attn_group.get_metadata_builder.return_value = attn_metadata_builder
mock_attn_group.layer_names = list(proposer._draft_attn_layer_names)
mock_attn_group.kv_cache_spec = attn_metadata_builder.kv_cache_spec
proposer.draft_attn_groups = [mock_attn_group]
result = proposer.propose(
target_token_ids=target_token_ids,
@@ -1089,7 +1087,7 @@ def test_propose_tree(spec_token_tree):
proposer.model = model_mock
# Assign draft attn_layer_names since load_model is not invoked
proposer.attn_layer_names = ["layer.0"]
proposer._draft_attn_layer_names = {"layer.0"}
# Get the tree attention metadata builder.
attn_metadata_builder_cls, _ = try_get_attention_backend(
@@ -1097,21 +1095,18 @@ def test_propose_tree(spec_token_tree):
)
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,
layer_names=proposer._draft_attn_layer_names,
vllm_config=proposer.vllm_config,
device=device,
)
# Mock runner for attention metadata building.
# Mock runner and draft_attn_groups for attention metadata building.
proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].metadata_builders = [attn_metadata_builder]
proposer.runner.attn_groups[0][
0
].get_metadata_builder.return_value = attn_metadata_builder
proposer._get_attention_metadata_builder = mock.MagicMock(
return_value=attn_metadata_builder
)
mock_attn_group = mock.MagicMock()
mock_attn_group.get_metadata_builder.return_value = attn_metadata_builder
mock_attn_group.layer_names = list(proposer._draft_attn_layer_names)
mock_attn_group.kv_cache_spec = attn_metadata_builder.kv_cache_spec
proposer.draft_attn_groups = [mock_attn_group]
# Setup inputs for the proposer.
target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device)

View File

@@ -162,7 +162,7 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
model_mock.compute_logits.side_effect = logits_returns
proposer.model = model_mock
proposer.attn_layer_names = ["layer.0"]
proposer._draft_attn_layer_names = {"layer.0"}
# Prepare inputs
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens)
@@ -190,13 +190,17 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,
layer_names=list(proposer._draft_attn_layer_names),
vllm_config=proposer.vllm_config,
device=device,
)
proposer.runner = mock.MagicMock()
proposer.attn_metadata_builder = attn_metadata_builder
mock_attn_group = mock.MagicMock()
mock_attn_group.get_metadata_builder.return_value = attn_metadata_builder
mock_attn_group.layer_names = list(proposer._draft_attn_layer_names)
mock_attn_group.kv_cache_spec = attn_metadata_builder.kv_cache_spec
proposer.draft_attn_groups = [mock_attn_group]
# Run propose
result = proposer.propose(

View File

@@ -79,6 +79,12 @@ def sparse_attn_indexer(
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
# During speculative decoding, k may be padded to the CUDA graph batch
# size while slot_mapping only covers actual tokens. Truncate k to avoid
# out-of-bounds reads in the kernel.
num_tokens = slot_mapping.shape[0]
k = k[:num_tokens]
ops.indexer_k_quant_and_cache(
k,
kv_cache,

View File

@@ -12,6 +12,7 @@ from vllm.utils.deep_gemm import (
get_paged_mqa_logits_metadata,
is_deep_gemm_supported,
)
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import num_compute_units
from vllm.v1.attention.backend import (
AttentionBackend,
@@ -24,6 +25,7 @@ from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills,
split_prefill_chunks,
)
from vllm.v1.worker.cp_utils import get_total_cp_world_size
logger = init_logger(__name__)
@@ -214,20 +216,39 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
if self.vllm_config.speculative_config
else 0
)
if self.num_speculative_tokens > 1:
raise ValueError(
"Sparse MLA only supports "
"num_speculative_tokens <= 1 because the DeepGEMM "
"fp8_paged_mqa_logits kernel does not support next_n > 2. "
f"Got num_speculative_tokens={self.num_speculative_tokens}."
)
self.reorder_batch_threshold += self.num_speculative_tokens
sm_count = num_compute_units(self.device.index)
self.num_sms = sm_count
self.decode_lens_buffer = torch.empty(
(scheduler_config.max_num_seqs,), dtype=torch.int32, device=self.device
(scheduler_config.max_num_batched_tokens,),
dtype=torch.int32,
device=self.device,
)
# Pre-allocated buffers for flattening (spec decode).
self.arange_buffer = torch.arange(
scheduler_config.max_num_seqs * (1 + self.num_speculative_tokens),
dtype=torch.int32,
device=self.device,
)
self.expanded_seq_lens_buffer = torch.zeros(
(scheduler_config.max_num_batched_tokens,),
dtype=torch.int32,
device=self.device,
)
max_num_blocks_per_req = cdiv(
self.vllm_config.model_config.max_model_len,
self.kv_cache_spec.block_size * get_total_cp_world_size(),
)
self.expanded_block_table_buffer = torch.zeros(
(
scheduler_config.max_num_batched_tokens,
max_num_blocks_per_req,
),
dtype=torch.int32,
device=self.device,
)
# See: DeepGMM/csrc/apis/attention.hpp
@@ -326,42 +347,97 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
)
# Use CPU to avoid GPU sync; breaking async scheduling
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
# Decide which top-k kernel to use based on batch size and sequence length
batch_size = num_decodes
_is_large_context = common_attn_metadata.max_seq_len > 8192
# Decision logic based on micro-benchmark results:
# - large_context_topk wins for batch <= 128 and seq_len > 8K
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
use_large_context_topk = batch_size <= 128 and _is_large_context
next_n = 1 + self.num_speculative_tokens
if next_n > 1:
offsets = torch.arange(next_n, device=self.device, dtype=torch.int32)
else:
offsets = None
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
# DeepGEMM is required for the paged MQA logits on CUDA devices
if current_platform.is_cuda() and is_deep_gemm_supported():
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms
)
block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...]
# Padded CUDA graph requests have block_table entries of -1.
# Clamp to 0 to prevent OOB access in the DeepGEMM kernel.
# This is safe because padded requests have seq_lens=0, so the
# kernel produces no meaningful output for those rows.
block_table.clamp_(min=0)
max_decode_len = int(decode_lens_cpu.max().item())
if max_decode_len > 1:
# Flatten multi-token decode requests into single-token
# batch entries, expanding seq_lens and block tables so
# the kernel always sees next_n=1.
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
# The context lengths are therefore
# [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0].
# 3 + 1 + 4 + 0 = 8
actual_expanded = int(decode_lens_cpu.sum().item())
# [7, 6, 8, 0] -> [7, 7, 7, 6, 8, 8, 8, 8]
expanded_base = torch.repeat_interleave(
seq_lens - decode_lens, decode_lens
)
# [0, 3, 4, 8] -> [0, 0, 0, 3, 4, 4, 4, 4]
expanded_starts = torch.repeat_interleave(
common_attn_metadata.query_start_loc[:num_decodes], decode_lens
)
# [0, 1, 2, 0, 0, 1, 2, 3]
positions_within = (
self.arange_buffer[:actual_expanded] - expanded_starts
)
# [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space
self.expanded_seq_lens_buffer[:actual_expanded] = (
expanded_base + positions_within + 1
)
self.expanded_seq_lens_buffer[actual_expanded:] = 0
seq_lens = self.expanded_seq_lens_buffer[:num_decode_tokens]
# Give each of the flattened entries the same block table row as the
# original request.
self.expanded_block_table_buffer[:actual_expanded] = (
torch.repeat_interleave(block_table, decode_lens, dim=0)
)
if actual_expanded < num_decode_tokens:
self.expanded_block_table_buffer[
actual_expanded:num_decode_tokens, 0
] = 0
block_table = self.expanded_block_table_buffer[:num_decode_tokens]
# All reqs now have decode_len=1
self.decode_lens_buffer[:num_decode_tokens] = 1
decode_lens = self.decode_lens_buffer[:num_decode_tokens]
offsets = None
batch_size = num_decode_tokens
else:
next_n = 1 + self.num_speculative_tokens
if next_n > 1:
offsets = torch.arange(
next_n, device=self.device, dtype=torch.int32
)
else:
offsets = None
batch_size = num_decodes
# DeepGEMM is required for the paged MQA logits on CUDA devices
if current_platform.is_cuda() and is_deep_gemm_supported():
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens,
self.kv_cache_spec.block_size,
self.num_sms,
)
# Decide which top-k kernel to use based on batch size and sequence length
# Decision logic based on micro-benchmark results:
# - large_context_topk wins for batch <= 128 and seq_len > 8K
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
_is_large_context = common_attn_metadata.max_seq_len > 8192
use_large_context_topk = batch_size <= 128 and _is_large_context
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=block_table,
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
seq_lens=seq_lens,
decode_lens=decode_lens,
requires_padding=requires_padding,
requires_padding=False,
schedule_metadata=self.scheduler_metadata_buffer,
use_large_context_topk=use_large_context_topk,
offsets=offsets,

View File

@@ -20,17 +20,13 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backend import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.attention.backend import CommonAttentionMetadata
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.tree_attn import (
TreeAttentionMetadata,
@@ -38,7 +34,7 @@ from vllm.v1.attention.backends.tree_attn import (
)
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.kv_cache_interface import KVCacheConfig, UniformTypeKVCacheSpecs
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import _SAMPLING_EPS
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@@ -53,6 +49,7 @@ from vllm.v1.spec_decode.utils import (
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.utils import AttentionGroup
logger = init_logger(__name__)
@@ -113,10 +110,8 @@ class SpecDecodeBaseProposer:
vllm_config.model_config
)
self.attn_metadata_builder: AttentionMetadataBuilder | None = None
self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None
self.attn_layer_names: list[str] = []
self.indexer_layer_names: list[str] = []
self.draft_attn_groups: list[AttentionGroup] = []
self.kv_cache_gid: int = -1
self.eagle3_use_aux_hidden_state: bool = (
self._get_eagle3_use_aux_hidden_state_from_config()
)
@@ -353,7 +348,7 @@ class SpecDecodeBaseProposer:
self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)
view = self._slot_mapping_buffer[:num_tokens]
return {name: view for name in self.attn_layer_names + self.indexer_layer_names}
return {name: view for name in self._draft_attn_layer_names}
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
"""Initialize cudagraph dispatcher keys for eagle.
@@ -420,33 +415,13 @@ class SpecDecodeBaseProposer:
assert self.runner is not None
if self.attn_metadata_builder is None:
attn_metadata_builder = self._get_attention_metadata_builder()
else:
attn_metadata_builder = self.attn_metadata_builder
attn_metadata = attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0
)
# FIXME: support hybrid kv for draft model (remove separate indexer)
if self.draft_indexer_metadata_builder:
draft_indexer_metadata = (
self.draft_indexer_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=0,
)
per_layer_attn_metadata: dict[str, object] = {}
for attn_group in self.draft_attn_groups:
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0
)
else:
draft_indexer_metadata = None
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
for layer_name in self.indexer_layer_names:
assert draft_indexer_metadata is not None
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
for layer_name in attn_group.layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(num_tokens)
@@ -503,12 +478,7 @@ class SpecDecodeBaseProposer:
positions = self.mrope_positions[:, token_indices_to_sample]
else:
positions = self.positions[token_indices_to_sample]
if self.method in (
"deepseek_mtp",
"ernie_mtp",
"longcat_flash_mtp",
"pangu_ultra_moe_mtp",
):
if self.method == "mtp":
hidden_states = self.hidden_states[token_indices_to_sample]
else:
hidden_states = hidden_states[token_indices_to_sample]
@@ -613,7 +583,8 @@ class SpecDecodeBaseProposer:
common_attn_metadata._num_computed_tokens_cpu += 1
# Compute the slot mapping.
block_size = attn_metadata_builder.kv_cache_spec.block_size
# Use the first draft attention group's kv_cache_spec for block_size
block_size = self.draft_attn_groups[0].kv_cache_spec.block_size
if self.uses_mrope:
# all dimensions of positions are the same
block_numbers = clamped_positions[0] // block_size
@@ -639,11 +610,13 @@ class SpecDecodeBaseProposer:
)
# Rebuild attention metadata
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
)
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
for attn_group in self.draft_attn_groups:
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=token_index + 1,
)
for layer_name in attn_group.layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
@@ -805,18 +778,17 @@ class SpecDecodeBaseProposer:
# 2.
# Recompute the slot mapping based on the new positions and
# rejection mask.
builder = (
self._get_attention_metadata_builder()
if self.attn_metadata_builder is None
else self.attn_metadata_builder
)
# Use the first draft attention group's kv_cache_spec for block_size
# (all draft layers share the same kv-cache group)
assert len(self.draft_attn_groups) > 0
block_size = self.draft_attn_groups[0].kv_cache_spec.block_size
new_slot_mapping = compute_new_slot_mapping(
cad=cad,
new_positions=self.positions[:total_num_output_tokens],
is_rejected_token_mask=self.is_rejected_token_mask[
:total_num_output_tokens
],
block_size=builder.kv_cache_spec.block_size,
block_size=block_size,
num_new_tokens=self.net_num_new_slots_per_request,
max_model_len=self.max_model_len,
)
@@ -1000,9 +972,7 @@ class SpecDecodeBaseProposer:
| list[dict[str, torch.Tensor]]
| None = None,
) -> list[torch.Tensor]:
tree_attn_metadata_builder = self.runner.attn_groups[0][
0
].get_metadata_builder()
tree_attn_metadata_builder = self.draft_attn_groups[0].get_metadata_builder()
assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder)
total_num_drafts = self.cu_drafts_per_level[0]
@@ -1078,10 +1048,11 @@ class SpecDecodeBaseProposer:
common_attn_metadata=common_attn_metadata, draft_index=level + 1
)
# Apply new attention metadata to all layers.
# Apply new attention metadata to all draft layers.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
for attn_group in self.draft_attn_groups:
for layer_name in attn_group.layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
# Consider max model length.
attn_metadata.max_seq_len = min(
@@ -1288,43 +1259,17 @@ class SpecDecodeBaseProposer:
AttentionLayerBase, # type: ignore[type-abstract]
).keys()
)
# FIXME: support hybrid kv for draft model
target_indexer_layer_names = set(
get_layers_from_vllm_config(
self.vllm_config, DeepseekV32IndexerCache
).keys()
)
self.model = self._get_model()
draft_attn_layer_names = (
get_layers_from_vllm_config(
self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
).keys()
- target_attn_layer_names
# Find draft layers (attention layers added by draft model)
all_attn_layers = get_layers_from_vllm_config(
self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
)
indexer_layers = get_layers_from_vllm_config(
self.vllm_config, DeepseekV32IndexerCache
self._draft_attn_layer_names = (
set(all_attn_layers.keys()) - target_attn_layer_names
)
draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names
self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names)
self.indexer_layer_names = list(draft_indexer_layer_names)
if self.indexer_layer_names:
first_layer = self.indexer_layer_names[0]
self.draft_indexer_metadata_builder = (
indexer_layers[first_layer]
.get_attn_backend()
.get_builder_cls()(
indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config),
self.indexer_layer_names,
self.vllm_config,
self.device,
)
)
else:
self.draft_indexer_metadata_builder = None
if self.supports_mm_inputs:
# Even if the target model is multimodal, we can also use
@@ -1562,9 +1507,9 @@ class SpecDecodeBaseProposer:
# Make sure to use EAGLE's own buffer during cudagraph capture.
if (
self.attn_layer_names
self._draft_attn_layer_names
and slot_mappings is not None
and self.attn_layer_names[0] in slot_mappings
and next(iter(self._draft_attn_layer_names)) in slot_mappings
):
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
else:
@@ -1594,31 +1539,6 @@ class SpecDecodeBaseProposer:
kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
self.model(**kwargs)
def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
"""Find and return the attention metadata builders for EAGLE layers.
Returns:
The metadata builders for EAGLE layers.
Raises:
AssertionError: If no metadata builders are found for EAGLE layers.
"""
builder = None
chosen_layer = self.attn_layer_names[0]
for kv_cache_group in self.runner.attn_groups:
for attn_group in kv_cache_group:
if chosen_layer in attn_group.layer_names:
builder = attn_group.get_metadata_builder()
break
if builder is not None:
break
assert builder is not None, (
"Failed to find attention metadata builder for EAGLE layers."
)
return builder
def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
"""
Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
@@ -1651,13 +1571,71 @@ class SpecDecodeBaseProposer:
set(
[
kv_cache_groups[layer_name]
for layer_name in self.attn_layer_names
for layer_name in self._draft_attn_layer_names
]
)
)
== 1
), "All drafting layers should belong to the same kv cache group"
def initialize_attn_backend(
self,
kv_cache_config: KVCacheConfig,
kernel_block_sizes: list[int] | None = None,
) -> None:
"""
Initialize AttentionGroups for draft layers using kv_cache_config.
Called from the model runner's initialize_metadata_builders.
"""
all_attn_layers = get_layers_from_vllm_config(
self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
)
# Find which kv_cache_group the draft layers belong to
self.validate_same_kv_cache_group(kv_cache_config)
kv_cache_spec = None
for gid, group in enumerate(kv_cache_config.kv_cache_groups):
if self._draft_attn_layer_names & set(group.layer_names):
self.kv_cache_gid = gid
kv_cache_spec = group.kv_cache_spec
break
attention_groups: dict[tuple[str, str], AttentionGroup] = {}
if kv_cache_spec is not None:
for layer_name in self._draft_attn_layer_names:
attn_backend = all_attn_layers[layer_name].get_attn_backend()
backend_key = attn_backend.full_cls_name()
if backend_key not in attention_groups:
layer_kv_cache_spec = kv_cache_spec
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[
layer_name
]
kernel_block_size = (
kernel_block_sizes[self.kv_cache_gid]
if kernel_block_sizes is not None
and self.kv_cache_gid < len(kernel_block_sizes)
else None
)
attn_group = AttentionGroup(
backend=attn_backend,
layer_names=[layer_name],
kv_cache_spec=layer_kv_cache_spec,
kv_cache_group_id=self.kv_cache_gid,
)
attn_group.create_metadata_builders(
self.vllm_config,
self.device,
kernel_block_size=kernel_block_size,
)
attention_groups[backend_key] = attn_group
else:
attention_groups[backend_key].layer_names.append(layer_name)
self.draft_attn_groups = list(attention_groups.values())
def _determine_batch_execution_and_padding(
self,
num_tokens: int,

View File

@@ -1936,7 +1936,7 @@ class GPUModelRunner(
if self.speculative_config and spec_decode_common_attn_metadata is None:
if isinstance(self.drafter, EagleProposer):
if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
if self.drafter.kv_cache_gid == kv_cache_gid:
spec_decode_common_attn_metadata = cm
else:
spec_decode_common_attn_metadata = cm
@@ -5559,6 +5559,14 @@ class GPUModelRunner(
# because some of them change the threshold at init time.
self.calculate_reorder_batch_threshold()
# Initialize drafter attention backend
if self.speculative_config and (
self.speculative_config.use_eagle()
or self.speculative_config.uses_draft_model()
):
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
self.drafter.initialize_attn_backend(kv_cache_config, kernel_block_sizes)
def _check_and_update_cudagraph_mode(
self,
attention_backends: list[set[type[AttentionBackend]]],
@@ -6079,15 +6087,11 @@ class GPUModelRunner(
kv_cache_config, kernel_block_sizes
)
if self.speculative_config and (
self.speculative_config.use_eagle()
or self.speculative_config.uses_draft_model()
or self.speculative_config.uses_extract_hidden_states()
if (
self.speculative_config
and self.speculative_config.uses_extract_hidden_states()
):
assert isinstance(
self.drafter,
EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer,
)
assert isinstance(self.drafter, ExtractHiddenStatesProposer)
# validate all draft model layers belong to the same kv cache
# group
self.drafter.validate_same_kv_cache_group(kv_cache_config)

View File

@@ -48,7 +48,7 @@ class AttentionGroup:
self,
vllm_config,
device,
kernel_block_size: int | None,
kernel_block_size: int | None = None,
num_metadata_builders: int = 1,
):
kv_cache_spec_builder = (