[Model Runner V2] Support VLM (#32546)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2026-01-18 16:58:51 -08:00
committed by GitHub
parent 6101a26dc9
commit bb1848cd62
6 changed files with 263 additions and 15 deletions

View File

@@ -76,6 +76,7 @@ class CudaGraphManager:
model: nn.Module,
input_buffers: InputBuffers,
mrope_positions: torch.Tensor | None,
inputs_embeds: torch.Tensor | None,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
@@ -86,6 +87,8 @@ class CudaGraphManager:
if self.uses_mrope:
assert mrope_positions is not None
positions = mrope_positions[:, :num_tokens]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:num_tokens]
attn_metadata = prepare_inputs_to_capture(
num_reqs,
num_tokens,
@@ -108,6 +111,7 @@ class CudaGraphManager:
hidden_states = model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states)
@@ -128,6 +132,7 @@ class CudaGraphManager:
hidden_states = model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
self.hidden_states[:num_tokens] = hidden_states
self.graphs[num_tokens] = graph
@@ -138,6 +143,7 @@ class CudaGraphManager:
model: nn.Module,
input_buffers: InputBuffers,
mrope_positions: torch.Tensor | None,
inputs_embeds: torch.Tensor | None,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
@@ -149,6 +155,7 @@ class CudaGraphManager:
model=model,
input_buffers=input_buffers,
mrope_positions=mrope_positions,
inputs_embeds=inputs_embeds,
block_tables=block_tables,
attn_metadata_builders=attn_metadata_builders,
kv_cache_config=kv_cache_config,

View File

@@ -15,9 +15,6 @@ class InputBuffers:
self,
max_num_reqs: int,
max_num_tokens: int,
inputs_embeds_size: int,
vocab_size: int,
dtype: torch.dtype,
device: torch.device,
):
self.max_num_reqs = max_num_reqs
@@ -64,6 +61,8 @@ class InputBatch:
positions: torch.Tensor
# [3, num_tokens_after_padding]
mrope_positions: torch.Tensor | None
# [num_tokens_after_padding, hidden_size]
inputs_embeds: torch.Tensor | None
# layer_name -> Metadata
attn_metadata: dict[str, Any]
@@ -132,6 +131,7 @@ class InputBatch:
input_ids=input_ids,
positions=positions,
mrope_positions=None,
inputs_embeds=None,
attn_metadata=None, # type: ignore
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,

View File

@@ -0,0 +1,184 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItem
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.v1.worker.gpu.buffer_utils import UvaBufferPool
from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs
class EncoderRunner:
def __init__(
self,
max_num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
device: torch.device,
):
self.max_num_tokens = max_num_tokens
self.hidden_size = hidden_size
self.dtype = dtype
self.device = device
self.inputs_embeds = torch.zeros(
max_num_tokens,
hidden_size,
dtype=dtype,
device=device,
)
self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
self.encoder_cache: dict[str, torch.Tensor] = {}
self.tmp_is_mm_embed = UvaBufferPool(max_num_tokens, torch.bool)
def add_request(self, req_id: str, mm_features: list[MultiModalFeatureSpec]):
self.req_id_to_mm_features[req_id] = mm_features
def free_encoder_cache(self, mm_hash: str) -> None:
self.encoder_cache.pop(mm_hash, None)
def remove_request(self, req_id: str) -> None:
self.req_id_to_mm_features.pop(req_id, None)
def prepare_mm_inputs(
self,
scheduled_encoder_inputs: dict[str, list[int]],
) -> tuple[list[str], list[MultiModalKwargsItem]]:
mm_hashes: list[str] = []
mm_kwargs: list[MultiModalKwargsItem] = []
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
mm_features = self.req_id_to_mm_features[req_id]
for mm_input_id in encoder_input_ids:
mm_feature = mm_features[mm_input_id]
if mm_feature.data is None:
continue
mm_hashes.append(mm_feature.identifier)
mm_kwargs.append(mm_feature.data)
return mm_hashes, mm_kwargs
@torch.inference_mode()
def execute_mm_encoder(
self,
model: SupportsMultiModal,
mm_hashes: list[str],
mm_kwargs: list[MultiModalKwargsItem],
) -> list[torch.Tensor]:
if not mm_hashes:
return []
encoder_outputs: list[torch.Tensor] = []
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs,
device=self.device,
pin_memory=False,
):
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
sanity_check_mm_encoder_outputs(
curr_group_outputs,
expected_num_items=num_items,
)
encoder_outputs.extend(curr_group_outputs)
# Cache the encoder outputs by mm_hash
for mm_hash, output in zip(mm_hashes, encoder_outputs):
self.encoder_cache[mm_hash] = output
return encoder_outputs
def gather_mm_embeddings(
self,
req_ids: list[str],
total_num_scheduled_tokens: int,
num_scheduled_tokens: np.ndarray,
query_start_loc: np.ndarray,
prefill_lens: np.ndarray,
computed_prefill_lens: np.ndarray,
) -> tuple[list[torch.Tensor], torch.Tensor]:
is_prefilling = (computed_prefill_lens < prefill_lens).tolist()
all_decode = not any(is_prefilling)
if all_decode:
# All decode requests, so no need to gather any embeddings.
return [], torch.zeros(
total_num_scheduled_tokens,
dtype=torch.bool,
device=self.device,
)
query_start = computed_prefill_lens.tolist()
query_end = (computed_prefill_lens + num_scheduled_tokens).tolist()
mm_embeds: list[torch.Tensor] = []
is_mm_embed = torch.zeros(
total_num_scheduled_tokens,
dtype=torch.bool,
device="cpu",
pin_memory=False,
)
for i, req_id in enumerate(req_ids):
if not is_prefilling[i]:
# OPTIMIZATION: Skip decode requests.
continue
mm_features = self.req_id_to_mm_features[req_id]
for mm_feature in mm_features:
pos_info = mm_feature.mm_position
start_pos = pos_info.offset
num_encoder_tokens = pos_info.length
if start_pos >= query_end[i]:
# The encoder output is not needed in this step.
break
if start_pos + num_encoder_tokens <= query_start[i]:
# The encoder output is already processed and stored
# in the decoder's KV cache.
continue
start_idx = max(query_start[i] - start_pos, 0)
end_idx = min(query_end[i] - start_pos, num_encoder_tokens)
assert start_idx < end_idx
curr_embeds_start, curr_embeds_end = (
pos_info.get_embeds_indices_in_range(start_idx, end_idx)
)
# If there are no embeddings in the current range, we skip
# gathering the embeddings.
if curr_embeds_start == curr_embeds_end:
continue
mm_hash = mm_feature.identifier
encoder_output = self.encoder_cache.get(mm_hash, None)
assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx]
mm_embeds_item = encoder_output[curr_embeds_start:curr_embeds_end]
else:
mm_embeds_item = encoder_output[start_idx:end_idx]
req_start_pos = query_start_loc[i] + start_pos - query_start[i]
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
True if is_embed is None else is_embed
)
mm_embeds.append(mm_embeds_item)
# Copy the is_mm_embed tensor to the GPU.
is_mm_embed = self.tmp_is_mm_embed.copy_to_gpu(is_mm_embed)
return mm_embeds, is_mm_embed
@torch.inference_mode()
def get_inputs_embeds(
self,
model: SupportsMultiModal,
input_ids: torch.Tensor,
mm_embeds: list[torch.Tensor],
is_mm_embed: torch.Tensor,
) -> torch.Tensor:
x = model.embed_input_ids(
input_ids,
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
)
# Copy to the pre-allocated buffer for CUDA graphs.
self.inputs_embeds[: x.shape[0]] = x
return self.inputs_embeds

View File

@@ -23,7 +23,7 @@ class MRopeState:
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
# wasting a lot of CPU memory.
self.prefill_mrope_positions = StagedWriteTensor(
(max_num_reqs, 3 * max_model_len),
(max_num_reqs * 3, max_model_len),
dtype=torch.int32,
device=device,
uva_instead_of_gpu=True,
@@ -58,9 +58,7 @@ class MRopeState:
)
for i in range(3):
pos = prefill_mrope_positions[i].tolist()
self.prefill_mrope_positions.stage_write(
req_idx, i * self.max_model_len, pos
)
self.prefill_mrope_positions.stage_write(3 * req_idx + i, 0, pos)
self.prefill_mrope_delta.np[req_idx] = prefill_mrope_delta
def apply_staged_writes(self) -> None:
@@ -79,7 +77,7 @@ class MRopeState:
self.mrope_positions,
self.mrope_positions.stride(0),
self.prefill_mrope_positions.gpu,
self.prefill_mrope_positions.gpu.stride(0),
3 * self.max_model_len,
self.max_model_len,
self.prefill_mrope_delta.gpu,
idx_mapping,

View File

@@ -14,6 +14,7 @@ from vllm.config.compilation import CUDAGraphMode
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
@@ -47,6 +48,7 @@ from vllm.v1.worker.gpu.input_batch import (
prepare_pos_seq_lens,
prepare_prefill_inputs,
)
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
from vllm.v1.worker.gpu.sample.logprob import compute_prompt_logprobs
from vllm.v1.worker.gpu.sample.output import SamplerOutput
@@ -95,6 +97,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
# Multimodal
self.mm_registry = MULTIMODAL_REGISTRY
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
self.model_config
)
if self.supports_mm_inputs:
self.encoder_runner = EncoderRunner(
max_num_tokens=self.max_num_tokens,
hidden_size=self.inputs_embeds_size,
dtype=self.dtype,
device=self.device,
)
self.uses_mrope = self.model_config.uses_mrope
if self.uses_mrope:
self.mrope_states = MRopeState(
@@ -134,9 +147,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
inputs_embeds_size=self.inputs_embeds_size,
vocab_size=self.vocab_size,
dtype=self.dtype,
device=self.device,
)
self.sampler = Sampler(
@@ -289,6 +299,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch.mrope_positions = self.mrope_states.mrope_positions[
:, :num_tokens
]
if self.supports_mm_inputs:
input_batch.inputs_embeds = self.encoder_runner.inputs_embeds[:num_tokens]
if not skip_attn:
self.prepare_dummy_attn_metadata(input_batch)
@@ -314,6 +326,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
hidden_states = self.model(
input_ids=input_batch.input_ids,
positions=positions,
inputs_embeds=input_batch.inputs_embeds,
)
sample_hidden_states = hidden_states[input_batch.logits_indices]
return hidden_states, sample_hidden_states
@@ -378,10 +391,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mrope_positions = None
if self.uses_mrope:
mrope_positions = self.mrope_states.mrope_positions
inputs_embeds = None
if self.supports_mm_inputs:
inputs_embeds = self.encoder_runner.inputs_embeds
self.cudagraph_manager.capture(
model=self.model,
input_buffers=self.input_buffers,
mrope_positions=mrope_positions,
inputs_embeds=inputs_embeds,
block_tables=self.block_tables,
attn_metadata_builders=self.attn_metadata_builders,
kv_cache_config=self.kv_cache_config,
@@ -412,8 +429,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if scheduler_output.preempted_req_ids is not None:
for req_id in scheduler_output.preempted_req_ids:
self.req_states.remove_request(req_id)
if self.supports_mm_inputs:
self.encoder_runner.remove_request(req_id)
for req_id in scheduler_output.finished_req_ids:
self.req_states.remove_request(req_id)
if self.supports_mm_inputs:
self.encoder_runner.remove_request(req_id)
if self.supports_mm_inputs:
for mm_hash in scheduler_output.free_encoder_mm_hashes:
self.encoder_runner.free_encoder_cache(mm_hash)
# Add new requests.
for new_req_data in scheduler_output.scheduled_new_reqs:
@@ -432,13 +457,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
req_index = self.req_states.req_id_to_index[req_id]
if self.supports_mm_inputs:
self.encoder_runner.add_request(req_id, new_req_data.mm_features)
# Pre-compute M-RoPE positions for prefill.
if self.uses_mrope:
self.mrope_states.init_prefill_mrope_positions(
req_index,
self.model, # type: ignore
new_req_data.prefill_token_ids,
mm_features=[], # TODO
mm_features=new_req_data.mm_features,
)
self.block_tables.append_block_ids(
@@ -632,12 +660,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_ids=input_ids,
positions=positions,
mrope_positions=mrope_positions,
inputs_embeds=None,
attn_metadata=attn_metadata,
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np,
)
@torch.inference_mode()
def get_mm_embeddings(
self,
scheduled_encoder_inputs: dict[str, list[int]],
input_batch: InputBatch,
) -> tuple[list[torch.Tensor], torch.Tensor]:
mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
scheduled_encoder_inputs
)
self.encoder_runner.execute_mm_encoder(self.model, mm_hashes, mm_kwargs)
mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings(
input_batch.req_ids,
input_batch.num_tokens,
input_batch.num_scheduled_tokens,
input_batch.query_start_loc_np,
self.req_states.prefill_len.np[input_batch.idx_mapping_np],
self.req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np],
)
return mm_embeds, is_mm_embed
def sample(
self,
hidden_states: torch.Tensor,
@@ -930,6 +979,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch.num_scheduled_tokens,
)
self._set_active_loras(*lora_inputs)
if self.supports_mm_inputs:
# Execute the multimodal encoder.
mm_embeds, is_mm_embed = self.get_mm_embeddings(
scheduler_output.scheduled_encoder_inputs, input_batch
)
inputs_embeds = self.encoder_runner.get_inputs_embeds(
self.model, input_batch.input_ids, mm_embeds, is_mm_embed
)
input_batch.inputs_embeds = inputs_embeds[
: input_batch.num_tokens_after_padding
]
else:
# No actual tokens to run. A dummy run for DP.
num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
@@ -970,6 +1031,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
hidden_states = self.model(
input_ids=input_batch.input_ids,
positions=positions,
inputs_embeds=input_batch.inputs_embeds,
)
self.execute_model_state = hidden_states, input_batch

View File

@@ -48,9 +48,6 @@ class EagleSpeculator:
self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
inputs_embeds_size=self.inputs_embeds_size,
vocab_size=self.vocab_size,
dtype=self.dtype,
device=device,
)
self.hidden_states = torch.zeros(