[Bugfix] Fix auto dtype casting for BatchFeature (#19316)

Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2025-06-14 23:13:08 +08:00
committed by GitHub
parent 6fa718a460
commit 2db9044ab6
7 changed files with 85 additions and 57 deletions

View File

@@ -168,10 +168,12 @@ class InputProcessingContext(InputContext):
try:
output = hf_processor(**data, **merged_kwargs, return_tensors="pt")
# this emulates output.to(dtype=self.model_config.dtype)
cast_output = json_map_leaves(maybe_cast_dtype, output)
if isinstance(output, BatchFeature):
cast_output = json_map_leaves(maybe_cast_dtype, output.data)
return BatchFeature(cast_output)
cast_output = json_map_leaves(maybe_cast_dtype, output)
logger.warning_once(
f"{type(hf_processor).__name__} did not return `BatchFeature`. "
"Make sure to match the behaviour of `ProcessorMixin` when "