[BugFix] Add support for MTP num_speculative_tokens > 1 with sparse MLA (#34552)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -476,12 +476,12 @@ def test_set_inputs_first_pass_draft_model():
|
||||
proposer.max_num_tokens, dtype=torch.bool, device=device
|
||||
)
|
||||
|
||||
# Mock the attn_metadata_builder to avoid needing the full model setup
|
||||
# Mock draft_attn_groups to avoid needing the full model setup
|
||||
mock_kv_cache_spec = mock.MagicMock()
|
||||
mock_kv_cache_spec.block_size = block_size
|
||||
mock_builder = mock.MagicMock()
|
||||
mock_builder.kv_cache_spec = mock_kv_cache_spec
|
||||
proposer.attn_metadata_builder = mock_builder
|
||||
mock_attn_group = mock.MagicMock()
|
||||
mock_attn_group.kv_cache_spec = mock_kv_cache_spec
|
||||
proposer.draft_attn_groups = [mock_attn_group]
|
||||
|
||||
# Request 0: query_len=3 (but 1 rejected), Request 1: query_len=2
|
||||
batch_spec = BatchSpec(
|
||||
@@ -616,12 +616,12 @@ def test_set_inputs_first_pass_parallel_drafting():
|
||||
proposer.max_num_tokens, dtype=torch.bool, device=device
|
||||
)
|
||||
|
||||
# Mock the attn_metadata_builder
|
||||
# Mock draft_attn_groups
|
||||
mock_kv_cache_spec = mock.MagicMock()
|
||||
mock_kv_cache_spec.block_size = block_size
|
||||
mock_builder = mock.MagicMock()
|
||||
mock_builder.kv_cache_spec = mock_kv_cache_spec
|
||||
proposer.attn_metadata_builder = mock_builder
|
||||
mock_attn_group = mock.MagicMock()
|
||||
mock_attn_group.kv_cache_spec = mock_kv_cache_spec
|
||||
proposer.draft_attn_groups = [mock_attn_group]
|
||||
|
||||
# Request 0: query_len=4 (1 rejected), Request 1: query_len=4 (all valid)
|
||||
batch_spec = BatchSpec(
|
||||
@@ -916,7 +916,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
proposer.model = model_mock
|
||||
|
||||
# Assign draft attn_layer_names since load_model is not invoked
|
||||
proposer.attn_layer_names = ["layer.0"]
|
||||
proposer._draft_attn_layer_names = {"layer.0"}
|
||||
|
||||
# Create input tensors
|
||||
batch_spec = BatchSpec(
|
||||
@@ -961,20 +961,18 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
|
||||
attn_metadata_builder = attn_metadata_builder_cls(
|
||||
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||
layer_names=proposer.attn_layer_names,
|
||||
layer_names=proposer._draft_attn_layer_names,
|
||||
vllm_config=proposer.vllm_config,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Mock runner for attention metadata building
|
||||
# Mock runner and draft_attn_groups for attention metadata building
|
||||
proposer.runner = mock.MagicMock()
|
||||
proposer.runner.attn_groups.append([mock.MagicMock()])
|
||||
proposer.runner.attn_groups[0][
|
||||
0
|
||||
].get_metadata_builder.return_value = attn_metadata_builder
|
||||
proposer._get_attention_metadata_builder = mock.MagicMock(
|
||||
return_value=attn_metadata_builder
|
||||
)
|
||||
mock_attn_group = mock.MagicMock()
|
||||
mock_attn_group.get_metadata_builder.return_value = attn_metadata_builder
|
||||
mock_attn_group.layer_names = list(proposer._draft_attn_layer_names)
|
||||
mock_attn_group.kv_cache_spec = attn_metadata_builder.kv_cache_spec
|
||||
proposer.draft_attn_groups = [mock_attn_group]
|
||||
|
||||
result = proposer.propose(
|
||||
target_token_ids=target_token_ids,
|
||||
@@ -1089,7 +1087,7 @@ def test_propose_tree(spec_token_tree):
|
||||
proposer.model = model_mock
|
||||
|
||||
# Assign draft attn_layer_names since load_model is not invoked
|
||||
proposer.attn_layer_names = ["layer.0"]
|
||||
proposer._draft_attn_layer_names = {"layer.0"}
|
||||
|
||||
# Get the tree attention metadata builder.
|
||||
attn_metadata_builder_cls, _ = try_get_attention_backend(
|
||||
@@ -1097,21 +1095,18 @@ def test_propose_tree(spec_token_tree):
|
||||
)
|
||||
attn_metadata_builder = attn_metadata_builder_cls(
|
||||
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||
layer_names=proposer.attn_layer_names,
|
||||
layer_names=proposer._draft_attn_layer_names,
|
||||
vllm_config=proposer.vllm_config,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Mock runner for attention metadata building.
|
||||
# Mock runner and draft_attn_groups for attention metadata building.
|
||||
proposer.runner = mock.MagicMock()
|
||||
proposer.runner.attn_groups.append([mock.MagicMock()])
|
||||
proposer.runner.attn_groups[0][0].metadata_builders = [attn_metadata_builder]
|
||||
proposer.runner.attn_groups[0][
|
||||
0
|
||||
].get_metadata_builder.return_value = attn_metadata_builder
|
||||
proposer._get_attention_metadata_builder = mock.MagicMock(
|
||||
return_value=attn_metadata_builder
|
||||
)
|
||||
mock_attn_group = mock.MagicMock()
|
||||
mock_attn_group.get_metadata_builder.return_value = attn_metadata_builder
|
||||
mock_attn_group.layer_names = list(proposer._draft_attn_layer_names)
|
||||
mock_attn_group.kv_cache_spec = attn_metadata_builder.kv_cache_spec
|
||||
proposer.draft_attn_groups = [mock_attn_group]
|
||||
|
||||
# Setup inputs for the proposer.
|
||||
target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device)
|
||||
|
||||
@@ -162,7 +162,7 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
|
||||
model_mock.compute_logits.side_effect = logits_returns
|
||||
|
||||
proposer.model = model_mock
|
||||
proposer.attn_layer_names = ["layer.0"]
|
||||
proposer._draft_attn_layer_names = {"layer.0"}
|
||||
|
||||
# Prepare inputs
|
||||
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens)
|
||||
@@ -190,13 +190,17 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
|
||||
|
||||
attn_metadata_builder = attn_metadata_builder_cls(
|
||||
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||
layer_names=proposer.attn_layer_names,
|
||||
layer_names=list(proposer._draft_attn_layer_names),
|
||||
vllm_config=proposer.vllm_config,
|
||||
device=device,
|
||||
)
|
||||
|
||||
proposer.runner = mock.MagicMock()
|
||||
proposer.attn_metadata_builder = attn_metadata_builder
|
||||
mock_attn_group = mock.MagicMock()
|
||||
mock_attn_group.get_metadata_builder.return_value = attn_metadata_builder
|
||||
mock_attn_group.layer_names = list(proposer._draft_attn_layer_names)
|
||||
mock_attn_group.kv_cache_spec = attn_metadata_builder.kv_cache_spec
|
||||
proposer.draft_attn_groups = [mock_attn_group]
|
||||
|
||||
# Run propose
|
||||
result = proposer.propose(
|
||||
|
||||
@@ -79,6 +79,12 @@ def sparse_attn_indexer(
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
# During speculative decoding, k may be padded to the CUDA graph batch
|
||||
# size while slot_mapping only covers actual tokens. Truncate k to avoid
|
||||
# out-of-bounds reads in the kernel.
|
||||
num_tokens = slot_mapping.shape[0]
|
||||
k = k[:num_tokens]
|
||||
|
||||
ops.indexer_k_quant_and_cache(
|
||||
k,
|
||||
kv_cache,
|
||||
|
||||
@@ -12,6 +12,7 @@ from vllm.utils.deep_gemm import (
|
||||
get_paged_mqa_logits_metadata,
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
@@ -24,6 +25,7 @@ from vllm.v1.attention.backends.utils import (
|
||||
split_decodes_and_prefills,
|
||||
split_prefill_chunks,
|
||||
)
|
||||
from vllm.v1.worker.cp_utils import get_total_cp_world_size
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -214,20 +216,39 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
if self.vllm_config.speculative_config
|
||||
else 0
|
||||
)
|
||||
if self.num_speculative_tokens > 1:
|
||||
raise ValueError(
|
||||
"Sparse MLA only supports "
|
||||
"num_speculative_tokens <= 1 because the DeepGEMM "
|
||||
"fp8_paged_mqa_logits kernel does not support next_n > 2. "
|
||||
f"Got num_speculative_tokens={self.num_speculative_tokens}."
|
||||
)
|
||||
self.reorder_batch_threshold += self.num_speculative_tokens
|
||||
|
||||
sm_count = num_compute_units(self.device.index)
|
||||
self.num_sms = sm_count
|
||||
|
||||
self.decode_lens_buffer = torch.empty(
|
||||
(scheduler_config.max_num_seqs,), dtype=torch.int32, device=self.device
|
||||
(scheduler_config.max_num_batched_tokens,),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Pre-allocated buffers for flattening (spec decode).
|
||||
self.arange_buffer = torch.arange(
|
||||
scheduler_config.max_num_seqs * (1 + self.num_speculative_tokens),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.expanded_seq_lens_buffer = torch.zeros(
|
||||
(scheduler_config.max_num_batched_tokens,),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
max_num_blocks_per_req = cdiv(
|
||||
self.vllm_config.model_config.max_model_len,
|
||||
self.kv_cache_spec.block_size * get_total_cp_world_size(),
|
||||
)
|
||||
self.expanded_block_table_buffer = torch.zeros(
|
||||
(
|
||||
scheduler_config.max_num_batched_tokens,
|
||||
max_num_blocks_per_req,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# See: DeepGMM/csrc/apis/attention.hpp
|
||||
@@ -326,42 +347,97 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
|
||||
)
|
||||
|
||||
# Use CPU to avoid GPU sync; breaking async scheduling
|
||||
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
|
||||
|
||||
# Decide which top-k kernel to use based on batch size and sequence length
|
||||
batch_size = num_decodes
|
||||
_is_large_context = common_attn_metadata.max_seq_len > 8192
|
||||
|
||||
# Decision logic based on micro-benchmark results:
|
||||
# - large_context_topk wins for batch <= 128 and seq_len > 8K
|
||||
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
|
||||
use_large_context_topk = batch_size <= 128 and _is_large_context
|
||||
|
||||
next_n = 1 + self.num_speculative_tokens
|
||||
if next_n > 1:
|
||||
offsets = torch.arange(next_n, device=self.device, dtype=torch.int32)
|
||||
else:
|
||||
offsets = None
|
||||
|
||||
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
|
||||
|
||||
# DeepGEMM is required for the paged MQA logits on CUDA devices
|
||||
if current_platform.is_cuda() and is_deep_gemm_supported():
|
||||
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
|
||||
seq_lens, self.kv_cache_spec.block_size, self.num_sms
|
||||
)
|
||||
block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...]
|
||||
|
||||
# Padded CUDA graph requests have block_table entries of -1.
|
||||
# Clamp to 0 to prevent OOB access in the DeepGEMM kernel.
|
||||
# This is safe because padded requests have seq_lens=0, so the
|
||||
# kernel produces no meaningful output for those rows.
|
||||
block_table.clamp_(min=0)
|
||||
|
||||
max_decode_len = int(decode_lens_cpu.max().item())
|
||||
if max_decode_len > 1:
|
||||
# Flatten multi-token decode requests into single-token
|
||||
# batch entries, expanding seq_lens and block tables so
|
||||
# the kernel always sees next_n=1.
|
||||
|
||||
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
|
||||
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
|
||||
# The context lengths are therefore
|
||||
# [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0].
|
||||
|
||||
# 3 + 1 + 4 + 0 = 8
|
||||
actual_expanded = int(decode_lens_cpu.sum().item())
|
||||
|
||||
# [7, 6, 8, 0] -> [7, 7, 7, 6, 8, 8, 8, 8]
|
||||
expanded_base = torch.repeat_interleave(
|
||||
seq_lens - decode_lens, decode_lens
|
||||
)
|
||||
|
||||
# [0, 3, 4, 8] -> [0, 0, 0, 3, 4, 4, 4, 4]
|
||||
expanded_starts = torch.repeat_interleave(
|
||||
common_attn_metadata.query_start_loc[:num_decodes], decode_lens
|
||||
)
|
||||
|
||||
# [0, 1, 2, 0, 0, 1, 2, 3]
|
||||
positions_within = (
|
||||
self.arange_buffer[:actual_expanded] - expanded_starts
|
||||
)
|
||||
|
||||
# [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space
|
||||
self.expanded_seq_lens_buffer[:actual_expanded] = (
|
||||
expanded_base + positions_within + 1
|
||||
)
|
||||
self.expanded_seq_lens_buffer[actual_expanded:] = 0
|
||||
seq_lens = self.expanded_seq_lens_buffer[:num_decode_tokens]
|
||||
|
||||
# Give each of the flattened entries the same block table row as the
|
||||
# original request.
|
||||
self.expanded_block_table_buffer[:actual_expanded] = (
|
||||
torch.repeat_interleave(block_table, decode_lens, dim=0)
|
||||
)
|
||||
if actual_expanded < num_decode_tokens:
|
||||
self.expanded_block_table_buffer[
|
||||
actual_expanded:num_decode_tokens, 0
|
||||
] = 0
|
||||
block_table = self.expanded_block_table_buffer[:num_decode_tokens]
|
||||
|
||||
# All reqs now have decode_len=1
|
||||
self.decode_lens_buffer[:num_decode_tokens] = 1
|
||||
decode_lens = self.decode_lens_buffer[:num_decode_tokens]
|
||||
offsets = None
|
||||
batch_size = num_decode_tokens
|
||||
else:
|
||||
next_n = 1 + self.num_speculative_tokens
|
||||
if next_n > 1:
|
||||
offsets = torch.arange(
|
||||
next_n, device=self.device, dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
offsets = None
|
||||
batch_size = num_decodes
|
||||
|
||||
# DeepGEMM is required for the paged MQA logits on CUDA devices
|
||||
if current_platform.is_cuda() and is_deep_gemm_supported():
|
||||
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
|
||||
seq_lens,
|
||||
self.kv_cache_spec.block_size,
|
||||
self.num_sms,
|
||||
)
|
||||
|
||||
# Decide which top-k kernel to use based on batch size and sequence length
|
||||
# Decision logic based on micro-benchmark results:
|
||||
# - large_context_topk wins for batch <= 128 and seq_len > 8K
|
||||
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
|
||||
_is_large_context = common_attn_metadata.max_seq_len > 8192
|
||||
use_large_context_topk = batch_size <= 128 and _is_large_context
|
||||
|
||||
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||||
block_table=block_table,
|
||||
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
||||
seq_lens=seq_lens,
|
||||
decode_lens=decode_lens,
|
||||
requires_padding=requires_padding,
|
||||
requires_padding=False,
|
||||
schedule_metadata=self.scheduler_metadata_buffer,
|
||||
use_large_context_topk=use_large_context_topk,
|
||||
offsets=offsets,
|
||||
|
||||
@@ -20,17 +20,13 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import supports_multimodal
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backend import CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.attention.backends.tree_attn import (
|
||||
TreeAttentionMetadata,
|
||||
@@ -38,7 +34,7 @@ from vllm.v1.attention.backends.tree_attn import (
|
||||
)
|
||||
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, UniformTypeKVCacheSpecs
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.sampler import _SAMPLING_EPS
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
@@ -53,6 +49,7 @@ from vllm.v1.spec_decode.utils import (
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -113,10 +110,8 @@ class SpecDecodeBaseProposer:
|
||||
vllm_config.model_config
|
||||
)
|
||||
|
||||
self.attn_metadata_builder: AttentionMetadataBuilder | None = None
|
||||
self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None
|
||||
self.attn_layer_names: list[str] = []
|
||||
self.indexer_layer_names: list[str] = []
|
||||
self.draft_attn_groups: list[AttentionGroup] = []
|
||||
self.kv_cache_gid: int = -1
|
||||
self.eagle3_use_aux_hidden_state: bool = (
|
||||
self._get_eagle3_use_aux_hidden_state_from_config()
|
||||
)
|
||||
@@ -353,7 +348,7 @@ class SpecDecodeBaseProposer:
|
||||
self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)
|
||||
|
||||
view = self._slot_mapping_buffer[:num_tokens]
|
||||
return {name: view for name in self.attn_layer_names + self.indexer_layer_names}
|
||||
return {name: view for name in self._draft_attn_layer_names}
|
||||
|
||||
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
|
||||
"""Initialize cudagraph dispatcher keys for eagle.
|
||||
@@ -420,33 +415,13 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
assert self.runner is not None
|
||||
|
||||
if self.attn_metadata_builder is None:
|
||||
attn_metadata_builder = self._get_attention_metadata_builder()
|
||||
else:
|
||||
attn_metadata_builder = self.attn_metadata_builder
|
||||
|
||||
attn_metadata = attn_metadata_builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata, draft_index=0
|
||||
)
|
||||
# FIXME: support hybrid kv for draft model (remove separate indexer)
|
||||
if self.draft_indexer_metadata_builder:
|
||||
draft_indexer_metadata = (
|
||||
self.draft_indexer_metadata_builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
draft_index=0,
|
||||
)
|
||||
per_layer_attn_metadata: dict[str, object] = {}
|
||||
for attn_group in self.draft_attn_groups:
|
||||
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata, draft_index=0
|
||||
)
|
||||
else:
|
||||
draft_indexer_metadata = None
|
||||
# At this moment, we assume all eagle layers belong to the same KV
|
||||
# cache group, thus using the same attention metadata.
|
||||
per_layer_attn_metadata = {}
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
|
||||
for layer_name in self.indexer_layer_names:
|
||||
assert draft_indexer_metadata is not None
|
||||
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
|
||||
for layer_name in attn_group.layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
|
||||
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
|
||||
self._determine_batch_execution_and_padding(num_tokens)
|
||||
@@ -503,12 +478,7 @@ class SpecDecodeBaseProposer:
|
||||
positions = self.mrope_positions[:, token_indices_to_sample]
|
||||
else:
|
||||
positions = self.positions[token_indices_to_sample]
|
||||
if self.method in (
|
||||
"deepseek_mtp",
|
||||
"ernie_mtp",
|
||||
"longcat_flash_mtp",
|
||||
"pangu_ultra_moe_mtp",
|
||||
):
|
||||
if self.method == "mtp":
|
||||
hidden_states = self.hidden_states[token_indices_to_sample]
|
||||
else:
|
||||
hidden_states = hidden_states[token_indices_to_sample]
|
||||
@@ -613,7 +583,8 @@ class SpecDecodeBaseProposer:
|
||||
common_attn_metadata._num_computed_tokens_cpu += 1
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_size = attn_metadata_builder.kv_cache_spec.block_size
|
||||
# Use the first draft attention group's kv_cache_spec for block_size
|
||||
block_size = self.draft_attn_groups[0].kv_cache_spec.block_size
|
||||
if self.uses_mrope:
|
||||
# all dimensions of positions are the same
|
||||
block_numbers = clamped_positions[0] // block_size
|
||||
@@ -639,11 +610,13 @@ class SpecDecodeBaseProposer:
|
||||
)
|
||||
|
||||
# Rebuild attention metadata
|
||||
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
|
||||
common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
|
||||
)
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
for attn_group in self.draft_attn_groups:
|
||||
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
draft_index=token_index + 1,
|
||||
)
|
||||
for layer_name in attn_group.layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.input_ids[:batch_size] = input_ids
|
||||
@@ -805,18 +778,17 @@ class SpecDecodeBaseProposer:
|
||||
# 2.
|
||||
# Recompute the slot mapping based on the new positions and
|
||||
# rejection mask.
|
||||
builder = (
|
||||
self._get_attention_metadata_builder()
|
||||
if self.attn_metadata_builder is None
|
||||
else self.attn_metadata_builder
|
||||
)
|
||||
# Use the first draft attention group's kv_cache_spec for block_size
|
||||
# (all draft layers share the same kv-cache group)
|
||||
assert len(self.draft_attn_groups) > 0
|
||||
block_size = self.draft_attn_groups[0].kv_cache_spec.block_size
|
||||
new_slot_mapping = compute_new_slot_mapping(
|
||||
cad=cad,
|
||||
new_positions=self.positions[:total_num_output_tokens],
|
||||
is_rejected_token_mask=self.is_rejected_token_mask[
|
||||
:total_num_output_tokens
|
||||
],
|
||||
block_size=builder.kv_cache_spec.block_size,
|
||||
block_size=block_size,
|
||||
num_new_tokens=self.net_num_new_slots_per_request,
|
||||
max_model_len=self.max_model_len,
|
||||
)
|
||||
@@ -1000,9 +972,7 @@ class SpecDecodeBaseProposer:
|
||||
| list[dict[str, torch.Tensor]]
|
||||
| None = None,
|
||||
) -> list[torch.Tensor]:
|
||||
tree_attn_metadata_builder = self.runner.attn_groups[0][
|
||||
0
|
||||
].get_metadata_builder()
|
||||
tree_attn_metadata_builder = self.draft_attn_groups[0].get_metadata_builder()
|
||||
assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder)
|
||||
|
||||
total_num_drafts = self.cu_drafts_per_level[0]
|
||||
@@ -1078,10 +1048,11 @@ class SpecDecodeBaseProposer:
|
||||
common_attn_metadata=common_attn_metadata, draft_index=level + 1
|
||||
)
|
||||
|
||||
# Apply new attention metadata to all layers.
|
||||
# Apply new attention metadata to all draft layers.
|
||||
per_layer_attn_metadata = {}
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
for attn_group in self.draft_attn_groups:
|
||||
for layer_name in attn_group.layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
|
||||
# Consider max model length.
|
||||
attn_metadata.max_seq_len = min(
|
||||
@@ -1288,43 +1259,17 @@ class SpecDecodeBaseProposer:
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
).keys()
|
||||
)
|
||||
# FIXME: support hybrid kv for draft model
|
||||
target_indexer_layer_names = set(
|
||||
get_layers_from_vllm_config(
|
||||
self.vllm_config, DeepseekV32IndexerCache
|
||||
).keys()
|
||||
)
|
||||
|
||||
self.model = self._get_model()
|
||||
|
||||
draft_attn_layer_names = (
|
||||
get_layers_from_vllm_config(
|
||||
self.vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
).keys()
|
||||
- target_attn_layer_names
|
||||
# Find draft layers (attention layers added by draft model)
|
||||
all_attn_layers = get_layers_from_vllm_config(
|
||||
self.vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
)
|
||||
indexer_layers = get_layers_from_vllm_config(
|
||||
self.vllm_config, DeepseekV32IndexerCache
|
||||
self._draft_attn_layer_names = (
|
||||
set(all_attn_layers.keys()) - target_attn_layer_names
|
||||
)
|
||||
draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names
|
||||
self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names)
|
||||
self.indexer_layer_names = list(draft_indexer_layer_names)
|
||||
|
||||
if self.indexer_layer_names:
|
||||
first_layer = self.indexer_layer_names[0]
|
||||
self.draft_indexer_metadata_builder = (
|
||||
indexer_layers[first_layer]
|
||||
.get_attn_backend()
|
||||
.get_builder_cls()(
|
||||
indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config),
|
||||
self.indexer_layer_names,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.draft_indexer_metadata_builder = None
|
||||
|
||||
if self.supports_mm_inputs:
|
||||
# Even if the target model is multimodal, we can also use
|
||||
@@ -1562,9 +1507,9 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
# Make sure to use EAGLE's own buffer during cudagraph capture.
|
||||
if (
|
||||
self.attn_layer_names
|
||||
self._draft_attn_layer_names
|
||||
and slot_mappings is not None
|
||||
and self.attn_layer_names[0] in slot_mappings
|
||||
and next(iter(self._draft_attn_layer_names)) in slot_mappings
|
||||
):
|
||||
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
|
||||
else:
|
||||
@@ -1594,31 +1539,6 @@ class SpecDecodeBaseProposer:
|
||||
kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
|
||||
self.model(**kwargs)
|
||||
|
||||
def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
|
||||
"""Find and return the attention metadata builders for EAGLE layers.
|
||||
|
||||
Returns:
|
||||
The metadata builders for EAGLE layers.
|
||||
|
||||
Raises:
|
||||
AssertionError: If no metadata builders are found for EAGLE layers.
|
||||
"""
|
||||
builder = None
|
||||
chosen_layer = self.attn_layer_names[0]
|
||||
|
||||
for kv_cache_group in self.runner.attn_groups:
|
||||
for attn_group in kv_cache_group:
|
||||
if chosen_layer in attn_group.layer_names:
|
||||
builder = attn_group.get_metadata_builder()
|
||||
break
|
||||
if builder is not None:
|
||||
break
|
||||
|
||||
assert builder is not None, (
|
||||
"Failed to find attention metadata builder for EAGLE layers."
|
||||
)
|
||||
return builder
|
||||
|
||||
def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
|
||||
"""
|
||||
Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
|
||||
@@ -1651,13 +1571,71 @@ class SpecDecodeBaseProposer:
|
||||
set(
|
||||
[
|
||||
kv_cache_groups[layer_name]
|
||||
for layer_name in self.attn_layer_names
|
||||
for layer_name in self._draft_attn_layer_names
|
||||
]
|
||||
)
|
||||
)
|
||||
== 1
|
||||
), "All drafting layers should belong to the same kv cache group"
|
||||
|
||||
def initialize_attn_backend(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
kernel_block_sizes: list[int] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize AttentionGroups for draft layers using kv_cache_config.
|
||||
Called from the model runner's initialize_metadata_builders.
|
||||
"""
|
||||
all_attn_layers = get_layers_from_vllm_config(
|
||||
self.vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
)
|
||||
|
||||
# Find which kv_cache_group the draft layers belong to
|
||||
self.validate_same_kv_cache_group(kv_cache_config)
|
||||
kv_cache_spec = None
|
||||
for gid, group in enumerate(kv_cache_config.kv_cache_groups):
|
||||
if self._draft_attn_layer_names & set(group.layer_names):
|
||||
self.kv_cache_gid = gid
|
||||
kv_cache_spec = group.kv_cache_spec
|
||||
break
|
||||
|
||||
attention_groups: dict[tuple[str, str], AttentionGroup] = {}
|
||||
if kv_cache_spec is not None:
|
||||
for layer_name in self._draft_attn_layer_names:
|
||||
attn_backend = all_attn_layers[layer_name].get_attn_backend()
|
||||
backend_key = attn_backend.full_cls_name()
|
||||
if backend_key not in attention_groups:
|
||||
layer_kv_cache_spec = kv_cache_spec
|
||||
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
|
||||
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[
|
||||
layer_name
|
||||
]
|
||||
|
||||
kernel_block_size = (
|
||||
kernel_block_sizes[self.kv_cache_gid]
|
||||
if kernel_block_sizes is not None
|
||||
and self.kv_cache_gid < len(kernel_block_sizes)
|
||||
else None
|
||||
)
|
||||
attn_group = AttentionGroup(
|
||||
backend=attn_backend,
|
||||
layer_names=[layer_name],
|
||||
kv_cache_spec=layer_kv_cache_spec,
|
||||
kv_cache_group_id=self.kv_cache_gid,
|
||||
)
|
||||
attn_group.create_metadata_builders(
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
kernel_block_size=kernel_block_size,
|
||||
)
|
||||
attention_groups[backend_key] = attn_group
|
||||
else:
|
||||
attention_groups[backend_key].layer_names.append(layer_name)
|
||||
|
||||
self.draft_attn_groups = list(attention_groups.values())
|
||||
|
||||
def _determine_batch_execution_and_padding(
|
||||
self,
|
||||
num_tokens: int,
|
||||
|
||||
@@ -1936,7 +1936,7 @@ class GPUModelRunner(
|
||||
|
||||
if self.speculative_config and spec_decode_common_attn_metadata is None:
|
||||
if isinstance(self.drafter, EagleProposer):
|
||||
if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
|
||||
if self.drafter.kv_cache_gid == kv_cache_gid:
|
||||
spec_decode_common_attn_metadata = cm
|
||||
else:
|
||||
spec_decode_common_attn_metadata = cm
|
||||
@@ -5559,6 +5559,14 @@ class GPUModelRunner(
|
||||
# because some of them change the threshold at init time.
|
||||
self.calculate_reorder_batch_threshold()
|
||||
|
||||
# Initialize drafter attention backend
|
||||
if self.speculative_config and (
|
||||
self.speculative_config.use_eagle()
|
||||
or self.speculative_config.uses_draft_model()
|
||||
):
|
||||
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
|
||||
self.drafter.initialize_attn_backend(kv_cache_config, kernel_block_sizes)
|
||||
|
||||
def _check_and_update_cudagraph_mode(
|
||||
self,
|
||||
attention_backends: list[set[type[AttentionBackend]]],
|
||||
@@ -6079,15 +6087,11 @@ class GPUModelRunner(
|
||||
kv_cache_config, kernel_block_sizes
|
||||
)
|
||||
|
||||
if self.speculative_config and (
|
||||
self.speculative_config.use_eagle()
|
||||
or self.speculative_config.uses_draft_model()
|
||||
or self.speculative_config.uses_extract_hidden_states()
|
||||
if (
|
||||
self.speculative_config
|
||||
and self.speculative_config.uses_extract_hidden_states()
|
||||
):
|
||||
assert isinstance(
|
||||
self.drafter,
|
||||
EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer,
|
||||
)
|
||||
assert isinstance(self.drafter, ExtractHiddenStatesProposer)
|
||||
# validate all draft model layers belong to the same kv cache
|
||||
# group
|
||||
self.drafter.validate_same_kv_cache_group(kv_cache_config)
|
||||
|
||||
@@ -48,7 +48,7 @@ class AttentionGroup:
|
||||
self,
|
||||
vllm_config,
|
||||
device,
|
||||
kernel_block_size: int | None,
|
||||
kernel_block_size: int | None = None,
|
||||
num_metadata_builders: int = 1,
|
||||
):
|
||||
kv_cache_spec_builder = (
|
||||
|
||||
Reference in New Issue
Block a user