[Multimodal] Consolidate mm inputs into MultiModalFeatureSpec (#23779)
Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
@@ -6,10 +6,9 @@ import time
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
|
||||
EngineCoreRequest, FinishReason)
|
||||
from vllm.v1.structured_output.request import StructuredOutputRequest
|
||||
@@ -26,14 +25,12 @@ class Request:
|
||||
self,
|
||||
request_id: str,
|
||||
prompt_token_ids: list[int],
|
||||
multi_modal_kwargs: Optional[list[MultiModalKwargsItem]],
|
||||
multi_modal_hashes: Optional[list[str]],
|
||||
multi_modal_placeholders: Optional[list[PlaceholderRange]],
|
||||
sampling_params: Optional[SamplingParams],
|
||||
pooling_params: Optional[PoolingParams],
|
||||
eos_token_id: Optional[int],
|
||||
client_index: int = 0,
|
||||
arrival_time: Optional[float] = None,
|
||||
mm_features: Optional[list[MultiModalFeatureSpec]] = None,
|
||||
lora_request: Optional["LoRARequest"] = None,
|
||||
structured_output_request: Optional["StructuredOutputRequest"] = None,
|
||||
cache_salt: Optional[str] = None,
|
||||
@@ -89,16 +86,14 @@ class Request:
|
||||
self.cache_salt: Optional[str] = cache_salt
|
||||
|
||||
# Multi-modal related
|
||||
self.mm_positions = multi_modal_placeholders or []
|
||||
self.mm_kwargs = multi_modal_kwargs or []
|
||||
self.mm_hashes: list[str] = multi_modal_hashes or []
|
||||
self.num_encoder_inputs = len(self.mm_kwargs)
|
||||
self.mm_features = mm_features or []
|
||||
self.num_encoder_inputs = len(self.mm_features)
|
||||
self.has_encoder_inputs = self.num_encoder_inputs > 0
|
||||
|
||||
# Sanity check
|
||||
assert len(self.mm_kwargs) == len(self.mm_positions)
|
||||
if self.mm_hashes:
|
||||
assert len(self.mm_kwargs) == len(self.mm_hashes)
|
||||
# TODO(sfeng33): Remove these legacy fields after clearing out all
|
||||
# references in scheduler and model runner
|
||||
self.mm_positions = [f.mm_position for f in self.mm_features]
|
||||
self.mm_kwargs = [f.data for f in self.mm_features]
|
||||
self.mm_hashes = [f.identifier for f in self.mm_features]
|
||||
|
||||
# Read-only views
|
||||
# Prevent directly appending to these lists since
|
||||
@@ -126,20 +121,11 @@ class Request:
|
||||
cls, request: EngineCoreRequest,
|
||||
block_hasher: Optional[Callable[["Request"], list["BlockHash"]]]
|
||||
) -> "Request":
|
||||
if request.mm_kwargs is not None:
|
||||
mm_kwargs_lst = list(request.mm_kwargs)
|
||||
assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem), (
|
||||
"mm_kwargs was not updated in EngineCore.add_request")
|
||||
else:
|
||||
mm_kwargs_lst = None
|
||||
|
||||
return cls(
|
||||
request_id=request.request_id,
|
||||
client_index=request.client_index,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
multi_modal_kwargs=mm_kwargs_lst,
|
||||
multi_modal_hashes=request.mm_hashes,
|
||||
multi_modal_placeholders=request.mm_placeholders,
|
||||
mm_features=request.mm_features,
|
||||
sampling_params=request.sampling_params,
|
||||
pooling_params=request.pooling_params,
|
||||
eos_token_id=request.eos_token_id,
|
||||
|
||||
Reference in New Issue
Block a user