[Bugfix] Guard mm_token_type_ids kwarg in get_mrope_input_positions (#35711)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -474,7 +474,19 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
|
||||
# can't accept arbitrary args, even if its value is `None`
|
||||
kwargs = {}
|
||||
if mm_token_type_ids:
|
||||
kwargs["mm_token_type_ids"] = torch.cat(mm_token_type_ids)
|
||||
if not hasattr(self, "_get_rope_index_accepts_mm_token_type_ids"):
|
||||
import inspect
|
||||
|
||||
sig = inspect.signature(self.model.get_rope_index)
|
||||
params = sig.parameters
|
||||
self._get_rope_index_accepts_mm_token_type_ids = (
|
||||
"mm_token_type_ids" in params
|
||||
or any(
|
||||
p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()
|
||||
)
|
||||
)
|
||||
if self._get_rope_index_accepts_mm_token_type_ids:
|
||||
kwargs["mm_token_type_ids"] = torch.cat(mm_token_type_ids)
|
||||
|
||||
mrope_positions, mrope_position_delta = self.model.get_rope_index(
|
||||
input_ids=torch.tensor(input_tokens).unsqueeze(0),
|
||||
|
||||
Reference in New Issue
Block a user