[Spec Decode] Unified Parallel Drafting (#32887)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
5b2a9422f0
commit
af3162d3aa
@@ -52,13 +52,16 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
# Subsequent layers use hidden_size (only hidden_states, no embeds)
|
||||
qkv_input_size = 2 * self.hidden_size if layer_idx == 0 else self.hidden_size
|
||||
|
||||
# override qkv
|
||||
# Parallel drafting checkpoints may have attention bias enabled
|
||||
qkv_bias = getattr(config, "attention_bias", False)
|
||||
|
||||
# Override qkv_proj with correct input size and bias setting
|
||||
self.self_attn.qkv_proj = QKVParallelLinear(
|
||||
qkv_input_size,
|
||||
self.self_attn.head_dim,
|
||||
self.self_attn.total_num_heads,
|
||||
self.self_attn.total_num_kv_heads,
|
||||
bias=False,
|
||||
bias=qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "qkv_proj"),
|
||||
)
|
||||
@@ -293,6 +296,19 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
self.use_parallel_drafting = vllm_config.speculative_config.parallel_drafting
|
||||
|
||||
if self.use_parallel_drafting:
|
||||
self.register_buffer(
|
||||
"mask_hidden",
|
||||
torch.zeros(
|
||||
1,
|
||||
(3 if self.model.use_aux_hidden_state else 1)
|
||||
* self.config.hidden_size,
|
||||
),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def embed_input_ids(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -347,12 +363,25 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
model_weights = {}
|
||||
includes_draft_id_mapping = False
|
||||
includes_embed_tokens = False
|
||||
includes_mask_hidden = False
|
||||
for name, loaded_weight in weights:
|
||||
if "t2d" in name:
|
||||
continue
|
||||
if "d2t" in name:
|
||||
name = name.replace("d2t", "draft_id_to_target_id")
|
||||
includes_draft_id_mapping = True
|
||||
elif "mask_hidden" in name:
|
||||
# Load mask_hidden directly into buffer
|
||||
if not self.use_parallel_drafting:
|
||||
logger.warning(
|
||||
"mask_hidden found in weights but "
|
||||
"model is not configured for parallel drafting. "
|
||||
"Skipping loading mask_hidden."
|
||||
)
|
||||
continue
|
||||
self.mask_hidden.copy_(loaded_weight.view(1, -1))
|
||||
includes_mask_hidden = True
|
||||
continue
|
||||
elif "lm_head" not in name:
|
||||
name = "model." + name
|
||||
if "embed_tokens" in name:
|
||||
@@ -360,7 +389,14 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
model_weights[name] = loaded_weight
|
||||
process_eagle_weight(self, name)
|
||||
|
||||
skip_substrs = []
|
||||
if not includes_mask_hidden and self.use_parallel_drafting:
|
||||
raise ValueError(
|
||||
"mask_hidden not found in weights but "
|
||||
"model is configured for parallel drafting. "
|
||||
"Please provide mask_hidden in the weights."
|
||||
)
|
||||
|
||||
skip_substrs = ["mask_hidden"]
|
||||
if not includes_draft_id_mapping:
|
||||
skip_substrs.append("draft_id_to_target_id")
|
||||
if not includes_embed_tokens:
|
||||
|
||||
Reference in New Issue
Block a user