[Bugfix] Make MM batching more robust (#33817)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -62,7 +62,6 @@ from vllm.multimodal.processing import (
|
||||
PromptUpdate,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
|
||||
from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal
|
||||
from .interfaces_base import attn_type
|
||||
@@ -74,7 +73,11 @@ def _terratorch_field_names(input_definition: InputDefinition):
|
||||
return set(input_definition.data.keys())
|
||||
|
||||
|
||||
def _terratorch_field_factory(input_definition: InputDefinition):
|
||||
def _terratorch_field_factory(
|
||||
input_definition: InputDefinition,
|
||||
*,
|
||||
is_shared: bool = True, # True for unprocessed data, False for processed data
|
||||
):
|
||||
def _terratorch_field_config(
|
||||
hf_inputs: Mapping[str, torch.Tensor],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
@@ -82,7 +85,11 @@ def _terratorch_field_factory(input_definition: InputDefinition):
|
||||
for name, input in input_definition.data.items():
|
||||
modality = "image"
|
||||
if input.type == InputTypeEnum.tensor:
|
||||
fields[name] = MultiModalFieldConfig.shared(modality, batch_size=1)
|
||||
fields[name] = (
|
||||
MultiModalFieldConfig.shared(modality, batch_size=1)
|
||||
if is_shared
|
||||
else MultiModalFieldConfig.batched(modality)
|
||||
)
|
||||
|
||||
return fields
|
||||
|
||||
@@ -166,8 +173,14 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessing
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
*,
|
||||
is_shared: bool = True,
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return _terratorch_field_factory(self.info.input_definition)(hf_inputs)
|
||||
factory = _terratorch_field_factory(
|
||||
self.info.input_definition,
|
||||
is_shared=is_shared,
|
||||
)
|
||||
return factory(hf_inputs)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
@@ -193,12 +206,19 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessing
|
||||
)
|
||||
|
||||
_, passthrough_data = self._get_hf_mm_data(mm_items)
|
||||
mm_processed_data = BatchFeature(dict(passthrough_data), tensor_type="pt")
|
||||
mm_processed_data = BatchFeature(
|
||||
{k: torch.tensor(v).unsqueeze(0) for k, v in passthrough_data.items()},
|
||||
tensor_type="pt",
|
||||
)
|
||||
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
|
||||
|
||||
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
|
||||
mm_processed_data,
|
||||
self._get_mm_fields_config(mm_processed_data, hf_processor_mm_kwargs),
|
||||
self._get_mm_fields_config(
|
||||
mm_processed_data,
|
||||
hf_processor_mm_kwargs,
|
||||
is_shared=False,
|
||||
),
|
||||
)
|
||||
|
||||
return MultiModalInputs(
|
||||
@@ -235,9 +255,6 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
self.inference_runner = InferenceRunner(config)
|
||||
self.model = self.inference_runner.model
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = IdentityPooler()
|
||||
|
||||
def embed_input_ids(
|
||||
@@ -262,15 +279,8 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs: object,
|
||||
):
|
||||
input_len = length_from_prompt_token_ids_or_embeds(input_ids, inputs_embeds)
|
||||
|
||||
batched_kwargs = {k: v.unsqueeze(0) for k, v in kwargs.items()}
|
||||
model_output = self.inference_runner.forward(**batched_kwargs).output
|
||||
|
||||
# The leading dimension of hidden states needs to equal input length
|
||||
return model_output.expand(
|
||||
input_len, *(-1 for _ in range(model_output.ndim - 1))
|
||||
)
|
||||
model_output = self.inference_runner.forward(**kwargs)
|
||||
return model_output.output
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
params_list = []
|
||||
|
||||
Reference in New Issue
Block a user