[Bugfix] Make MM batching more robust (#33817)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-02-06 04:40:58 +08:00
committed by GitHub
parent 4145e50d85
commit 116880a5a0
13 changed files with 625 additions and 428 deletions

View File

@@ -62,7 +62,6 @@ from vllm.multimodal.processing import (
PromptUpdate,
)
from vllm.sequence import IntermediateTensors
from vllm.utils import length_from_prompt_token_ids_or_embeds
from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal
from .interfaces_base import attn_type
@@ -74,7 +73,11 @@ def _terratorch_field_names(input_definition: InputDefinition):
return set(input_definition.data.keys())
def _terratorch_field_factory(input_definition: InputDefinition):
def _terratorch_field_factory(
input_definition: InputDefinition,
*,
is_shared: bool = True, # True for unprocessed data, False for processed data
):
def _terratorch_field_config(
hf_inputs: Mapping[str, torch.Tensor],
) -> Mapping[str, MultiModalFieldConfig]:
@@ -82,7 +85,11 @@ def _terratorch_field_factory(input_definition: InputDefinition):
for name, input in input_definition.data.items():
modality = "image"
if input.type == InputTypeEnum.tensor:
fields[name] = MultiModalFieldConfig.shared(modality, batch_size=1)
fields[name] = (
MultiModalFieldConfig.shared(modality, batch_size=1)
if is_shared
else MultiModalFieldConfig.batched(modality)
)
return fields
@@ -166,8 +173,14 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessing
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
*,
is_shared: bool = True,
) -> Mapping[str, MultiModalFieldConfig]:
return _terratorch_field_factory(self.info.input_definition)(hf_inputs)
factory = _terratorch_field_factory(
self.info.input_definition,
is_shared=is_shared,
)
return factory(hf_inputs)
def _get_prompt_updates(
self,
@@ -193,12 +206,19 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessing
)
_, passthrough_data = self._get_hf_mm_data(mm_items)
mm_processed_data = BatchFeature(dict(passthrough_data), tensor_type="pt")
mm_processed_data = BatchFeature(
{k: torch.tensor(v).unsqueeze(0) for k, v in passthrough_data.items()},
tensor_type="pt",
)
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
mm_processed_data,
self._get_mm_fields_config(mm_processed_data, hf_processor_mm_kwargs),
self._get_mm_fields_config(
mm_processed_data,
hf_processor_mm_kwargs,
is_shared=False,
),
)
return MultiModalInputs(
@@ -235,9 +255,6 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
self.inference_runner = InferenceRunner(config)
self.model = self.inference_runner.model
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = IdentityPooler()
def embed_input_ids(
@@ -262,15 +279,8 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
):
input_len = length_from_prompt_token_ids_or_embeds(input_ids, inputs_embeds)
batched_kwargs = {k: v.unsqueeze(0) for k, v in kwargs.items()}
model_output = self.inference_runner.forward(**batched_kwargs).output
# The leading dimension of hidden states needs to equal input length
return model_output.expand(
input_len, *(-1 for _ in range(model_output.ndim - 1))
)
model_output = self.inference_runner.forward(**kwargs)
return model_output.output
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
params_list = []