[V1] Support Deepseek MTP (#18435)
Signed-off-by: Rui Qiao <ruisearch42@gmail.com> Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn> Co-authored-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
@@ -151,12 +151,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.use_aux_hidden_state_outputs = False
|
||||
if self.speculative_config:
|
||||
self.use_spec_decode = True
|
||||
|
||||
# NOTE(Jiayi): currently we put the entire draft model on
|
||||
# the last PP rank. This is not ideal if there are many
|
||||
# layers in the draft model.
|
||||
if get_pp_group().is_last_rank:
|
||||
if self.speculative_config.method == "ngram":
|
||||
self.drafter = NgramProposer(self.vllm_config)
|
||||
elif self.speculative_config.use_eagle():
|
||||
self.drafter = EagleProposer(self.vllm_config,
|
||||
self.device) # type: ignore
|
||||
self.drafter = EagleProposer(self.vllm_config, self.device,
|
||||
self) # type: ignore
|
||||
if self.speculative_config.method == "eagle3":
|
||||
self.use_aux_hidden_state_outputs = True
|
||||
elif self.speculative_config.method == "medusa":
|
||||
@@ -1361,6 +1365,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
device=self.device)
|
||||
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
|
||||
|
||||
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
|
||||
if hasattr(eagle_attn_metadata, "block_table"):
|
||||
block_table = eagle_attn_metadata.block_table
|
||||
else:
|
||||
block_table = None
|
||||
|
||||
if spec_decode_metadata is None:
|
||||
# input_ids can be None for multimodal models.
|
||||
target_token_ids = self.input_ids[:num_scheduled_tokens]
|
||||
@@ -1406,7 +1416,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
target_slot_mapping=target_slot_mapping,
|
||||
next_token_ids=next_token_ids,
|
||||
cu_num_tokens=cu_num_tokens,
|
||||
block_table=eagle_attn_metadata.block_table,
|
||||
block_table=block_table,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
spec_token_ids = draft_token_ids.tolist()
|
||||
@@ -1723,8 +1733,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
hidden_states = outputs
|
||||
|
||||
if self.use_spec_decode and \
|
||||
self.speculative_config.method in ('eagle', 'eagle3'):
|
||||
if self.use_spec_decode and self.speculative_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
self.drafter.dummy_run(num_tokens)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user