[V1] Support Deepseek MTP (#18435)

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>
Co-authored-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
Jiayi Yao
2025-05-23 12:26:28 -05:00
committed by GitHub
parent 371f7e4ca2
commit 2628a69e35
6 changed files with 120 additions and 66 deletions

View File

@@ -151,12 +151,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.use_aux_hidden_state_outputs = False
if self.speculative_config:
self.use_spec_decode = True
# NOTE(Jiayi): currently we put the entire draft model on
# the last PP rank. This is not ideal if there are many
# layers in the draft model.
if get_pp_group().is_last_rank:
if self.speculative_config.method == "ngram":
self.drafter = NgramProposer(self.vllm_config)
elif self.speculative_config.use_eagle():
self.drafter = EagleProposer(self.vllm_config,
self.device) # type: ignore
self.drafter = EagleProposer(self.vllm_config, self.device,
self) # type: ignore
if self.speculative_config.method == "eagle3":
self.use_aux_hidden_state_outputs = True
elif self.speculative_config.method == "medusa":
@@ -1361,6 +1365,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
device=self.device)
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
if hasattr(eagle_attn_metadata, "block_table"):
block_table = eagle_attn_metadata.block_table
else:
block_table = None
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
@@ -1406,7 +1416,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens,
block_table=eagle_attn_metadata.block_table,
block_table=block_table,
sampling_metadata=sampling_metadata,
)
spec_token_ids = draft_token_ids.tolist()
@@ -1723,8 +1733,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else:
hidden_states = outputs
if self.use_spec_decode and \
self.speculative_config.method in ('eagle', 'eagle3'):
if self.use_spec_decode and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
self.drafter.dummy_run(num_tokens)