[V1][Spec Decode] Support multi-layer eagle draft model (#18030)

Signed-off-by: qizixi <qizixi@meta.com>
This commit is contained in:
qizixi
2025-05-24 02:45:34 -07:00
committed by GitHub
parent a859320575
commit c1e4a4052d
3 changed files with 45 additions and 9 deletions

View File

@@ -1360,11 +1360,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output.num_scheduled_tokens[req_id])
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = async_tensor_h2d(next_token_ids,
dtype=torch.int32,
target_device=self.device,
pin_memory=True)
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
eagle_attn_metadata = attn_metadata[
self.drafter.attn_layer_names[0]]
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
if hasattr(eagle_attn_metadata, "block_table"):
@@ -2018,6 +2020,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# KV cache specs.
raise ValueError("Unknown KV cache spec type.")
if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
# validate all draft model layers belong to the same kv cache
# group
self.drafter.validate_same_kv_cache_group(kv_cache_config)
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,