[Model Runner V2] Gather multimodal embeddings before draft model postprocess (#37932)
Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
This commit is contained in:
@@ -1162,6 +1162,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
copy_event=self.output_copy_event,
|
||||
)
|
||||
|
||||
mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None
|
||||
if self.speculator is not None and self.speculator.supports_mm_inputs:
|
||||
# Get cached multimodal embeddings for draft forward.
|
||||
# NOTE: This is done here because postprocess updates
|
||||
# num_computed_prefill_tokens.
|
||||
prefill_lens = self.req_states.prefill_len.np[input_batch.idx_mapping_np]
|
||||
computed_prefill_lens = self.req_states.num_computed_prefill_tokens[
|
||||
input_batch.idx_mapping_np
|
||||
]
|
||||
mm_inputs = self.model_state.encoder_runner.gather_mm_embeddings(
|
||||
input_batch.req_ids,
|
||||
input_batch.num_tokens,
|
||||
input_batch.num_scheduled_tokens,
|
||||
input_batch.query_start_loc_np,
|
||||
prefill_lens,
|
||||
computed_prefill_lens + 1, # +1 to consider the skew in eagle
|
||||
)
|
||||
|
||||
# Postprocess results and update request states.
|
||||
# NOTE: This is intentionally done after creating the AsyncOutput,
|
||||
# ensuring that `copy_event` is recorded before calling postprocess.
|
||||
@@ -1173,24 +1191,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
if self.speculator is not None:
|
||||
assert self.sampler is not None
|
||||
mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None
|
||||
if self.speculator.supports_mm_inputs:
|
||||
# Get cached multimodal embeddings for draft forward.
|
||||
prefill_lens = self.req_states.prefill_len.np[
|
||||
input_batch.idx_mapping_np
|
||||
]
|
||||
computed_prefill_lens = self.req_states.num_computed_prefill_tokens[
|
||||
input_batch.idx_mapping_np
|
||||
]
|
||||
mm_inputs = self.model_state.encoder_runner.gather_mm_embeddings(
|
||||
input_batch.req_ids,
|
||||
input_batch.num_tokens,
|
||||
input_batch.num_scheduled_tokens,
|
||||
input_batch.query_start_loc_np,
|
||||
prefill_lens,
|
||||
computed_prefill_lens + 1, # + 1 to consider the skew in eagle
|
||||
)
|
||||
|
||||
draft_tokens = self.speculator.propose(
|
||||
input_batch,
|
||||
attn_metadata,
|
||||
|
||||
Reference in New Issue
Block a user