[V1][Spec Decode] Support multi-layer eagle draft model (#18030)
Signed-off-by: qizixi <qizixi@meta.com>
This commit is contained in:
@@ -1360,11 +1360,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
next_token_id = req_state.get_token_id(seq_len)
|
||||
next_token_ids.append(next_token_id)
|
||||
next_token_ids = async_tensor_h2d(next_token_ids,
|
||||
dtype=torch.int32,
|
||||
target_device=self.device,
|
||||
pin_memory=True)
|
||||
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
|
||||
next_token_ids = torch.tensor(next_token_ids,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
# At this moment, we assume all eagle layers belong to the same KV
|
||||
# cache group, thus using the same attention metadata.
|
||||
eagle_attn_metadata = attn_metadata[
|
||||
self.drafter.attn_layer_names[0]]
|
||||
|
||||
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
|
||||
if hasattr(eagle_attn_metadata, "block_table"):
|
||||
@@ -2018,6 +2020,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# KV cache specs.
|
||||
raise ValueError("Unknown KV cache spec type.")
|
||||
|
||||
if self.speculative_config and self.speculative_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
# validate all draft model layers belong to the same kv cache
|
||||
# group
|
||||
self.drafter.validate_same_kv_cache_group(kv_cache_config)
|
||||
|
||||
bind_kv_cache(
|
||||
kv_caches,
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
|
||||
Reference in New Issue
Block a user