[V1] Extend beyond image modality and support mixed-modality inference with Llava-OneVision (#11685)

Signed-off-by: Roger Wang <ywang@roblox.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Roger Wang
2025-01-06 11:58:16 -08:00
committed by GitHub
parent e20c92bb61
commit 91b361ae89
17 changed files with 633 additions and 279 deletions

View File

@@ -1,15 +1,15 @@
import enum
from typing import TYPE_CHECKING, List, Optional, Union
from vllm.inputs import DecoderOnlyInputs, SingletonInputsAdapter, token_inputs
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingParams
from vllm.sequence import RequestMetrics
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.v1.core.kv_cache_utils import BlockHashType
@@ -18,14 +18,17 @@ class Request:
def __init__(
self,
request_id: str,
inputs: DecoderOnlyInputs,
prompt: Optional[str],
prompt_token_ids: List[int],
multi_modal_inputs: Optional[List["MultiModalKwargs"]],
multi_modal_hashes: Optional[List[str]],
multi_modal_placeholders: Optional[List["PlaceholderRange"]],
sampling_params: SamplingParams,
eos_token_id: Optional[int],
arrival_time: float,
lora_request: Optional[LoRARequest] = None,
) -> None:
self.request_id = request_id
self.inputs = SingletonInputsAdapter(inputs)
self.sampling_params = sampling_params
# Because of LoRA, the eos token id can be different for each request.
self.eos_token_id = eos_token_id
@@ -41,26 +44,21 @@ class Request:
assert sampling_params.max_tokens is not None
self.max_tokens = sampling_params.max_tokens
self.prompt = self.inputs.prompt
self.prompt_token_ids = self.inputs.prompt_token_ids
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.num_prompt_tokens = len(self.prompt_token_ids)
self._output_token_ids: List[int] = []
self._all_token_ids: List[int] = self.prompt_token_ids.copy()
self.num_computed_tokens = 0
# Multi-modal input metadata.
mm_positions = self.inputs.multi_modal_placeholders
if mm_positions:
# FIXME(woosuk): Support other modalities.
self.mm_positions = mm_positions.get("image", [])
else:
self.mm_positions = []
# Output of the mm input mapper (e.g., image tensors).
self.mm_inputs: List[MultiModalKwargs] = []
if self.inputs.multi_modal_inputs:
self.mm_inputs = self.inputs.multi_modal_inputs
# Multi-modal related
self.mm_positions = multi_modal_placeholders or []
self.mm_inputs = multi_modal_inputs or []
self.mm_hashes: List[str] = multi_modal_hashes or []
self.mm_hashes: List[str] = self.inputs.multi_modal_hashes
# Sanity check
assert len(self.mm_inputs) == len(self.mm_positions)
assert len(self.mm_inputs) == len(self.mm_hashes)
# Cache the computed kv block hashes of the request to avoid
# recomputing.
@@ -70,15 +68,11 @@ class Request:
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
return cls(
request_id=request.request_id,
inputs=token_inputs(
prompt_token_ids=request.prompt_token_ids,
prompt=request.prompt,
multi_modal_data=None,
multi_modal_inputs=request.mm_inputs,
multi_modal_hashes=request.mm_hashes,
multi_modal_placeholders=request.mm_placeholders,
mm_processor_kwargs=None,
),
prompt=request.prompt,
prompt_token_ids=request.prompt_token_ids,
multi_modal_inputs=request.mm_inputs,
multi_modal_hashes=request.mm_hashes,
multi_modal_placeholders=request.mm_placeholders,
sampling_params=request.sampling_params,
eos_token_id=request.eos_token_id,
arrival_time=request.arrival_time,