[OOT] Support sync_model_loading for OOT (#25126)
Signed-off-by: Chendi Xue <Chendi.Xue@intel.com>
This commit is contained in:
@@ -12,7 +12,6 @@ from torch.nn import Parameter
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.utils import _make_synced_weight_loader
|
||||
|
||||
__all__ = [
|
||||
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
|
||||
@@ -53,8 +52,9 @@ class BasevLLMParameter(Parameter):
|
||||
# This sometimes causes OOM errors during model loading. To avoid this,
|
||||
# we sync the param tensor after its weight loader is called.
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.is_tpu():
|
||||
weight_loader = _make_synced_weight_loader(weight_loader)
|
||||
if current_platform.use_sync_weight_loader():
|
||||
weight_loader = current_platform.make_synced_weight_loader(
|
||||
weight_loader)
|
||||
|
||||
self._weight_loader = weight_loader
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
Reference in New Issue
Block a user