diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 5788b31d2..d10530c95 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -430,6 +430,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): # dummy run the eagle speculator's propose to ensure DP/EP sync. if self.speculator is not None: assert self.sampler is not None + mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None + if self.speculator.supports_mm_inputs: + mm_inputs = ( + [], + torch.zeros( + input_batch.num_tokens, + dtype=torch.bool, + device=self.device, + ), + ) self.speculator.propose( input_batch=input_batch, attn_metadata=attn_metadata, @@ -449,6 +459,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_tokens_across_dp=num_tokens_across_dp, dummy_run=True, skip_attn_for_dummy_run=skip_attn, + mm_inputs=mm_inputs, ) assert hidden_states is not None # Last PP rank always has hidden_states @@ -1142,8 +1153,27 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.postprocess( input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected ) + if self.speculator is not None: assert self.sampler is not None + mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None + if self.speculator.supports_mm_inputs: + # Get cached multimodal embeddings for draft forward. + prefill_lens = self.req_states.prefill_len.np[ + input_batch.idx_mapping_np + ] + computed_prefill_lens = self.req_states.num_computed_prefill_tokens[ + input_batch.idx_mapping_np + ] + mm_inputs = self.model_state.encoder_runner.gather_mm_embeddings( + input_batch.req_ids, + input_batch.num_tokens, + input_batch.num_scheduled_tokens, + input_batch.query_start_loc_np, + prefill_lens, + computed_prefill_lens + 1, # + 1 to consider the skew in eagle + ) + draft_tokens = self.speculator.propose( input_batch, attn_metadata, @@ -1157,6 +1187,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.sampler.sampling_states.temperature.gpu, self.sampler.sampling_states.seeds.gpu, num_tokens_across_dp=num_tokens_across_dp, + mm_inputs=mm_inputs, ) self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens) diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py index 4df88bf95..bc001db8e 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py @@ -10,6 +10,7 @@ from vllm.config.compilation import CUDAGraphMode from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.triton_utils import tl, triton from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu.attn_utils import ( @@ -76,6 +77,14 @@ class EagleSpeculator: device=device, ) + self.supports_mm_inputs = MULTIMODAL_REGISTRY.supports_multimodal_inputs( + self.draft_model_config + ) + if self.supports_mm_inputs: + self.inputs_embeds = torch.zeros( + self.max_num_tokens, self.hidden_size, dtype=self.dtype, device=device + ) + cache_draft_logits = self.speculative_config.rejection_sample_method != "strict" self.draft_logits: torch.Tensor | None = None if cache_draft_logits: @@ -138,6 +147,7 @@ class EagleSpeculator: slot_mappings: dict[str, torch.Tensor] | None, num_tokens_across_dp: torch.Tensor | None, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: batch_descriptor = BatchDescriptor(num_tokens=num_tokens) with set_forward_context( @@ -149,10 +159,25 @@ class EagleSpeculator: slot_mapping=slot_mappings, batch_descriptor=batch_descriptor, ): + inputs_embeds = None + if self.supports_mm_inputs: + # Merge multimodal embeddings with input ids. + mm_embeds, is_mm_embed = mm_inputs or (None, None) + num_input_tokens = ( + is_mm_embed.shape[0] if is_mm_embed is not None else num_tokens + ) + self.inputs_embeds[:num_input_tokens] = self.model.embed_input_ids( + self.input_buffers.input_ids[:num_input_tokens], + multimodal_embeddings=mm_embeds, + is_multimodal=is_mm_embed, + ) + inputs_embeds = self.inputs_embeds[:num_tokens] + ret_hidden_states = self.model( input_ids=self.input_buffers.input_ids[:num_tokens], positions=self.input_buffers.positions[:num_tokens], hidden_states=self.hidden_states[:num_tokens], + inputs_embeds=inputs_embeds, ) if self.method == "mtp": last_hidden_states = ret_hidden_states @@ -254,6 +279,7 @@ class EagleSpeculator: num_tokens_across_dp: torch.Tensor | None = None, dummy_run: bool = False, skip_attn_for_dummy_run: bool = False, + mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, ) -> torch.Tensor: # NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the # number of rejected tokens, we maintain the size of eagle's input_ids and @@ -288,6 +314,7 @@ class EagleSpeculator: attn_metadata, slot_mappings, num_tokens_across_dp=num_tokens_across_dp, + mm_inputs=mm_inputs, ) sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states)