[VLM] Support caching in merged multi-modal processor (#11396)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
from functools import cached_property
|
||||
from types import MethodType
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
|
||||
@@ -7,7 +6,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
|
||||
PixtralVisionConfig, PretrainedConfig,
|
||||
ProcessorMixin, SiglipVisionConfig)
|
||||
SiglipVisionConfig)
|
||||
from transformers.models.llava import LlavaProcessor
|
||||
from transformers.models.pixtral import PixtralProcessor
|
||||
|
||||
@@ -21,10 +20,12 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
|
||||
MultiModalFieldConfig, MultiModalInputsV2,
|
||||
MultiModalKwargs, NestedTensors)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
ProcessorInputs, PromptReplacement,
|
||||
full_groupby_modality)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
||||
@@ -116,36 +117,54 @@ def get_max_llava_image_tokens(ctx: InputContext):
|
||||
|
||||
class LlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def _patch_pixtral_processor(self, hf_processor: PixtralProcessor):
|
||||
if getattr(hf_processor, "__is_patched__", False):
|
||||
return # Already patched
|
||||
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
orig_preprocess = image_processor.preprocess
|
||||
|
||||
def preprocess(__self, *args, **kwargs):
|
||||
hf_inputs = orig_preprocess(*args, **kwargs)
|
||||
hf_inputs["is_pixtral"] = torch.tensor(True)
|
||||
return hf_inputs
|
||||
|
||||
image_processor.preprocess = MethodType(preprocess, image_processor)
|
||||
|
||||
hf_processor.__is_patched__ = True # type: ignore
|
||||
|
||||
def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
|
||||
hf_processor = self.ctx.get_hf_processor(
|
||||
(LlavaProcessor, PixtralProcessor))
|
||||
return self.ctx.get_hf_processor((LlavaProcessor, PixtralProcessor))
|
||||
|
||||
if isinstance(hf_processor, PixtralProcessor):
|
||||
self._patch_pixtral_processor(hf_processor)
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
return hf_processor
|
||||
# NOTE: pixel_values=None for MLlavaProcessor
|
||||
pixel_values = processed_outputs.get("pixel_values")
|
||||
if pixel_values is not None:
|
||||
images = mm_data["images"]
|
||||
assert isinstance(images, list)
|
||||
|
||||
if isinstance(self._get_hf_processor(), PixtralProcessor):
|
||||
# Original output: (1, num_images, C, H, W)
|
||||
# New output: (num_images, C, H, W)
|
||||
assert (isinstance(pixel_values, list)
|
||||
and len(pixel_values) == 1
|
||||
and isinstance(pixel_values[0], list)
|
||||
and len(pixel_values[0]) == len(images))
|
||||
|
||||
processed_outputs["pixel_values"] = pixel_values[0]
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_inputs: BatchFeature,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.ctx.get_hf_config(LlavaConfig)
|
||||
image_token_id = hf_config.image_token_index
|
||||
@@ -200,7 +219,7 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
) -> ProcessorInputs:
|
||||
hf_config = self.ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
num_images = mm_counts["image"]
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
data = dummy_image_for_clip(vision_config, num_images)
|
||||
@@ -218,7 +237,6 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
return ProcessorInputs(
|
||||
prompt_text=image_token * num_images,
|
||||
mm_data=data,
|
||||
mm_processor_kwargs={},
|
||||
)
|
||||
|
||||
|
||||
@@ -379,7 +397,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[LlavaImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
is_pixtral = kwargs.pop("is_pixtral", torch.tensor([False]))
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values is None and image_embeds is None:
|
||||
@@ -390,33 +407,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
assert isinstance(is_pixtral, torch.Tensor)
|
||||
if is_pixtral.any():
|
||||
images = pixel_values
|
||||
|
||||
def flatten_to_3d_tensors(item):
|
||||
if isinstance(item, torch.Tensor):
|
||||
if item.dim() >= 3:
|
||||
return [t for t in item.view(-1, *item.shape[-3:])]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected tensor dimension: {item.dim()}")
|
||||
elif isinstance(item, list):
|
||||
return [
|
||||
t for subitem in item
|
||||
for t in flatten_to_3d_tensors(subitem)
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unexpected type: {type(item)}")
|
||||
|
||||
# Restructure the batched images into a list of lists of images
|
||||
images = flatten_to_3d_tensors(pixel_values)
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=images,
|
||||
)
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(
|
||||
@@ -586,19 +576,71 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
|
||||
def _get_hf_processor(self) -> ProcessorMixin:
|
||||
try:
|
||||
from mantis.models.mllava import MLlavaProcessor
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ModuleNotFoundError(
|
||||
"You need to `pip install "
|
||||
"git+https://github.com/TIGER-AI-Lab/Mantis.git` "
|
||||
"to use this model") from exc
|
||||
def _get_hf_processor(self):
|
||||
return self.ctx.get_hf_processor(LlavaProcessor)
|
||||
|
||||
processor = MLlavaProcessor.from_pretrained(
|
||||
self.ctx.model_config.tokenizer)
|
||||
assert isinstance(processor, ProcessorMixin)
|
||||
return processor
|
||||
def apply(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputsV2:
|
||||
hf_config = self.ctx.get_hf_config(LlavaConfig)
|
||||
image_token_id = hf_config.image_token_index
|
||||
max_image_tokens = get_max_llava_image_tokens(self.ctx)
|
||||
|
||||
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
mm_items = self._get_mm_items(mm_data)
|
||||
mm_item_counts = mm_items.get_item_counts()
|
||||
mm_kwargs = result["mm_kwargs"]
|
||||
|
||||
# We reimplement the functionality of MLlavaProcessor from
|
||||
# https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
def get_replacement_mantis(item_idx: int):
|
||||
return "".join([
|
||||
f"(image {item_idx+1}: <Image>", # 7 tokens
|
||||
"<image>" * max_image_tokens,
|
||||
"</Image>)", # 3 tokens
|
||||
])
|
||||
|
||||
mantis_repls = self._bind_prompt_replacements([
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id] * max_image_tokens,
|
||||
replacement=get_replacement_mantis,
|
||||
)
|
||||
])
|
||||
|
||||
prompt_ids, prompt_text, _ = self._apply_prompt_replacements(
|
||||
result["prompt_token_ids"],
|
||||
mantis_repls,
|
||||
mm_item_counts,
|
||||
)
|
||||
|
||||
unbound_orig_repls = self._get_prompt_replacements(
|
||||
mm_items,
|
||||
hf_processor_mm_kwargs,
|
||||
mm_kwargs,
|
||||
)
|
||||
orig_repls = self._bind_prompt_replacements(unbound_orig_repls)
|
||||
|
||||
all_placeholders = self._find_placeholders(orig_repls, prompt_ids,
|
||||
mm_item_counts)
|
||||
assert len(all_placeholders) == mm_item_counts.get("image", 0)
|
||||
|
||||
mm_placeholders = {
|
||||
modality: [item.to_range() for item in items]
|
||||
for modality, items in full_groupby_modality(all_placeholders)
|
||||
}
|
||||
|
||||
return MultiModalInputsV2(
|
||||
type="multimodal",
|
||||
prompt=prompt_text,
|
||||
prompt_token_ids=prompt_ids,
|
||||
mm_kwargs=mm_kwargs,
|
||||
mm_placeholders=mm_placeholders,
|
||||
)
|
||||
|
||||
|
||||
# To use this model, please use
|
||||
|
||||
@@ -12,9 +12,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -32,10 +32,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
|
||||
MultiModalFieldConfig, MultiModalInputsV2,
|
||||
MultiModalKwargs, NestedTensors,
|
||||
PlaceholderRange)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
ProcessorInputs, PromptReplacement,
|
||||
_BoundPromptReplacement,
|
||||
_PlaceholderInfo)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
@@ -306,11 +310,11 @@ def get_max_phi3v_image_tokens(
|
||||
*,
|
||||
num_crops: Optional[int] = None,
|
||||
) -> int:
|
||||
mm_processor_kwargs = {}
|
||||
hf_processor_mm_kwargs = {}
|
||||
if num_crops:
|
||||
mm_processor_kwargs["num_crops"] = num_crops
|
||||
hf_processor_mm_kwargs["num_crops"] = num_crops
|
||||
|
||||
processor = ctx.get_hf_processor(**mm_processor_kwargs)
|
||||
processor = ctx.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
return processor.calc_num_image_tokens_from_image_size(
|
||||
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
@@ -331,39 +335,50 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
hf_processor: ProcessorMixin,
|
||||
prompt: str,
|
||||
processor_data: Mapping[str, object],
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
hf_processor,
|
||||
prompt=prompt,
|
||||
processor_data=processor_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
input_ids = processed_outputs["input_ids"]
|
||||
assert isinstance(input_ids, torch.Tensor)
|
||||
|
||||
# Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
|
||||
# which will cause OverflowError when decoding the prompt_ids.
|
||||
# Therefore, we need to do an early replacement here
|
||||
token_ids = processed_outputs['input_ids']
|
||||
token_ids[token_ids < 0] = _IMAGE_TOKEN_ID
|
||||
processed_outputs['input_ids'] = token_ids
|
||||
input_ids.masked_fill_(input_ids < 0, _IMAGE_TOKEN_ID)
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_sizes=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_inputs: BatchFeature,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
|
||||
mm_config = self.ctx.get_mm_config()
|
||||
max_images = mm_config.limit_per_prompt.get("image", 1)
|
||||
tokenizer = self._get_tokenizer()
|
||||
bos_token_id = tokenizer.bos_token_id
|
||||
assert isinstance(bos_token_id, int)
|
||||
|
||||
def get_replacement_phi3v(item_idx: int):
|
||||
image_size = mm_items.get_image_size(item_idx)
|
||||
@@ -372,21 +387,44 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
height=image_size.height,
|
||||
)
|
||||
|
||||
return [_IMAGE_TOKEN_ID] * num_tokens
|
||||
return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id]
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=image_token,
|
||||
replacement=get_replacement_phi3v,
|
||||
) for image_token in image_tokens[:max_images]
|
||||
) for image_token in image_tokens[:len(mm_items.images)]
|
||||
]
|
||||
|
||||
def _apply_prompt_replacements(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
prompt_repls: Sequence[_BoundPromptReplacement],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> tuple[list[int], str, list[_PlaceholderInfo]]:
|
||||
token_ids, text, placeholders = super()._apply_prompt_replacements(
|
||||
token_ids=token_ids,
|
||||
prompt_repls=prompt_repls,
|
||||
mm_item_counts=mm_item_counts,
|
||||
)
|
||||
|
||||
# Keep the behavior in line with HF processor
|
||||
if text.startswith("<s> <|image|>"):
|
||||
text = text.replace("<s> <|image|>", "<s><|image|>", 1)
|
||||
token_ids = [token_ids[0], *token_ids[2:]]
|
||||
placeholders = [
|
||||
_PlaceholderInfo(p.modality, p.start_idx - 1, p.replacement)
|
||||
for p in placeholders
|
||||
]
|
||||
|
||||
return token_ids, text, placeholders
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts["image"]
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
data = dummy_image_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
@@ -401,9 +439,28 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
return ProcessorInputs(
|
||||
prompt_text="".join(image_tokens[:num_images]),
|
||||
mm_data=data,
|
||||
mm_processor_kwargs={},
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputsV2:
|
||||
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
# Only <|image|> tokens should be considered as placeholders,
|
||||
# so we ignore the trailing bos_token_id
|
||||
result["mm_placeholders"] = {
|
||||
modality: [
|
||||
PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
|
||||
for p in ps
|
||||
]
|
||||
for modality, ps in result["mm_placeholders"].items()
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor)
|
||||
|
||||
@@ -225,7 +225,7 @@ class VisualAttentionBlock(nn.Module):
|
||||
d_model: int,
|
||||
n_head: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -266,7 +266,7 @@ class TransformerBlock(nn.Module):
|
||||
layers: int,
|
||||
heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -26,7 +26,7 @@ from typing import (Any, Iterable, List, Mapping, Optional, Set, Tuple,
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BatchFeature, ProcessorMixin
|
||||
from transformers import BatchFeature
|
||||
from transformers.models.qwen2_audio import (Qwen2AudioConfig,
|
||||
Qwen2AudioEncoder,
|
||||
Qwen2AudioProcessor)
|
||||
@@ -38,10 +38,10 @@ from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig,
|
||||
MultiModalKwargs, NestedTensors)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
ProcessorInputs, PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
@@ -73,7 +73,7 @@ class Qwen2AudioMultiModalProjector(nn.Module):
|
||||
|
||||
|
||||
# From Qwen2AudioEncoder._get_feat_extract_output_lengths
|
||||
def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
|
||||
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
|
||||
feat_lengths = (input_lengths - 1) // 2 + 1
|
||||
output_lengths = (feat_lengths - 2) // 2 + 1
|
||||
return feat_lengths, output_lengths
|
||||
@@ -88,13 +88,18 @@ def get_max_qwen2_audio_audio_tokens(ctx: InputContext) -> int:
|
||||
|
||||
class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def _get_hf_processor(self) -> Qwen2AudioProcessor:
|
||||
def _get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
# Ignored in initialization
|
||||
sampling_rate: Optional[int] = None,
|
||||
) -> Qwen2AudioProcessor:
|
||||
return self.ctx.get_hf_processor(Qwen2AudioProcessor)
|
||||
|
||||
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
|
||||
return self._get_hf_processor().feature_extractor # type: ignore
|
||||
|
||||
def _get_processor_data(
|
||||
def _get_hf_mm_data(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
@@ -102,50 +107,61 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
mm_items.resample_audios(feature_extractor.sampling_rate)
|
||||
|
||||
return super()._get_processor_data(mm_items)
|
||||
return super()._get_hf_mm_data(mm_items)
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
hf_processor: ProcessorMixin,
|
||||
prompt: str,
|
||||
processor_data: Mapping[str, object],
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processor_data = dict(processor_data)
|
||||
audios = processor_data.pop("audios", [])
|
||||
mm_data = dict(mm_data)
|
||||
audios = mm_data.pop("audios", [])
|
||||
|
||||
if audios:
|
||||
processor_data["audios"] = audios
|
||||
mm_data["audios"] = audios
|
||||
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
mm_processor_kwargs = dict(
|
||||
**mm_processor_kwargs,
|
||||
mm_kwargs = dict(
|
||||
**mm_kwargs,
|
||||
sampling_rate=feature_extractor.sampling_rate,
|
||||
)
|
||||
else:
|
||||
# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
|
||||
pass
|
||||
|
||||
return super()._call_hf_processor(
|
||||
hf_processor,
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
processor_data=processor_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
input_features=MultiModalFieldConfig.batched("audio"),
|
||||
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_inputs: BatchFeature,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.ctx.get_hf_config(Qwen2AudioConfig)
|
||||
placeholder = hf_config.audio_token_index
|
||||
|
||||
feature_attention_mask = hf_inputs.get("feature_attention_mask")
|
||||
feature_attention_mask = out_mm_kwargs.get("feature_attention_mask")
|
||||
if feature_attention_mask is None:
|
||||
audio_output_lengths = []
|
||||
else:
|
||||
assert isinstance(feature_attention_mask, torch.Tensor)
|
||||
_, audio_output_lengths = _get_feat_extract_output_lengths(
|
||||
feature_attention_mask.sum(-1))
|
||||
|
||||
@@ -168,14 +184,13 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
sampling_rate = feature_extractor.sampling_rate
|
||||
audio_len = feature_extractor.chunk_length * sampling_rate
|
||||
|
||||
audio_count = mm_counts["audio"]
|
||||
audio_count = mm_counts.get("audio", 0)
|
||||
audio = np.zeros(audio_len)
|
||||
data = {"audio": [audio] * audio_count}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="<|AUDIO|>" * audio_count,
|
||||
mm_data=data,
|
||||
mm_processor_kwargs={},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -22,9 +22,10 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
||||
from functools import cached_property, partial
|
||||
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
|
||||
Tuple, Type, TypedDict, Union)
|
||||
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
|
||||
Set, Tuple, Type, TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -54,10 +55,11 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalDataDict, NestedTensors
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
|
||||
MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
ProcessorInputs, PromptReplacement)
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
@@ -229,9 +231,9 @@ class Qwen2VisionAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: Optional[int] = None,
|
||||
num_heads: Optional[int] = None,
|
||||
projection_size: Optional[int] = None,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
projection_size: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
@@ -264,7 +266,7 @@ class Qwen2VisionAttention(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor = None,
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||
x, _ = self.qkv(x)
|
||||
@@ -347,7 +349,7 @@ class Qwen2VisionBlock(nn.Module):
|
||||
num_heads: int,
|
||||
mlp_ratio: float,
|
||||
act_layer: Type[nn.Module] = QuickGELU,
|
||||
norm_layer: Type[nn.Module] = None,
|
||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
@@ -384,7 +386,7 @@ class Qwen2VisionPatchEmbed(nn.Module):
|
||||
self,
|
||||
patch_size: int = 14,
|
||||
temporal_patch_size: int = 2,
|
||||
in_chans: int = 3,
|
||||
in_channels: int = 3,
|
||||
embed_dim: int = 1152,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -392,8 +394,8 @@ class Qwen2VisionPatchEmbed(nn.Module):
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
kernel_size = [temporal_patch_size, patch_size, patch_size]
|
||||
self.proj = nn.Conv3d(in_chans,
|
||||
kernel_size = (temporal_patch_size, patch_size, patch_size)
|
||||
self.proj = nn.Conv3d(in_channels,
|
||||
embed_dim,
|
||||
kernel_size=kernel_size,
|
||||
stride=kernel_size,
|
||||
@@ -413,7 +415,7 @@ class Qwen2VisionPatchMerger(nn.Module):
|
||||
self,
|
||||
d_model: int,
|
||||
context_dim: int,
|
||||
norm_layer: Type[nn.Module] = None,
|
||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||
spatial_merge_size: int = 2,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
@@ -489,15 +491,15 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
patch_size: int = vision_config.patch_size
|
||||
temporal_patch_size: int = vision_config.temporal_patch_size
|
||||
spatial_merge_size: int = vision_config.spatial_merge_size
|
||||
in_chans: int = vision_config.in_chans
|
||||
hidden_size: int = vision_config.hidden_size
|
||||
embed_dim: int = vision_config.embed_dim
|
||||
depth: int = vision_config.depth
|
||||
num_heads: int = vision_config.num_heads
|
||||
mlp_ratio: float = vision_config.mlp_ratio
|
||||
patch_size = vision_config.patch_size
|
||||
temporal_patch_size = vision_config.temporal_patch_size
|
||||
spatial_merge_size = vision_config.spatial_merge_size
|
||||
in_channels = vision_config.in_channels
|
||||
hidden_size = vision_config.hidden_size
|
||||
embed_dim = vision_config.embed_dim
|
||||
depth = vision_config.depth
|
||||
num_heads = vision_config.num_heads
|
||||
mlp_ratio = vision_config.mlp_ratio
|
||||
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.num_heads = num_heads
|
||||
@@ -506,7 +508,7 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
self.patch_embed = Qwen2VisionPatchEmbed(
|
||||
patch_size=patch_size,
|
||||
temporal_patch_size=temporal_patch_size,
|
||||
in_chans=in_chans,
|
||||
in_channels=in_channels,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
|
||||
@@ -733,8 +735,12 @@ class Qwen2VLMultiModalDataItems(MultiModalDataItems):
|
||||
if k == "video":
|
||||
# Special case since even a single item can be a list
|
||||
multi_data[k] = ( # type: ignore[index]
|
||||
v if (isinstance(v, (dict, torch.Tensor)) # type: ignore[assignment]
|
||||
or is_list_of(v, list)) else [v]
|
||||
v if (
|
||||
isinstance(v, (dict, torch.Tensor)) # type: ignore[assignment]
|
||||
or is_list_of(v, list)
|
||||
or isinstance(v[0], (np.ndarray, torch.Tensor))
|
||||
and v[0].ndim == 4
|
||||
) else [v]
|
||||
)
|
||||
elif k in ("image", "audio"):
|
||||
multi_data[k] = ( # type: ignore[index]
|
||||
@@ -754,6 +760,12 @@ class Qwen2VLMultiModalDataItems(MultiModalDataItems):
|
||||
for m, items in self.items()
|
||||
}
|
||||
|
||||
def has_embedding_inputs(self) -> bool:
|
||||
return any(
|
||||
isinstance(items, dict) or any(
|
||||
isinstance(item, torch.Tensor) for item in items)
|
||||
for items in self.values())
|
||||
|
||||
|
||||
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
@@ -784,7 +796,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
return hf_processor
|
||||
|
||||
def _get_processor_data(
|
||||
def _get_hf_mm_data(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
@@ -805,7 +817,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
and v[0].ndim == 2):
|
||||
# Pass through embedding inputs (multi)
|
||||
passthrough_data[f"{k}_embeds"] = v
|
||||
else:
|
||||
elif len(v) > 0:
|
||||
# Map keys to plural form, e.g.: image -> images
|
||||
processor_data[f"{k}s"] = v
|
||||
else:
|
||||
@@ -816,8 +828,8 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_inputs: BatchFeature,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_processor = _get_image_processor(hf_processor)
|
||||
@@ -831,7 +843,9 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
merge_length = image_processor.merge_size**2
|
||||
|
||||
def get_replacement_qwen2vl(item_idx: int, modality: str):
|
||||
grid_thw = hf_inputs[f"{modality}_grid_thw"][item_idx]
|
||||
grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx]
|
||||
assert isinstance(grid_thw, torch.Tensor)
|
||||
|
||||
num_tokens = grid_thw.prod() // merge_length
|
||||
return placeholder[modality] * num_tokens
|
||||
|
||||
@@ -844,11 +858,40 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
) for modality in ("image", "video")
|
||||
]
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
|
||||
image_slice_idxs = [0] + image_grid_thw.prod(-1).cumsum_(0).tolist()
|
||||
image_slices = [
|
||||
slice(image_slice_idxs[i], image_slice_idxs[i + 1])
|
||||
for i in range(len(image_grid_thw))
|
||||
]
|
||||
|
||||
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
|
||||
video_slice_idxs = [0] + video_grid_thw.prod(-1).cumsum_(0).tolist()
|
||||
video_slices = [
|
||||
slice(video_slice_idxs[i], video_slice_idxs[i + 1])
|
||||
for i in range(len(video_grid_thw))
|
||||
]
|
||||
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.flat("image", image_slices),
|
||||
image_embeds=MultiModalFieldConfig.flat("image", image_slices),
|
||||
image_grid_thw=MultiModalFieldConfig.batched("image"),
|
||||
pixel_values_videos=MultiModalFieldConfig.flat(
|
||||
"video", video_slices),
|
||||
video_embeds=MultiModalFieldConfig.flat("video", video_slices),
|
||||
video_grid_thw=MultiModalFieldConfig.batched("video"),
|
||||
)
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts["image"]
|
||||
num_images = mm_counts.get("image", 0)
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_token: str = hf_processor.image_token
|
||||
image_processor = _get_image_processor(hf_processor)
|
||||
@@ -869,7 +912,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
return ProcessorInputs(
|
||||
prompt_text=image_token * num_images,
|
||||
mm_data=data,
|
||||
mm_processor_kwargs={},
|
||||
)
|
||||
|
||||
|
||||
@@ -950,9 +992,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return None
|
||||
return quant_config
|
||||
|
||||
def _validate_and_reshape_mm_tensor(self,
|
||||
mm_input: Union[torch.Tensor,
|
||||
List[torch.Tensor]],
|
||||
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
||||
name: str) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of {name}. "
|
||||
@@ -962,7 +1002,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return mm_input
|
||||
if mm_input.ndim != 3:
|
||||
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
||||
f"Got ndim: {mm_input.ndim}")
|
||||
f"Got ndim: {mm_input.ndim} "
|
||||
f"(shape={mm_input.shape})")
|
||||
return torch.concat(list(mm_input))
|
||||
else:
|
||||
return torch.concat(mm_input)
|
||||
|
||||
@@ -23,10 +23,11 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig,
|
||||
MultiModalKwargs, NestedTensors)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
ProcessorInputs, PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||
from vllm.utils import is_list_of
|
||||
@@ -72,11 +73,19 @@ def get_ultravox_max_audio_tokens(ctx: InputContext):
|
||||
|
||||
class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def _get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
# Ignored in initialization
|
||||
sampling_rate: Optional[int] = None,
|
||||
) -> ProcessorMixin:
|
||||
return self.ctx.get_hf_processor()
|
||||
|
||||
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
|
||||
hf_processor = self._get_hf_processor()
|
||||
return hf_processor.audio_processor.feature_extractor # type: ignore
|
||||
|
||||
def _get_processor_data(
|
||||
def _get_hf_mm_data(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
@@ -84,33 +93,41 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
mm_items.resample_audios(feature_extractor.sampling_rate)
|
||||
|
||||
return super()._get_processor_data(mm_items)
|
||||
return super()._get_hf_mm_data(mm_items)
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
hf_processor: ProcessorMixin,
|
||||
prompt: str,
|
||||
processor_data: Mapping[str, object],
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processor_data = dict(processor_data)
|
||||
audios = processor_data.pop("audios", [])
|
||||
# Text-only input not supported in composite processor
|
||||
if not mm_data:
|
||||
tokenizer = self._get_tokenizer()
|
||||
|
||||
prompt_ids = tokenizer.encode(
|
||||
prompt,
|
||||
add_special_tokens=False, # type: ignore
|
||||
)
|
||||
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
||||
|
||||
mm_data = dict(mm_data)
|
||||
audios = mm_data.pop("audios", [])
|
||||
|
||||
if not audios:
|
||||
return super()._call_hf_processor(
|
||||
hf_processor,
|
||||
prompt=prompt,
|
||||
processor_data=processor_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
mm_processor_kwargs = dict(
|
||||
**mm_processor_kwargs,
|
||||
mm_kwargs = dict(
|
||||
**mm_kwargs,
|
||||
sampling_rate=feature_extractor.sampling_rate,
|
||||
)
|
||||
|
||||
# Already resampled by _get_processor_data
|
||||
# Already resampled by _get_hf_mm_data
|
||||
assert is_list_of(audios, np.ndarray)
|
||||
|
||||
# Ultravox processor doesn't support multiple inputs,
|
||||
@@ -119,13 +136,12 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
shared_outputs = {}
|
||||
for audio in audios:
|
||||
# NOTE: Ultravox processor accepts "audio" instead of "audios"
|
||||
item_processor_data = dict(**processor_data, audio=audio)
|
||||
item_processor_data = dict(**mm_data, audio=audio)
|
||||
|
||||
item_outputs = super()._call_hf_processor(
|
||||
hf_processor,
|
||||
prompt=prompt,
|
||||
processor_data=item_processor_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
mm_data=item_processor_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
audio_features.append(item_outputs.pop("audio_values")[0])
|
||||
@@ -139,17 +155,28 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
)
|
||||
return BatchFeature(combined_outputs)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
audio_features=MultiModalFieldConfig.batched("audio"),
|
||||
audio_token_len=MultiModalFieldConfig.batched("audio"),
|
||||
audio_embeds=MultiModalFieldConfig.batched("audio"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_inputs: BatchFeature,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_processor = self._get_hf_processor()
|
||||
placeholder = hf_processor.audio_token_replacement # type: ignore
|
||||
|
||||
def get_replacement_ultravox(item_idx: int):
|
||||
audio_token_len = hf_inputs["audio_token_len"][item_idx]
|
||||
audio_token_len = out_mm_kwargs["audio_token_len"][item_idx]
|
||||
return placeholder * audio_token_len
|
||||
|
||||
return [
|
||||
@@ -168,14 +195,13 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
sampling_rate = feature_extractor.sampling_rate
|
||||
audio_len = feature_extractor.chunk_length * sampling_rate
|
||||
|
||||
audio_count = mm_counts["audio"]
|
||||
audio_count = mm_counts.get("audio", 0)
|
||||
audio = np.zeros(audio_len)
|
||||
data = {"audio": [audio] * audio_count}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="<|audio|>" * audio_count,
|
||||
mm_data=data,
|
||||
mm_processor_kwargs={},
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user