[V1][Spec Decode] Apply torch.compile & cudagraph to EAGLE (#17211)
Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
This commit is contained in:
@@ -1106,7 +1106,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# For mid-pipeline stages, return the hidden states.
|
||||
return hidden_states
|
||||
|
||||
hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
|
||||
@@ -1172,7 +1171,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Compute prompt logprobs if needed.
|
||||
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
||||
hidden_states,
|
||||
hidden_states[:num_scheduled_tokens],
|
||||
scheduler_output,
|
||||
)
|
||||
|
||||
@@ -1222,15 +1221,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
if spec_decode_metadata is None:
|
||||
# input_ids can be None for multimodal models.
|
||||
# We need to slice token_ids, positions, and hidden_states
|
||||
# because the eagle head does not use cuda graph and should
|
||||
# not include padding.
|
||||
target_token_ids = self.input_ids[:num_scheduled_tokens]
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = [
|
||||
h[:num_scheduled_tokens] for h in aux_hidden_states
|
||||
]
|
||||
target_hidden_states = torch.cat(
|
||||
[h[:num_scheduled_tokens] for h in aux_hidden_states],
|
||||
dim=-1)
|
||||
else:
|
||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
target_slot_mapping = attn_metadata.slot_mapping
|
||||
@@ -1254,15 +1250,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
target_token_ids = self.input_ids[token_indices]
|
||||
target_positions = positions[token_indices]
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = [
|
||||
h[token_indices] for h in aux_hidden_states
|
||||
]
|
||||
target_hidden_states = torch.cat(
|
||||
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||
else:
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
|
||||
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat(target_hidden_states, dim=-1)
|
||||
draft_token_ids = self.drafter.propose(
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
@@ -1506,6 +1499,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
hidden_states = outputs
|
||||
|
||||
if self.use_spec_decode and \
|
||||
self.speculative_config.method in ('eagle', 'eagle3'):
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
self.drafter.dummy_run(num_tokens)
|
||||
|
||||
logit_indices = np.cumsum(num_scheduled_tokens) - 1
|
||||
return hidden_states[logit_indices]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user