[V1][Spec Decode] Apply torch.compile & cudagraph to EAGLE (#17211)

Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
This commit is contained in:
Bryan Lu
2025-04-29 14:10:00 -07:00
committed by GitHub
parent c9c1b59e59
commit 70788bdbdc
6 changed files with 152 additions and 53 deletions

View File

@@ -1106,7 +1106,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# For mid-pipeline stages, return the hidden states.
return hidden_states
hidden_states = hidden_states[:num_scheduled_tokens]
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
@@ -1172,7 +1171,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states,
hidden_states[:num_scheduled_tokens],
scheduler_output,
)
@@ -1222,15 +1221,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
# We need to slice token_ids, positions, and hidden_states
# because the eagle head does not use cuda graph and should
# not include padding.
target_token_ids = self.input_ids[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
if self.use_aux_hidden_state_outputs:
target_hidden_states = [
h[:num_scheduled_tokens] for h in aux_hidden_states
]
target_hidden_states = torch.cat(
[h[:num_scheduled_tokens] for h in aux_hidden_states],
dim=-1)
else:
target_hidden_states = hidden_states[:num_scheduled_tokens]
target_slot_mapping = attn_metadata.slot_mapping
@@ -1254,15 +1250,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_token_ids = self.input_ids[token_indices]
target_positions = positions[token_indices]
if self.use_aux_hidden_state_outputs:
target_hidden_states = [
h[token_indices] for h in aux_hidden_states
]
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states], dim=-1)
else:
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(target_hidden_states, dim=-1)
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
@@ -1506,6 +1499,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else:
hidden_states = outputs
if self.use_spec_decode and \
self.speculative_config.method in ('eagle', 'eagle3'):
assert isinstance(self.drafter, EagleProposer)
self.drafter.dummy_run(num_tokens)
logit_indices = np.cumsum(num_scheduled_tokens) - 1
return hidden_states[logit_indices]