[Model Runner V2] Gather multimodal embeddings before draft model postprocess (#37932)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
This commit is contained in:
Giancarlo Delfin
2026-03-23 18:14:13 -07:00
committed by GitHub
parent 56777b5c89
commit 8f4824b664

View File

@@ -1162,6 +1162,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
copy_event=self.output_copy_event,
)
mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None
if self.speculator is not None and self.speculator.supports_mm_inputs:
# Get cached multimodal embeddings for draft forward.
# NOTE: This is done here because postprocess updates
# num_computed_prefill_tokens.
prefill_lens = self.req_states.prefill_len.np[input_batch.idx_mapping_np]
computed_prefill_lens = self.req_states.num_computed_prefill_tokens[
input_batch.idx_mapping_np
]
mm_inputs = self.model_state.encoder_runner.gather_mm_embeddings(
input_batch.req_ids,
input_batch.num_tokens,
input_batch.num_scheduled_tokens,
input_batch.query_start_loc_np,
prefill_lens,
computed_prefill_lens + 1, # +1 to consider the skew in eagle
)
# Postprocess results and update request states.
# NOTE: This is intentionally done after creating the AsyncOutput,
# ensuring that `copy_event` is recorded before calling postprocess.
@@ -1173,24 +1191,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.speculator is not None:
assert self.sampler is not None
mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None
if self.speculator.supports_mm_inputs:
# Get cached multimodal embeddings for draft forward.
prefill_lens = self.req_states.prefill_len.np[
input_batch.idx_mapping_np
]
computed_prefill_lens = self.req_states.num_computed_prefill_tokens[
input_batch.idx_mapping_np
]
mm_inputs = self.model_state.encoder_runner.gather_mm_embeddings(
input_batch.req_ids,
input_batch.num_tokens,
input_batch.num_scheduled_tokens,
input_batch.query_start_loc_np,
prefill_lens,
computed_prefill_lens + 1, # + 1 to consider the skew in eagle
)
draft_tokens = self.speculator.propose(
input_batch,
attn_metadata,