[Model Runner V2] Support multi-modal embeddings for spec decode model (#36097)
Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai> Signed-off-by: Woosuk Kwon <woosuk@inferact.ai> Co-authored-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -430,6 +430,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# dummy run the eagle speculator's propose to ensure DP/EP sync.
|
||||
if self.speculator is not None:
|
||||
assert self.sampler is not None
|
||||
mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None
|
||||
if self.speculator.supports_mm_inputs:
|
||||
mm_inputs = (
|
||||
[],
|
||||
torch.zeros(
|
||||
input_batch.num_tokens,
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
self.speculator.propose(
|
||||
input_batch=input_batch,
|
||||
attn_metadata=attn_metadata,
|
||||
@@ -449,6 +459,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
dummy_run=True,
|
||||
skip_attn_for_dummy_run=skip_attn,
|
||||
mm_inputs=mm_inputs,
|
||||
)
|
||||
|
||||
assert hidden_states is not None # Last PP rank always has hidden_states
|
||||
@@ -1142,8 +1153,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.postprocess(
|
||||
input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
|
||||
)
|
||||
|
||||
if self.speculator is not None:
|
||||
assert self.sampler is not None
|
||||
mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None
|
||||
if self.speculator.supports_mm_inputs:
|
||||
# Get cached multimodal embeddings for draft forward.
|
||||
prefill_lens = self.req_states.prefill_len.np[
|
||||
input_batch.idx_mapping_np
|
||||
]
|
||||
computed_prefill_lens = self.req_states.num_computed_prefill_tokens[
|
||||
input_batch.idx_mapping_np
|
||||
]
|
||||
mm_inputs = self.model_state.encoder_runner.gather_mm_embeddings(
|
||||
input_batch.req_ids,
|
||||
input_batch.num_tokens,
|
||||
input_batch.num_scheduled_tokens,
|
||||
input_batch.query_start_loc_np,
|
||||
prefill_lens,
|
||||
computed_prefill_lens + 1, # + 1 to consider the skew in eagle
|
||||
)
|
||||
|
||||
draft_tokens = self.speculator.propose(
|
||||
input_batch,
|
||||
attn_metadata,
|
||||
@@ -1157,6 +1187,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.sampler.sampling_states.temperature.gpu,
|
||||
self.sampler.sampling_states.seeds.gpu,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
mm_inputs=mm_inputs,
|
||||
)
|
||||
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
|
||||
self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens)
|
||||
|
||||
@@ -10,6 +10,7 @@ from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
@@ -76,6 +77,14 @@ class EagleSpeculator:
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.supports_mm_inputs = MULTIMODAL_REGISTRY.supports_multimodal_inputs(
|
||||
self.draft_model_config
|
||||
)
|
||||
if self.supports_mm_inputs:
|
||||
self.inputs_embeds = torch.zeros(
|
||||
self.max_num_tokens, self.hidden_size, dtype=self.dtype, device=device
|
||||
)
|
||||
|
||||
cache_draft_logits = self.speculative_config.rejection_sample_method != "strict"
|
||||
self.draft_logits: torch.Tensor | None = None
|
||||
if cache_draft_logits:
|
||||
@@ -138,6 +147,7 @@ class EagleSpeculator:
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
num_tokens_across_dp: torch.Tensor | None,
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_descriptor = BatchDescriptor(num_tokens=num_tokens)
|
||||
with set_forward_context(
|
||||
@@ -149,10 +159,25 @@ class EagleSpeculator:
|
||||
slot_mapping=slot_mappings,
|
||||
batch_descriptor=batch_descriptor,
|
||||
):
|
||||
inputs_embeds = None
|
||||
if self.supports_mm_inputs:
|
||||
# Merge multimodal embeddings with input ids.
|
||||
mm_embeds, is_mm_embed = mm_inputs or (None, None)
|
||||
num_input_tokens = (
|
||||
is_mm_embed.shape[0] if is_mm_embed is not None else num_tokens
|
||||
)
|
||||
self.inputs_embeds[:num_input_tokens] = self.model.embed_input_ids(
|
||||
self.input_buffers.input_ids[:num_input_tokens],
|
||||
multimodal_embeddings=mm_embeds,
|
||||
is_multimodal=is_mm_embed,
|
||||
)
|
||||
inputs_embeds = self.inputs_embeds[:num_tokens]
|
||||
|
||||
ret_hidden_states = self.model(
|
||||
input_ids=self.input_buffers.input_ids[:num_tokens],
|
||||
positions=self.input_buffers.positions[:num_tokens],
|
||||
hidden_states=self.hidden_states[:num_tokens],
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if self.method == "mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
@@ -254,6 +279,7 @@ class EagleSpeculator:
|
||||
num_tokens_across_dp: torch.Tensor | None = None,
|
||||
dummy_run: bool = False,
|
||||
skip_attn_for_dummy_run: bool = False,
|
||||
mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
|
||||
# number of rejected tokens, we maintain the size of eagle's input_ids and
|
||||
@@ -288,6 +314,7 @@ class EagleSpeculator:
|
||||
attn_metadata,
|
||||
slot_mappings,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
mm_inputs=mm_inputs,
|
||||
)
|
||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user