[Spec Decode] Unified Parallel Drafting (#32887)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Benjamin Chislett
2026-02-05 12:37:18 -05:00
committed by GitHub
parent 5b2a9422f0
commit af3162d3aa
14 changed files with 1085 additions and 392 deletions

View File

@@ -52,13 +52,16 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
# Subsequent layers use hidden_size (only hidden_states, no embeds)
qkv_input_size = 2 * self.hidden_size if layer_idx == 0 else self.hidden_size
# override qkv
# Parallel drafting checkpoints may have attention bias enabled
qkv_bias = getattr(config, "attention_bias", False)
# Override qkv_proj with correct input size and bias setting
self.self_attn.qkv_proj = QKVParallelLinear(
qkv_input_size,
self.self_attn.head_dim,
self.self_attn.total_num_heads,
self.self_attn.total_num_kv_heads,
bias=False,
bias=qkv_bias,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "qkv_proj"),
)
@@ -293,6 +296,19 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
requires_grad=False,
)
self.use_parallel_drafting = vllm_config.speculative_config.parallel_drafting
if self.use_parallel_drafting:
self.register_buffer(
"mask_hidden",
torch.zeros(
1,
(3 if self.model.use_aux_hidden_state else 1)
* self.config.hidden_size,
),
persistent=False,
)
def embed_input_ids(
self,
input_ids: torch.Tensor,
@@ -347,12 +363,25 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
model_weights = {}
includes_draft_id_mapping = False
includes_embed_tokens = False
includes_mask_hidden = False
for name, loaded_weight in weights:
if "t2d" in name:
continue
if "d2t" in name:
name = name.replace("d2t", "draft_id_to_target_id")
includes_draft_id_mapping = True
elif "mask_hidden" in name:
# Load mask_hidden directly into buffer
if not self.use_parallel_drafting:
logger.warning(
"mask_hidden found in weights but "
"model is not configured for parallel drafting. "
"Skipping loading mask_hidden."
)
continue
self.mask_hidden.copy_(loaded_weight.view(1, -1))
includes_mask_hidden = True
continue
elif "lm_head" not in name:
name = "model." + name
if "embed_tokens" in name:
@@ -360,7 +389,14 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
model_weights[name] = loaded_weight
process_eagle_weight(self, name)
skip_substrs = []
if not includes_mask_hidden and self.use_parallel_drafting:
raise ValueError(
"mask_hidden not found in weights but "
"model is configured for parallel drafting. "
"Please provide mask_hidden in the weights."
)
skip_substrs = ["mask_hidden"]
if not includes_draft_id_mapping:
skip_substrs.append("draft_id_to_target_id")
if not includes_embed_tokens: