Fix pipeline parallel with multimodal models with the Transformers modelling backend (#37057)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2026-03-16 10:20:37 +00:00
committed by GitHub
parent d8f8a7aad2
commit 122f75d939
2 changed files with 31 additions and 8 deletions

View File

@@ -16,8 +16,9 @@
# limitations under the License.
"""Transformers modeling backend base class."""
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from itertools import chain
from operator import attrgetter
from typing import TYPE_CHECKING
import regex as re
@@ -296,6 +297,15 @@ class Base(
# Apply mapping to quantization config if needed
self._maybe_apply_model_mapping()
def _get_tie_word_embeddings(self):
"""
Check if the model has tied word embeddings.
"""
# Transformers v4 and v5 will store this in different places
tie_word_embeddings_v4 = getattr(self.text_config, "tie_word_embeddings", False)
tie_word_embeddings_v5 = getattr(self.config, "tie_word_embeddings", False)
return tie_word_embeddings_v4 or tie_word_embeddings_v5
def pipeline_parallel(self):
"""
Apply the model's pipeline parallelization plan.
@@ -311,11 +321,22 @@ class Base(
f"{type(self.model)} does not support pipeline parallel. {tip}"
)
def attrsetter(attr: str) -> Callable[[object, object], None]:
"""Set a possibly nested attribute, like the inverse of attrgetter."""
parent, _, name = attr.rpartition(".")
def setter(obj: object, value: object):
attr_parent = attrgetter(parent)(obj) if parent else obj
setattr(attr_parent, name, value)
return setter
module_lists = []
module_list_idx = None
pp_plan = list(self.model._pp_plan.keys())
for i, name in enumerate(pp_plan):
if isinstance(getattr(self.model, name), nn.ModuleList):
# attrgetter in case the module is nested (e.g. "text_model.layers")
if isinstance(attrgetter(name)(self.model), nn.ModuleList):
module_lists.append(name)
module_list_idx = i
@@ -330,11 +351,11 @@ class Base(
# Layers before module list
for name in pp_plan[:module_list_idx]:
if self.pp_group.is_first_rank or (
getattr(self.text_config, "tie_word_embeddings", False)
and self.pp_group.is_last_rank
self._get_tie_word_embeddings() and self.pp_group.is_last_rank
):
continue
setattr(self.model, name, PPMissingLayer())
# attrsetter in case the module is nested (e.g. "text_model.embed_tokens")
attrsetter(name)(self.model, PPMissingLayer())
# Module list
start_layer, end_layer = get_pp_indices(
@@ -343,7 +364,8 @@ class Base(
self.pp_group.world_size,
)
layers_name = pp_plan[module_list_idx]
layers = getattr(self.model, layers_name)
# attrgetter in case the module is nested (e.g. "text_model.layers")
layers = attrgetter(layers_name)(self.model)
for i in range(len(layers)):
if start_layer <= i and i < end_layer:
continue
@@ -353,7 +375,8 @@ class Base(
for name in pp_plan[module_list_idx + 1 :]:
# Modules that should be on last rank
if not self.pp_group.is_last_rank:
setattr(self.model, name, PPMissingLayer())
# attrsetter in case the module is nested (e.g. "text_model.norm")
attrsetter(name)(self.model, PPMissingLayer())
def recursive_replace(self):
"""Recursively replace modules in the model as needed.

View File

@@ -38,7 +38,7 @@ class CausalMixin(VllmModelForTextGeneration):
# Tell `Base.load_weights` to skip
# `lm_head` if the model has tied word embeddings
tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False)
tie_word_embeddings = self._get_tie_word_embeddings()
if tie_word_embeddings:
self.skip_prefixes.append("lm_head.")