From bb1848cd62e5e5f327f307bee26e2b947f00396e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 18 Jan 2026 16:58:51 -0800 Subject: [PATCH] [Model Runner V2] Support VLM (#32546) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/cudagraph_utils.py | 7 + vllm/v1/worker/gpu/input_batch.py | 6 +- vllm/v1/worker/gpu/mm/encoder_runner.py | 184 ++++++++++++++++++++++++ vllm/v1/worker/gpu/mm/mrope_utils.py | 8 +- vllm/v1/worker/gpu/model_runner.py | 70 ++++++++- vllm/v1/worker/gpu/spec_decode/eagle.py | 3 - 6 files changed, 263 insertions(+), 15 deletions(-) create mode 100644 vllm/v1/worker/gpu/mm/encoder_runner.py diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index a19e6383b..9ae31177e 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -76,6 +76,7 @@ class CudaGraphManager: model: nn.Module, input_buffers: InputBuffers, mrope_positions: torch.Tensor | None, + inputs_embeds: torch.Tensor | None, block_tables: BlockTables, attn_metadata_builders: list[AttentionMetadataBuilder], kv_cache_config: KVCacheConfig, @@ -86,6 +87,8 @@ class CudaGraphManager: if self.uses_mrope: assert mrope_positions is not None positions = mrope_positions[:, :num_tokens] + if inputs_embeds is not None: + inputs_embeds = inputs_embeds[:num_tokens] attn_metadata = prepare_inputs_to_capture( num_reqs, num_tokens, @@ -108,6 +111,7 @@ class CudaGraphManager: hidden_states = model( input_ids=input_ids, positions=positions, + inputs_embeds=inputs_embeds, ) if self.hidden_states is None: self.hidden_states = torch.empty_like(hidden_states) @@ -128,6 +132,7 @@ class CudaGraphManager: hidden_states = model( input_ids=input_ids, positions=positions, + inputs_embeds=inputs_embeds, ) self.hidden_states[:num_tokens] = hidden_states self.graphs[num_tokens] = graph @@ -138,6 +143,7 @@ class CudaGraphManager: model: nn.Module, input_buffers: InputBuffers, mrope_positions: torch.Tensor | None, + inputs_embeds: torch.Tensor | None, block_tables: BlockTables, attn_metadata_builders: list[AttentionMetadataBuilder], kv_cache_config: KVCacheConfig, @@ -149,6 +155,7 @@ class CudaGraphManager: model=model, input_buffers=input_buffers, mrope_positions=mrope_positions, + inputs_embeds=inputs_embeds, block_tables=block_tables, attn_metadata_builders=attn_metadata_builders, kv_cache_config=kv_cache_config, diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index 00564710c..d6069c4cf 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -15,9 +15,6 @@ class InputBuffers: self, max_num_reqs: int, max_num_tokens: int, - inputs_embeds_size: int, - vocab_size: int, - dtype: torch.dtype, device: torch.device, ): self.max_num_reqs = max_num_reqs @@ -64,6 +61,8 @@ class InputBatch: positions: torch.Tensor # [3, num_tokens_after_padding] mrope_positions: torch.Tensor | None + # [num_tokens_after_padding, hidden_size] + inputs_embeds: torch.Tensor | None # layer_name -> Metadata attn_metadata: dict[str, Any] @@ -132,6 +131,7 @@ class InputBatch: input_ids=input_ids, positions=positions, mrope_positions=None, + inputs_embeds=None, attn_metadata=None, # type: ignore logits_indices=logits_indices, cu_num_logits=cu_num_logits, diff --git a/vllm/v1/worker/gpu/mm/encoder_runner.py b/vllm/v1/worker/gpu/mm/encoder_runner.py new file mode 100644 index 000000000..f9a0b50f3 --- /dev/null +++ b/vllm/v1/worker/gpu/mm/encoder_runner.py @@ -0,0 +1,184 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import numpy as np +import torch + +from vllm.model_executor.models.interfaces import SupportsMultiModal +from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItem +from vllm.multimodal.utils import group_mm_kwargs_by_modality +from vllm.v1.worker.gpu.buffer_utils import UvaBufferPool +from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs + + +class EncoderRunner: + def __init__( + self, + max_num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + device: torch.device, + ): + self.max_num_tokens = max_num_tokens + self.hidden_size = hidden_size + self.dtype = dtype + self.device = device + + self.inputs_embeds = torch.zeros( + max_num_tokens, + hidden_size, + dtype=dtype, + device=device, + ) + self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {} + self.encoder_cache: dict[str, torch.Tensor] = {} + + self.tmp_is_mm_embed = UvaBufferPool(max_num_tokens, torch.bool) + + def add_request(self, req_id: str, mm_features: list[MultiModalFeatureSpec]): + self.req_id_to_mm_features[req_id] = mm_features + + def free_encoder_cache(self, mm_hash: str) -> None: + self.encoder_cache.pop(mm_hash, None) + + def remove_request(self, req_id: str) -> None: + self.req_id_to_mm_features.pop(req_id, None) + + def prepare_mm_inputs( + self, + scheduled_encoder_inputs: dict[str, list[int]], + ) -> tuple[list[str], list[MultiModalKwargsItem]]: + mm_hashes: list[str] = [] + mm_kwargs: list[MultiModalKwargsItem] = [] + for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): + mm_features = self.req_id_to_mm_features[req_id] + for mm_input_id in encoder_input_ids: + mm_feature = mm_features[mm_input_id] + if mm_feature.data is None: + continue + mm_hashes.append(mm_feature.identifier) + mm_kwargs.append(mm_feature.data) + return mm_hashes, mm_kwargs + + @torch.inference_mode() + def execute_mm_encoder( + self, + model: SupportsMultiModal, + mm_hashes: list[str], + mm_kwargs: list[MultiModalKwargsItem], + ) -> list[torch.Tensor]: + if not mm_hashes: + return [] + + encoder_outputs: list[torch.Tensor] = [] + for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( + mm_kwargs, + device=self.device, + pin_memory=False, + ): + curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) + sanity_check_mm_encoder_outputs( + curr_group_outputs, + expected_num_items=num_items, + ) + encoder_outputs.extend(curr_group_outputs) + + # Cache the encoder outputs by mm_hash + for mm_hash, output in zip(mm_hashes, encoder_outputs): + self.encoder_cache[mm_hash] = output + return encoder_outputs + + def gather_mm_embeddings( + self, + req_ids: list[str], + total_num_scheduled_tokens: int, + num_scheduled_tokens: np.ndarray, + query_start_loc: np.ndarray, + prefill_lens: np.ndarray, + computed_prefill_lens: np.ndarray, + ) -> tuple[list[torch.Tensor], torch.Tensor]: + is_prefilling = (computed_prefill_lens < prefill_lens).tolist() + all_decode = not any(is_prefilling) + if all_decode: + # All decode requests, so no need to gather any embeddings. + return [], torch.zeros( + total_num_scheduled_tokens, + dtype=torch.bool, + device=self.device, + ) + + query_start = computed_prefill_lens.tolist() + query_end = (computed_prefill_lens + num_scheduled_tokens).tolist() + + mm_embeds: list[torch.Tensor] = [] + is_mm_embed = torch.zeros( + total_num_scheduled_tokens, + dtype=torch.bool, + device="cpu", + pin_memory=False, + ) + for i, req_id in enumerate(req_ids): + if not is_prefilling[i]: + # OPTIMIZATION: Skip decode requests. + continue + + mm_features = self.req_id_to_mm_features[req_id] + for mm_feature in mm_features: + pos_info = mm_feature.mm_position + start_pos = pos_info.offset + num_encoder_tokens = pos_info.length + + if start_pos >= query_end[i]: + # The encoder output is not needed in this step. + break + if start_pos + num_encoder_tokens <= query_start[i]: + # The encoder output is already processed and stored + # in the decoder's KV cache. + continue + + start_idx = max(query_start[i] - start_pos, 0) + end_idx = min(query_end[i] - start_pos, num_encoder_tokens) + assert start_idx < end_idx + curr_embeds_start, curr_embeds_end = ( + pos_info.get_embeds_indices_in_range(start_idx, end_idx) + ) + # If there are no embeddings in the current range, we skip + # gathering the embeddings. + if curr_embeds_start == curr_embeds_end: + continue + + mm_hash = mm_feature.identifier + encoder_output = self.encoder_cache.get(mm_hash, None) + assert encoder_output is not None, f"Encoder cache miss for {mm_hash}." + + if (is_embed := pos_info.is_embed) is not None: + is_embed = is_embed[start_idx:end_idx] + mm_embeds_item = encoder_output[curr_embeds_start:curr_embeds_end] + else: + mm_embeds_item = encoder_output[start_idx:end_idx] + + req_start_pos = query_start_loc[i] + start_pos - query_start[i] + is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = ( + True if is_embed is None else is_embed + ) + mm_embeds.append(mm_embeds_item) + + # Copy the is_mm_embed tensor to the GPU. + is_mm_embed = self.tmp_is_mm_embed.copy_to_gpu(is_mm_embed) + return mm_embeds, is_mm_embed + + @torch.inference_mode() + def get_inputs_embeds( + self, + model: SupportsMultiModal, + input_ids: torch.Tensor, + mm_embeds: list[torch.Tensor], + is_mm_embed: torch.Tensor, + ) -> torch.Tensor: + x = model.embed_input_ids( + input_ids, + multimodal_embeddings=mm_embeds, + is_multimodal=is_mm_embed, + ) + # Copy to the pre-allocated buffer for CUDA graphs. + self.inputs_embeds[: x.shape[0]] = x + return self.inputs_embeds diff --git a/vllm/v1/worker/gpu/mm/mrope_utils.py b/vllm/v1/worker/gpu/mm/mrope_utils.py index 4c915a5c9..968962114 100644 --- a/vllm/v1/worker/gpu/mm/mrope_utils.py +++ b/vllm/v1/worker/gpu/mm/mrope_utils.py @@ -23,7 +23,7 @@ class MRopeState: # NOTE(woosuk): This tensor can be extremely large (e.g., several GBs) # wasting a lot of CPU memory. self.prefill_mrope_positions = StagedWriteTensor( - (max_num_reqs, 3 * max_model_len), + (max_num_reqs * 3, max_model_len), dtype=torch.int32, device=device, uva_instead_of_gpu=True, @@ -58,9 +58,7 @@ class MRopeState: ) for i in range(3): pos = prefill_mrope_positions[i].tolist() - self.prefill_mrope_positions.stage_write( - req_idx, i * self.max_model_len, pos - ) + self.prefill_mrope_positions.stage_write(3 * req_idx + i, 0, pos) self.prefill_mrope_delta.np[req_idx] = prefill_mrope_delta def apply_staged_writes(self) -> None: @@ -79,7 +77,7 @@ class MRopeState: self.mrope_positions, self.mrope_positions.stride(0), self.prefill_mrope_positions.gpu, - self.prefill_mrope_positions.gpu.stride(0), + 3 * self.max_model_len, self.max_model_len, self.prefill_mrope_delta.gpu, idx_mapping, diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 63635640b..1dc844bb3 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -14,6 +14,7 @@ from vllm.config.compilation import CUDAGraphMode from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model_loader +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput @@ -47,6 +48,7 @@ from vllm.v1.worker.gpu.input_batch import ( prepare_pos_seq_lens, prepare_prefill_inputs, ) +from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState from vllm.v1.worker.gpu.sample.logprob import compute_prompt_logprobs from vllm.v1.worker.gpu.sample.output import SamplerOutput @@ -95,6 +97,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.inputs_embeds_size = self.model_config.get_inputs_embeds_size() # Multimodal + self.mm_registry = MULTIMODAL_REGISTRY + self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( + self.model_config + ) + if self.supports_mm_inputs: + self.encoder_runner = EncoderRunner( + max_num_tokens=self.max_num_tokens, + hidden_size=self.inputs_embeds_size, + dtype=self.dtype, + device=self.device, + ) self.uses_mrope = self.model_config.uses_mrope if self.uses_mrope: self.mrope_states = MRopeState( @@ -134,9 +147,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.input_buffers = InputBuffers( max_num_reqs=self.max_num_reqs, max_num_tokens=self.max_num_tokens, - inputs_embeds_size=self.inputs_embeds_size, - vocab_size=self.vocab_size, - dtype=self.dtype, device=self.device, ) self.sampler = Sampler( @@ -289,6 +299,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): input_batch.mrope_positions = self.mrope_states.mrope_positions[ :, :num_tokens ] + if self.supports_mm_inputs: + input_batch.inputs_embeds = self.encoder_runner.inputs_embeds[:num_tokens] if not skip_attn: self.prepare_dummy_attn_metadata(input_batch) @@ -314,6 +326,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): hidden_states = self.model( input_ids=input_batch.input_ids, positions=positions, + inputs_embeds=input_batch.inputs_embeds, ) sample_hidden_states = hidden_states[input_batch.logits_indices] return hidden_states, sample_hidden_states @@ -378,10 +391,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): mrope_positions = None if self.uses_mrope: mrope_positions = self.mrope_states.mrope_positions + inputs_embeds = None + if self.supports_mm_inputs: + inputs_embeds = self.encoder_runner.inputs_embeds self.cudagraph_manager.capture( model=self.model, input_buffers=self.input_buffers, mrope_positions=mrope_positions, + inputs_embeds=inputs_embeds, block_tables=self.block_tables, attn_metadata_builders=self.attn_metadata_builders, kv_cache_config=self.kv_cache_config, @@ -412,8 +429,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if scheduler_output.preempted_req_ids is not None: for req_id in scheduler_output.preempted_req_ids: self.req_states.remove_request(req_id) + if self.supports_mm_inputs: + self.encoder_runner.remove_request(req_id) for req_id in scheduler_output.finished_req_ids: self.req_states.remove_request(req_id) + if self.supports_mm_inputs: + self.encoder_runner.remove_request(req_id) + + if self.supports_mm_inputs: + for mm_hash in scheduler_output.free_encoder_mm_hashes: + self.encoder_runner.free_encoder_cache(mm_hash) # Add new requests. for new_req_data in scheduler_output.scheduled_new_reqs: @@ -432,13 +457,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) req_index = self.req_states.req_id_to_index[req_id] + if self.supports_mm_inputs: + self.encoder_runner.add_request(req_id, new_req_data.mm_features) + # Pre-compute M-RoPE positions for prefill. if self.uses_mrope: self.mrope_states.init_prefill_mrope_positions( req_index, self.model, # type: ignore new_req_data.prefill_token_ids, - mm_features=[], # TODO + mm_features=new_req_data.mm_features, ) self.block_tables.append_block_ids( @@ -632,12 +660,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): input_ids=input_ids, positions=positions, mrope_positions=mrope_positions, + inputs_embeds=None, attn_metadata=attn_metadata, logits_indices=logits_indices, cu_num_logits=cu_num_logits, cu_num_logits_np=cu_num_logits_np, ) + @torch.inference_mode() + def get_mm_embeddings( + self, + scheduled_encoder_inputs: dict[str, list[int]], + input_batch: InputBatch, + ) -> tuple[list[torch.Tensor], torch.Tensor]: + mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs( + scheduled_encoder_inputs + ) + self.encoder_runner.execute_mm_encoder(self.model, mm_hashes, mm_kwargs) + mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings( + input_batch.req_ids, + input_batch.num_tokens, + input_batch.num_scheduled_tokens, + input_batch.query_start_loc_np, + self.req_states.prefill_len.np[input_batch.idx_mapping_np], + self.req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np], + ) + return mm_embeds, is_mm_embed + def sample( self, hidden_states: torch.Tensor, @@ -930,6 +979,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): input_batch.num_scheduled_tokens, ) self._set_active_loras(*lora_inputs) + + if self.supports_mm_inputs: + # Execute the multimodal encoder. + mm_embeds, is_mm_embed = self.get_mm_embeddings( + scheduler_output.scheduled_encoder_inputs, input_batch + ) + inputs_embeds = self.encoder_runner.get_inputs_embeds( + self.model, input_batch.input_ids, mm_embeds, is_mm_embed + ) + input_batch.inputs_embeds = inputs_embeds[ + : input_batch.num_tokens_after_padding + ] else: # No actual tokens to run. A dummy run for DP. num_reqs = min(num_tokens_after_padding, self.max_num_reqs) @@ -970,6 +1031,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): hidden_states = self.model( input_ids=input_batch.input_ids, positions=positions, + inputs_embeds=input_batch.inputs_embeds, ) self.execute_model_state = hidden_states, input_batch diff --git a/vllm/v1/worker/gpu/spec_decode/eagle.py b/vllm/v1/worker/gpu/spec_decode/eagle.py index b4d1964f9..e8eeac7ec 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle.py @@ -48,9 +48,6 @@ class EagleSpeculator: self.input_buffers = InputBuffers( max_num_reqs=self.max_num_reqs, max_num_tokens=self.max_num_tokens, - inputs_embeds_size=self.inputs_embeds_size, - vocab_size=self.vocab_size, - dtype=self.dtype, device=device, ) self.hidden_states = torch.zeros(