[Model Runner V2] Support VLM (#32546)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -76,6 +76,7 @@ class CudaGraphManager:
|
||||
model: nn.Module,
|
||||
input_buffers: InputBuffers,
|
||||
mrope_positions: torch.Tensor | None,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
block_tables: BlockTables,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
@@ -86,6 +87,8 @@ class CudaGraphManager:
|
||||
if self.uses_mrope:
|
||||
assert mrope_positions is not None
|
||||
positions = mrope_positions[:, :num_tokens]
|
||||
if inputs_embeds is not None:
|
||||
inputs_embeds = inputs_embeds[:num_tokens]
|
||||
attn_metadata = prepare_inputs_to_capture(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
@@ -108,6 +111,7 @@ class CudaGraphManager:
|
||||
hidden_states = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if self.hidden_states is None:
|
||||
self.hidden_states = torch.empty_like(hidden_states)
|
||||
@@ -128,6 +132,7 @@ class CudaGraphManager:
|
||||
hidden_states = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
self.hidden_states[:num_tokens] = hidden_states
|
||||
self.graphs[num_tokens] = graph
|
||||
@@ -138,6 +143,7 @@ class CudaGraphManager:
|
||||
model: nn.Module,
|
||||
input_buffers: InputBuffers,
|
||||
mrope_positions: torch.Tensor | None,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
block_tables: BlockTables,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
@@ -149,6 +155,7 @@ class CudaGraphManager:
|
||||
model=model,
|
||||
input_buffers=input_buffers,
|
||||
mrope_positions=mrope_positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
block_tables=block_tables,
|
||||
attn_metadata_builders=attn_metadata_builders,
|
||||
kv_cache_config=kv_cache_config,
|
||||
|
||||
@@ -15,9 +15,6 @@ class InputBuffers:
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_num_tokens: int,
|
||||
inputs_embeds_size: int,
|
||||
vocab_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
@@ -64,6 +61,8 @@ class InputBatch:
|
||||
positions: torch.Tensor
|
||||
# [3, num_tokens_after_padding]
|
||||
mrope_positions: torch.Tensor | None
|
||||
# [num_tokens_after_padding, hidden_size]
|
||||
inputs_embeds: torch.Tensor | None
|
||||
|
||||
# layer_name -> Metadata
|
||||
attn_metadata: dict[str, Any]
|
||||
@@ -132,6 +131,7 @@ class InputBatch:
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
mrope_positions=None,
|
||||
inputs_embeds=None,
|
||||
attn_metadata=None, # type: ignore
|
||||
logits_indices=logits_indices,
|
||||
cu_num_logits=cu_num_logits,
|
||||
|
||||
184
vllm/v1/worker/gpu/mm/encoder_runner.py
Normal file
184
vllm/v1/worker/gpu/mm/encoder_runner.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItem
|
||||
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
||||
from vllm.v1.worker.gpu.buffer_utils import UvaBufferPool
|
||||
from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs
|
||||
|
||||
|
||||
class EncoderRunner:
|
||||
def __init__(
|
||||
self,
|
||||
max_num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
self.inputs_embeds = torch.zeros(
|
||||
max_num_tokens,
|
||||
hidden_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
|
||||
self.encoder_cache: dict[str, torch.Tensor] = {}
|
||||
|
||||
self.tmp_is_mm_embed = UvaBufferPool(max_num_tokens, torch.bool)
|
||||
|
||||
def add_request(self, req_id: str, mm_features: list[MultiModalFeatureSpec]):
|
||||
self.req_id_to_mm_features[req_id] = mm_features
|
||||
|
||||
def free_encoder_cache(self, mm_hash: str) -> None:
|
||||
self.encoder_cache.pop(mm_hash, None)
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
self.req_id_to_mm_features.pop(req_id, None)
|
||||
|
||||
def prepare_mm_inputs(
|
||||
self,
|
||||
scheduled_encoder_inputs: dict[str, list[int]],
|
||||
) -> tuple[list[str], list[MultiModalKwargsItem]]:
|
||||
mm_hashes: list[str] = []
|
||||
mm_kwargs: list[MultiModalKwargsItem] = []
|
||||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||
mm_features = self.req_id_to_mm_features[req_id]
|
||||
for mm_input_id in encoder_input_ids:
|
||||
mm_feature = mm_features[mm_input_id]
|
||||
if mm_feature.data is None:
|
||||
continue
|
||||
mm_hashes.append(mm_feature.identifier)
|
||||
mm_kwargs.append(mm_feature.data)
|
||||
return mm_hashes, mm_kwargs
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_mm_encoder(
|
||||
self,
|
||||
model: SupportsMultiModal,
|
||||
mm_hashes: list[str],
|
||||
mm_kwargs: list[MultiModalKwargsItem],
|
||||
) -> list[torch.Tensor]:
|
||||
if not mm_hashes:
|
||||
return []
|
||||
|
||||
encoder_outputs: list[torch.Tensor] = []
|
||||
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||
mm_kwargs,
|
||||
device=self.device,
|
||||
pin_memory=False,
|
||||
):
|
||||
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
|
||||
sanity_check_mm_encoder_outputs(
|
||||
curr_group_outputs,
|
||||
expected_num_items=num_items,
|
||||
)
|
||||
encoder_outputs.extend(curr_group_outputs)
|
||||
|
||||
# Cache the encoder outputs by mm_hash
|
||||
for mm_hash, output in zip(mm_hashes, encoder_outputs):
|
||||
self.encoder_cache[mm_hash] = output
|
||||
return encoder_outputs
|
||||
|
||||
def gather_mm_embeddings(
|
||||
self,
|
||||
req_ids: list[str],
|
||||
total_num_scheduled_tokens: int,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
query_start_loc: np.ndarray,
|
||||
prefill_lens: np.ndarray,
|
||||
computed_prefill_lens: np.ndarray,
|
||||
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
||||
is_prefilling = (computed_prefill_lens < prefill_lens).tolist()
|
||||
all_decode = not any(is_prefilling)
|
||||
if all_decode:
|
||||
# All decode requests, so no need to gather any embeddings.
|
||||
return [], torch.zeros(
|
||||
total_num_scheduled_tokens,
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
query_start = computed_prefill_lens.tolist()
|
||||
query_end = (computed_prefill_lens + num_scheduled_tokens).tolist()
|
||||
|
||||
mm_embeds: list[torch.Tensor] = []
|
||||
is_mm_embed = torch.zeros(
|
||||
total_num_scheduled_tokens,
|
||||
dtype=torch.bool,
|
||||
device="cpu",
|
||||
pin_memory=False,
|
||||
)
|
||||
for i, req_id in enumerate(req_ids):
|
||||
if not is_prefilling[i]:
|
||||
# OPTIMIZATION: Skip decode requests.
|
||||
continue
|
||||
|
||||
mm_features = self.req_id_to_mm_features[req_id]
|
||||
for mm_feature in mm_features:
|
||||
pos_info = mm_feature.mm_position
|
||||
start_pos = pos_info.offset
|
||||
num_encoder_tokens = pos_info.length
|
||||
|
||||
if start_pos >= query_end[i]:
|
||||
# The encoder output is not needed in this step.
|
||||
break
|
||||
if start_pos + num_encoder_tokens <= query_start[i]:
|
||||
# The encoder output is already processed and stored
|
||||
# in the decoder's KV cache.
|
||||
continue
|
||||
|
||||
start_idx = max(query_start[i] - start_pos, 0)
|
||||
end_idx = min(query_end[i] - start_pos, num_encoder_tokens)
|
||||
assert start_idx < end_idx
|
||||
curr_embeds_start, curr_embeds_end = (
|
||||
pos_info.get_embeds_indices_in_range(start_idx, end_idx)
|
||||
)
|
||||
# If there are no embeddings in the current range, we skip
|
||||
# gathering the embeddings.
|
||||
if curr_embeds_start == curr_embeds_end:
|
||||
continue
|
||||
|
||||
mm_hash = mm_feature.identifier
|
||||
encoder_output = self.encoder_cache.get(mm_hash, None)
|
||||
assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
|
||||
|
||||
if (is_embed := pos_info.is_embed) is not None:
|
||||
is_embed = is_embed[start_idx:end_idx]
|
||||
mm_embeds_item = encoder_output[curr_embeds_start:curr_embeds_end]
|
||||
else:
|
||||
mm_embeds_item = encoder_output[start_idx:end_idx]
|
||||
|
||||
req_start_pos = query_start_loc[i] + start_pos - query_start[i]
|
||||
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
|
||||
True if is_embed is None else is_embed
|
||||
)
|
||||
mm_embeds.append(mm_embeds_item)
|
||||
|
||||
# Copy the is_mm_embed tensor to the GPU.
|
||||
is_mm_embed = self.tmp_is_mm_embed.copy_to_gpu(is_mm_embed)
|
||||
return mm_embeds, is_mm_embed
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_inputs_embeds(
|
||||
self,
|
||||
model: SupportsMultiModal,
|
||||
input_ids: torch.Tensor,
|
||||
mm_embeds: list[torch.Tensor],
|
||||
is_mm_embed: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x = model.embed_input_ids(
|
||||
input_ids,
|
||||
multimodal_embeddings=mm_embeds,
|
||||
is_multimodal=is_mm_embed,
|
||||
)
|
||||
# Copy to the pre-allocated buffer for CUDA graphs.
|
||||
self.inputs_embeds[: x.shape[0]] = x
|
||||
return self.inputs_embeds
|
||||
@@ -23,7 +23,7 @@ class MRopeState:
|
||||
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
|
||||
# wasting a lot of CPU memory.
|
||||
self.prefill_mrope_positions = StagedWriteTensor(
|
||||
(max_num_reqs, 3 * max_model_len),
|
||||
(max_num_reqs * 3, max_model_len),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
uva_instead_of_gpu=True,
|
||||
@@ -58,9 +58,7 @@ class MRopeState:
|
||||
)
|
||||
for i in range(3):
|
||||
pos = prefill_mrope_positions[i].tolist()
|
||||
self.prefill_mrope_positions.stage_write(
|
||||
req_idx, i * self.max_model_len, pos
|
||||
)
|
||||
self.prefill_mrope_positions.stage_write(3 * req_idx + i, 0, pos)
|
||||
self.prefill_mrope_delta.np[req_idx] = prefill_mrope_delta
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
@@ -79,7 +77,7 @@ class MRopeState:
|
||||
self.mrope_positions,
|
||||
self.mrope_positions.stride(0),
|
||||
self.prefill_mrope_positions.gpu,
|
||||
self.prefill_mrope_positions.gpu.stride(0),
|
||||
3 * self.max_model_len,
|
||||
self.max_model_len,
|
||||
self.prefill_mrope_delta.gpu,
|
||||
idx_mapping,
|
||||
|
||||
@@ -14,6 +14,7 @@ from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
@@ -47,6 +48,7 @@ from vllm.v1.worker.gpu.input_batch import (
|
||||
prepare_pos_seq_lens,
|
||||
prepare_prefill_inputs,
|
||||
)
|
||||
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
|
||||
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
|
||||
from vllm.v1.worker.gpu.sample.logprob import compute_prompt_logprobs
|
||||
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
||||
@@ -95,6 +97,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
|
||||
|
||||
# Multimodal
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
|
||||
self.model_config
|
||||
)
|
||||
if self.supports_mm_inputs:
|
||||
self.encoder_runner = EncoderRunner(
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
hidden_size=self.inputs_embeds_size,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.uses_mrope = self.model_config.uses_mrope
|
||||
if self.uses_mrope:
|
||||
self.mrope_states = MRopeState(
|
||||
@@ -134,9 +147,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.input_buffers = InputBuffers(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
inputs_embeds_size=self.inputs_embeds_size,
|
||||
vocab_size=self.vocab_size,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.sampler = Sampler(
|
||||
@@ -289,6 +299,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
input_batch.mrope_positions = self.mrope_states.mrope_positions[
|
||||
:, :num_tokens
|
||||
]
|
||||
if self.supports_mm_inputs:
|
||||
input_batch.inputs_embeds = self.encoder_runner.inputs_embeds[:num_tokens]
|
||||
if not skip_attn:
|
||||
self.prepare_dummy_attn_metadata(input_batch)
|
||||
|
||||
@@ -314,6 +326,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_batch.input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=input_batch.inputs_embeds,
|
||||
)
|
||||
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
return hidden_states, sample_hidden_states
|
||||
@@ -378,10 +391,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
mrope_positions = None
|
||||
if self.uses_mrope:
|
||||
mrope_positions = self.mrope_states.mrope_positions
|
||||
inputs_embeds = None
|
||||
if self.supports_mm_inputs:
|
||||
inputs_embeds = self.encoder_runner.inputs_embeds
|
||||
self.cudagraph_manager.capture(
|
||||
model=self.model,
|
||||
input_buffers=self.input_buffers,
|
||||
mrope_positions=mrope_positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
block_tables=self.block_tables,
|
||||
attn_metadata_builders=self.attn_metadata_builders,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
@@ -412,8 +429,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if scheduler_output.preempted_req_ids is not None:
|
||||
for req_id in scheduler_output.preempted_req_ids:
|
||||
self.req_states.remove_request(req_id)
|
||||
if self.supports_mm_inputs:
|
||||
self.encoder_runner.remove_request(req_id)
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.req_states.remove_request(req_id)
|
||||
if self.supports_mm_inputs:
|
||||
self.encoder_runner.remove_request(req_id)
|
||||
|
||||
if self.supports_mm_inputs:
|
||||
for mm_hash in scheduler_output.free_encoder_mm_hashes:
|
||||
self.encoder_runner.free_encoder_cache(mm_hash)
|
||||
|
||||
# Add new requests.
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
@@ -432,13 +457,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
req_index = self.req_states.req_id_to_index[req_id]
|
||||
|
||||
if self.supports_mm_inputs:
|
||||
self.encoder_runner.add_request(req_id, new_req_data.mm_features)
|
||||
|
||||
# Pre-compute M-RoPE positions for prefill.
|
||||
if self.uses_mrope:
|
||||
self.mrope_states.init_prefill_mrope_positions(
|
||||
req_index,
|
||||
self.model, # type: ignore
|
||||
new_req_data.prefill_token_ids,
|
||||
mm_features=[], # TODO
|
||||
mm_features=new_req_data.mm_features,
|
||||
)
|
||||
|
||||
self.block_tables.append_block_ids(
|
||||
@@ -632,12 +660,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
mrope_positions=mrope_positions,
|
||||
inputs_embeds=None,
|
||||
attn_metadata=attn_metadata,
|
||||
logits_indices=logits_indices,
|
||||
cu_num_logits=cu_num_logits,
|
||||
cu_num_logits_np=cu_num_logits_np,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_mm_embeddings(
|
||||
self,
|
||||
scheduled_encoder_inputs: dict[str, list[int]],
|
||||
input_batch: InputBatch,
|
||||
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
||||
mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
|
||||
scheduled_encoder_inputs
|
||||
)
|
||||
self.encoder_runner.execute_mm_encoder(self.model, mm_hashes, mm_kwargs)
|
||||
mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings(
|
||||
input_batch.req_ids,
|
||||
input_batch.num_tokens,
|
||||
input_batch.num_scheduled_tokens,
|
||||
input_batch.query_start_loc_np,
|
||||
self.req_states.prefill_len.np[input_batch.idx_mapping_np],
|
||||
self.req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np],
|
||||
)
|
||||
return mm_embeds, is_mm_embed
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -930,6 +979,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
input_batch.num_scheduled_tokens,
|
||||
)
|
||||
self._set_active_loras(*lora_inputs)
|
||||
|
||||
if self.supports_mm_inputs:
|
||||
# Execute the multimodal encoder.
|
||||
mm_embeds, is_mm_embed = self.get_mm_embeddings(
|
||||
scheduler_output.scheduled_encoder_inputs, input_batch
|
||||
)
|
||||
inputs_embeds = self.encoder_runner.get_inputs_embeds(
|
||||
self.model, input_batch.input_ids, mm_embeds, is_mm_embed
|
||||
)
|
||||
input_batch.inputs_embeds = inputs_embeds[
|
||||
: input_batch.num_tokens_after_padding
|
||||
]
|
||||
else:
|
||||
# No actual tokens to run. A dummy run for DP.
|
||||
num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
|
||||
@@ -970,6 +1031,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_batch.input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=input_batch.inputs_embeds,
|
||||
)
|
||||
|
||||
self.execute_model_state = hidden_states, input_batch
|
||||
|
||||
@@ -48,9 +48,6 @@ class EagleSpeculator:
|
||||
self.input_buffers = InputBuffers(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
inputs_embeds_size=self.inputs_embeds_size,
|
||||
vocab_size=self.vocab_size,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
)
|
||||
self.hidden_states = torch.zeros(
|
||||
|
||||
Reference in New Issue
Block a user