[OOT] Support sync_model_loading for OOT (#25126)

Signed-off-by: Chendi Xue <Chendi.Xue@intel.com>
This commit is contained in:
Chendi.Xue
2025-09-19 00:41:53 -05:00
committed by GitHub
parent 6c8a3c099b
commit a6149aa587
4 changed files with 33 additions and 17 deletions

View File

@@ -44,23 +44,12 @@ def set_weight_attrs(
# TODO(woosuk): Remove this hack once we have a better solution.
from vllm.platforms import current_platform
if current_platform.is_tpu() and key == "weight_loader":
value = _make_synced_weight_loader(value)
if current_platform.use_sync_weight_loader(
) and key == "weight_loader":
value = current_platform.make_synced_weight_loader(value)
setattr(weight, key, value)
def _make_synced_weight_loader(original_weight_loader):
def _synced_weight_loader(param, *args, **kwargs):
out = original_weight_loader(param, *args, **kwargs)
# torch._sync doesn't support, is not needed for CPU tensors.
if param.device != torch.device("cpu"):
torch._sync(param)
return out
return _synced_weight_loader
def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
parent_map = getattr(model, "packed_modules_mapping", None)
parent_map = copy.deepcopy(parent_map) if parent_map is not None else {}