[Model] Pass mm_features directly into get_mrope_input_positions (#28399)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -27,6 +27,7 @@ from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.multimodal import MultiModalKwargsItems
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFeatureSpec,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalInputs,
|
||||
MultiModalUUIDDict,
|
||||
@@ -38,7 +39,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
@@ -367,20 +368,34 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
|
||||
def get_mrope_input_positions(
|
||||
self,
|
||||
input_tokens: list[int],
|
||||
hf_config: "PretrainedConfig",
|
||||
image_grid_thw: list[list[int]] | torch.Tensor | None,
|
||||
video_grid_thw: list[list[int]] | torch.Tensor | None,
|
||||
second_per_grid_ts: list[float] | None = None,
|
||||
audio_feature_lengths: torch.Tensor | None = None,
|
||||
use_audio_in_video: bool = False,
|
||||
mm_features: list[MultiModalFeatureSpec],
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
if any((second_per_grid_ts, audio_feature_lengths, use_audio_in_video)):
|
||||
kwargs = MultiModalFeatureSpec.gather_kwargs(
|
||||
mm_features,
|
||||
{
|
||||
"image_grid_thw",
|
||||
"video_grid_thw",
|
||||
"second_per_grid_ts",
|
||||
"audio_feature_lengths",
|
||||
"use_audio_in_video",
|
||||
},
|
||||
)
|
||||
if any(
|
||||
v
|
||||
for k, v in kwargs.items()
|
||||
if k not in {"image_grid_thw", "video_grid_thw"}
|
||||
):
|
||||
raise NotImplementedError("Transformers backend only supports images.")
|
||||
|
||||
if isinstance(image_grid_thw, list):
|
||||
image_grid_thw = torch.tensor(image_grid_thw)
|
||||
if isinstance(video_grid_thw, list):
|
||||
video_grid_thw = torch.tensor(video_grid_thw)
|
||||
image_grid_thw = kwargs.get("image_grid_thw", [])
|
||||
video_grid_thw = kwargs.get("video_grid_thw", [])
|
||||
|
||||
image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)(
|
||||
image_grid_thw
|
||||
)
|
||||
video_grid_thw = (torch.stack if video_grid_thw else torch.tensor)(
|
||||
video_grid_thw
|
||||
)
|
||||
|
||||
mrope_positions, mrope_position_delta = self.model.get_rope_index(
|
||||
input_ids=torch.tensor(input_tokens).unsqueeze(0),
|
||||
|
||||
Reference in New Issue
Block a user