[Refactor][TPU] Remove torch_xla path and use tpu-inference (#30808)
Signed-off-by: Wei-Yu Lin <weiyulin@google.com> Signed-off-by: weiyu <62784299+weiyu0824@users.noreply.github.com>
This commit is contained in:
@@ -30,7 +30,6 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
pt_weights_iterator,
|
||||
safetensors_weights_iterator,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -241,22 +240,6 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
self.load_config.pt_load_map_location,
|
||||
)
|
||||
|
||||
if current_platform.is_tpu():
|
||||
from vllm.platforms.tpu import USE_TPU_INFERENCE
|
||||
|
||||
if not USE_TPU_INFERENCE:
|
||||
# In PyTorch XLA, we should call `torch_xla.sync`
|
||||
# frequently so that not too many ops are accumulated
|
||||
# in the XLA program.
|
||||
import torch_xla
|
||||
|
||||
def _xla_weights_iterator(iterator: Generator):
|
||||
for weights in iterator:
|
||||
yield weights
|
||||
torch_xla.sync(wait=False)
|
||||
|
||||
weights_iterator = _xla_weights_iterator(weights_iterator)
|
||||
|
||||
if self.counter_before_loading_weights == 0.0:
|
||||
self.counter_before_loading_weights = time.perf_counter()
|
||||
# Apply the prefix.
|
||||
|
||||
Reference in New Issue
Block a user