[Model Runner V2] Support multi-modal embeddings for spec decode model (#36097)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Co-authored-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Giancarlo Delfin
2026-03-22 02:48:43 -07:00
committed by GitHub
parent cd1242d82a
commit b3e846017d
2 changed files with 58 additions and 0 deletions

View File

@@ -430,6 +430,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# dummy run the eagle speculator's propose to ensure DP/EP sync.
if self.speculator is not None:
assert self.sampler is not None
mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None
if self.speculator.supports_mm_inputs:
mm_inputs = (
[],
torch.zeros(
input_batch.num_tokens,
dtype=torch.bool,
device=self.device,
),
)
self.speculator.propose(
input_batch=input_batch,
attn_metadata=attn_metadata,
@@ -449,6 +459,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_tokens_across_dp=num_tokens_across_dp,
dummy_run=True,
skip_attn_for_dummy_run=skip_attn,
mm_inputs=mm_inputs,
)
assert hidden_states is not None # Last PP rank always has hidden_states
@@ -1142,8 +1153,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.postprocess(
input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
)
if self.speculator is not None:
assert self.sampler is not None
mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None
if self.speculator.supports_mm_inputs:
# Get cached multimodal embeddings for draft forward.
prefill_lens = self.req_states.prefill_len.np[
input_batch.idx_mapping_np
]
computed_prefill_lens = self.req_states.num_computed_prefill_tokens[
input_batch.idx_mapping_np
]
mm_inputs = self.model_state.encoder_runner.gather_mm_embeddings(
input_batch.req_ids,
input_batch.num_tokens,
input_batch.num_scheduled_tokens,
input_batch.query_start_loc_np,
prefill_lens,
computed_prefill_lens + 1, # + 1 to consider the skew in eagle
)
draft_tokens = self.speculator.propose(
input_batch,
attn_metadata,
@@ -1157,6 +1187,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu,
num_tokens_across_dp=num_tokens_across_dp,
mm_inputs=mm_inputs,
)
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens)

View File

@@ -10,6 +10,7 @@ from vllm.config.compilation import CUDAGraphMode
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.triton_utils import tl, triton
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import (
@@ -76,6 +77,14 @@ class EagleSpeculator:
device=device,
)
self.supports_mm_inputs = MULTIMODAL_REGISTRY.supports_multimodal_inputs(
self.draft_model_config
)
if self.supports_mm_inputs:
self.inputs_embeds = torch.zeros(
self.max_num_tokens, self.hidden_size, dtype=self.dtype, device=device
)
cache_draft_logits = self.speculative_config.rejection_sample_method != "strict"
self.draft_logits: torch.Tensor | None = None
if cache_draft_logits:
@@ -138,6 +147,7 @@ class EagleSpeculator:
slot_mappings: dict[str, torch.Tensor] | None,
num_tokens_across_dp: torch.Tensor | None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
batch_descriptor = BatchDescriptor(num_tokens=num_tokens)
with set_forward_context(
@@ -149,10 +159,25 @@ class EagleSpeculator:
slot_mapping=slot_mappings,
batch_descriptor=batch_descriptor,
):
inputs_embeds = None
if self.supports_mm_inputs:
# Merge multimodal embeddings with input ids.
mm_embeds, is_mm_embed = mm_inputs or (None, None)
num_input_tokens = (
is_mm_embed.shape[0] if is_mm_embed is not None else num_tokens
)
self.inputs_embeds[:num_input_tokens] = self.model.embed_input_ids(
self.input_buffers.input_ids[:num_input_tokens],
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
)
inputs_embeds = self.inputs_embeds[:num_tokens]
ret_hidden_states = self.model(
input_ids=self.input_buffers.input_ids[:num_tokens],
positions=self.input_buffers.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens],
inputs_embeds=inputs_embeds,
)
if self.method == "mtp":
last_hidden_states = ret_hidden_states
@@ -254,6 +279,7 @@ class EagleSpeculator:
num_tokens_across_dp: torch.Tensor | None = None,
dummy_run: bool = False,
skip_attn_for_dummy_run: bool = False,
mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
) -> torch.Tensor:
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
# number of rejected tokens, we maintain the size of eagle's input_ids and
@@ -288,6 +314,7 @@ class EagleSpeculator:
attn_metadata,
slot_mappings,
num_tokens_across_dp=num_tokens_across_dp,
mm_inputs=mm_inputs,
)
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)