[Bugfix] Fix auto dtype casting for BatchFeature (#19316)
Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -168,10 +168,12 @@ class InputProcessingContext(InputContext):
|
||||
try:
|
||||
output = hf_processor(**data, **merged_kwargs, return_tensors="pt")
|
||||
# this emulates output.to(dtype=self.model_config.dtype)
|
||||
cast_output = json_map_leaves(maybe_cast_dtype, output)
|
||||
if isinstance(output, BatchFeature):
|
||||
cast_output = json_map_leaves(maybe_cast_dtype, output.data)
|
||||
return BatchFeature(cast_output)
|
||||
|
||||
cast_output = json_map_leaves(maybe_cast_dtype, output)
|
||||
|
||||
logger.warning_once(
|
||||
f"{type(hf_processor).__name__} did not return `BatchFeature`. "
|
||||
"Make sure to match the behaviour of `ProcessorMixin` when "
|
||||
|
||||
Reference in New Issue
Block a user