[OOT] Support sync_model_loading for OOT (#25126)

Signed-off-by: Chendi Xue <Chendi.Xue@intel.com>
This commit is contained in:
Chendi.Xue
2025-09-19 00:41:53 -05:00
committed by GitHub
parent 6c8a3c099b
commit a6149aa587
4 changed files with 33 additions and 17 deletions

View File

@@ -12,7 +12,6 @@ from torch.nn import Parameter
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.logger import init_logger
from vllm.model_executor.utils import _make_synced_weight_loader
__all__ = [
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
@@ -53,8 +52,9 @@ class BasevLLMParameter(Parameter):
# This sometimes causes OOM errors during model loading. To avoid this,
# we sync the param tensor after its weight loader is called.
from vllm.platforms import current_platform
if current_platform.is_tpu():
weight_loader = _make_synced_weight_loader(weight_loader)
if current_platform.use_sync_weight_loader():
weight_loader = current_platform.make_synced_weight_loader(
weight_loader)
self._weight_loader = weight_loader
self.tp_rank = get_tensor_model_parallel_rank()