[Core] Automatically cast multi-modal input dtype (#18756)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -929,8 +929,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
encoder_outputs = []
|
||||
for grouped_mm_inputs in grouped_mm_inputs_list:
|
||||
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
|
||||
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
|
||||
device=self.device)
|
||||
batched_mm_inputs = MultiModalKwargs.as_kwargs(
|
||||
batched_mm_inputs,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Run the encoder.
|
||||
# `curr_group_outputs` is either of the following:
|
||||
@@ -1874,7 +1877,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
batched_dummy_mm_inputs = MultiModalKwargs.batch(
|
||||
[dummy_mm_kwargs] * max_num_mm_items)
|
||||
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
|
||||
batched_dummy_mm_inputs, device=self.device)
|
||||
batched_dummy_mm_inputs,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Run multimodal encoder.
|
||||
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
|
||||
|
||||
@@ -652,8 +652,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
encoder_outputs = []
|
||||
for grouped_mm_inputs in grouped_mm_inputs_list:
|
||||
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
|
||||
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
|
||||
device=self.device)
|
||||
batched_mm_inputs = MultiModalKwargs.as_kwargs(
|
||||
batched_mm_inputs,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Run the encoder.
|
||||
# `curr_group_outputs` is either of the following:
|
||||
@@ -1435,8 +1438,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
|
||||
batch_size)
|
||||
return MultiModalKwargs.as_kwargs(batched_dummy_mm_inputs,
|
||||
device=self.device)
|
||||
return MultiModalKwargs.as_kwargs(
|
||||
batched_dummy_mm_inputs,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
|
||||
def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:
|
||||
|
||||
Reference in New Issue
Block a user