[OOT] Support sync_model_loading for OOT (#25126)
Signed-off-by: Chendi Xue <Chendi.Xue@intel.com>
This commit is contained in:
@@ -44,23 +44,12 @@ def set_weight_attrs(
|
||||
# TODO(woosuk): Remove this hack once we have a better solution.
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_tpu() and key == "weight_loader":
|
||||
value = _make_synced_weight_loader(value)
|
||||
if current_platform.use_sync_weight_loader(
|
||||
) and key == "weight_loader":
|
||||
value = current_platform.make_synced_weight_loader(value)
|
||||
setattr(weight, key, value)
|
||||
|
||||
|
||||
def _make_synced_weight_loader(original_weight_loader):
|
||||
|
||||
def _synced_weight_loader(param, *args, **kwargs):
|
||||
out = original_weight_loader(param, *args, **kwargs)
|
||||
# torch._sync doesn't support, is not needed for CPU tensors.
|
||||
if param.device != torch.device("cpu"):
|
||||
torch._sync(param)
|
||||
return out
|
||||
|
||||
return _synced_weight_loader
|
||||
|
||||
|
||||
def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
|
||||
parent_map = getattr(model, "packed_modules_mapping", None)
|
||||
parent_map = copy.deepcopy(parent_map) if parent_map is not None else {}
|
||||
|
||||
Reference in New Issue
Block a user