[Core] Automatically cast multi-modal input dtype (#18756)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-05-27 23:45:48 +08:00
committed by GitHub
parent 6b6d496114
commit 696259ca01
16 changed files with 91 additions and 44 deletions

View File

@@ -929,8 +929,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list:
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
device=self.device)
batched_mm_inputs = MultiModalKwargs.as_kwargs(
batched_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)
# Run the encoder.
# `curr_group_outputs` is either of the following:
@@ -1874,7 +1877,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
batched_dummy_mm_inputs = MultiModalKwargs.batch(
[dummy_mm_kwargs] * max_num_mm_items)
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs, device=self.device)
batched_dummy_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)
# Run multimodal encoder.
dummy_encoder_outputs = self.model.get_multimodal_embeddings(

View File

@@ -652,8 +652,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list:
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
device=self.device)
batched_mm_inputs = MultiModalKwargs.as_kwargs(
batched_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)
# Run the encoder.
# `curr_group_outputs` is either of the following:
@@ -1435,8 +1438,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
batch_size)
return MultiModalKwargs.as_kwargs(batched_dummy_mm_inputs,
device=self.device)
return MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)
def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]: