[Bugfix] Fix RuntimeError: Index put requires the source and destination dtypes match (#22065)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
Chauncey
2025-08-08 10:20:21 +08:00
committed by GitHub
parent 3303f134e0
commit 17eaaef595
2 changed files with 106 additions and 2 deletions

View File

@@ -401,7 +401,7 @@ def merge_multimodal_embeddings_from_map(
"""
flattened_embeddings = _flatten_embeddings(multimodal_embeddings)
inputs_embeds[placeholder_map.dest] = flattened_embeddings[
placeholder_map.src]
placeholder_map.src].to(dtype=inputs_embeds.dtype)
return inputs_embeds
@@ -421,7 +421,8 @@ def _merge_multimodal_embeddings(
flattened = _flatten_embeddings(multimodal_embeddings)
try:
# This is equivalent to: inputs_embeds[is_multimodal] = flattened.
inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), flattened)
inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1),
flattened.to(dtype=inputs_embeds.dtype))
except RuntimeError as e:
num_expected_tokens = is_multimodal.sum().item()
assert isinstance(num_expected_tokens, int)