From fe0411fc6fa32cebeacd3a3aef87a591e7309c45 Mon Sep 17 00:00:00 2001 From: 947132885 <947132885@qq.com> Date: Sun, 17 Aug 2025 16:46:36 +0800 Subject: [PATCH] [Bugfix] should use stack instead of concat (#22972) Signed-off-by: 947132885 <947132885@qq.com> Signed-off-by: Isotr0py Co-authored-by: Isotr0py --- vllm/model_executor/models/transformers.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 4ec2b683f..f3b7263ca 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -694,6 +694,17 @@ class TransformersForCausalLM(TransformersBase): return logits +def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor: + """Flatten until a list of tensors can be concatenated then do concat""" + + def _can_concat(x: list[torch.Tensor]): + return len(set(map(lambda _x: _x.shape[1:], x))) == 1 + + if _can_concat(x): + return torch.concat(x) + return flatten_and_concat(flatten_bn(x)) + + @MULTIMODAL_REGISTRY.register_processor( MultiModalProcessor, info=MultiModalProcessingInfo, @@ -766,8 +777,7 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal): if isinstance(pixel_values, torch.Tensor): pixel_values = flatten_bn(pixel_values).to(self.dtype) elif is_list_of(pixel_values, torch.Tensor): - pixel_values = flatten_bn(flatten_bn(pixel_values), - concat=True).to(self.dtype) + pixel_values = flatten_and_concat(pixel_values).to(self.dtype) else: raise ValueError( f"Unsupported pixel_values type {type(pixel_values)}. "