[TPU] add tpu_inference (#27277)
Signed-off-by: Johnny Yang <johnnyyang@google.com>
This commit is contained in:
@@ -12,6 +12,4 @@ ray[data]
|
|||||||
setuptools==78.1.0
|
setuptools==78.1.0
|
||||||
nixl==0.3.0
|
nixl==0.3.0
|
||||||
tpu_info==0.4.0
|
tpu_info==0.4.0
|
||||||
|
tpu-inference==0.11.1
|
||||||
# Install torch_xla
|
|
||||||
torch_xla[tpu, pallas]==2.8.0
|
|
||||||
|
|||||||
@@ -97,11 +97,3 @@ class TpuCommunicator(DeviceCommunicatorBase):
|
|||||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||||
assert dim == -1, "TPUs only support dim=-1 for all-gather."
|
assert dim == -1, "TPUs only support dim=-1 for all-gather."
|
||||||
return xm.all_gather(input_, dim=dim)
|
return xm.all_gather(input_, dim=dim)
|
||||||
|
|
||||||
|
|
||||||
if USE_TPU_INFERENCE:
|
|
||||||
from tpu_inference.distributed.device_communicators import (
|
|
||||||
TpuCommunicator as TpuInferenceCommunicator,
|
|
||||||
)
|
|
||||||
|
|
||||||
TpuCommunicator = TpuInferenceCommunicator # type: ignore
|
|
||||||
|
|||||||
@@ -267,7 +267,9 @@ class TpuPlatform(Platform):
|
|||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from tpu_inference.platforms import TpuPlatform as TpuInferencePlatform
|
from tpu_inference.platforms.tpu_platforms import (
|
||||||
|
TpuPlatform as TpuInferencePlatform,
|
||||||
|
)
|
||||||
|
|
||||||
TpuPlatform = TpuInferencePlatform # type: ignore
|
TpuPlatform = TpuInferencePlatform # type: ignore
|
||||||
USE_TPU_INFERENCE = True
|
USE_TPU_INFERENCE = True
|
||||||
|
|||||||
@@ -346,6 +346,6 @@ class TPUWorker:
|
|||||||
|
|
||||||
|
|
||||||
if USE_TPU_INFERENCE:
|
if USE_TPU_INFERENCE:
|
||||||
from tpu_inference.worker import TPUWorker as TpuInferenceWorker
|
from tpu_inference.worker.tpu_worker import TPUWorker as TpuInferenceWorker
|
||||||
|
|
||||||
TPUWorker = TpuInferenceWorker # type: ignore
|
TPUWorker = TpuInferenceWorker # type: ignore
|
||||||
|
|||||||
Reference in New Issue
Block a user