[TPU] make ptxla not imported when using tpu_commons (#23081)
Signed-off-by: Chengji Yao <chengjiyao@gmail.com> Signed-off-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: Chengji Yao <chengjiyao@gmail.com>
This commit is contained in:
@@ -207,16 +207,21 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
)
|
||||
|
||||
if current_platform.is_tpu():
|
||||
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
|
||||
# not too many ops are accumulated in the XLA program.
|
||||
import torch_xla.core.xla_model as xm
|
||||
from vllm.platforms.tpu import USE_TPU_COMMONS
|
||||
|
||||
def _xla_weights_iterator(iterator: Generator):
|
||||
for weights in iterator:
|
||||
yield weights
|
||||
xm.mark_step()
|
||||
if not USE_TPU_COMMONS:
|
||||
# In PyTorch XLA, we should call `xm.mark_step`
|
||||
# requently so that not too many ops are accumulated
|
||||
# in the XLA program. import torch_xla.core.xla_model
|
||||
# as xm
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
weights_iterator = _xla_weights_iterator(weights_iterator)
|
||||
def _xla_weights_iterator(iterator: Generator):
|
||||
for weights in iterator:
|
||||
yield weights
|
||||
xm.mark_step()
|
||||
|
||||
weights_iterator = _xla_weights_iterator(weights_iterator)
|
||||
|
||||
if self.counter_before_loading_weights == 0.0:
|
||||
self.counter_before_loading_weights = time.perf_counter()
|
||||
|
||||
Reference in New Issue
Block a user