[hardware] unify usage of is_tpu to current_platform.is_tpu() (#7102)

This commit is contained in:
youkaichao
2024-08-13 00:16:42 -07:00
committed by GitHub
parent 7025b11d94
commit 4d2dc5072b
8 changed files with 29 additions and 33 deletions

View File

@@ -1,6 +1,7 @@
import torch.nn as nn
from vllm.utils import is_cpu, is_hip, is_tpu, is_xpu
from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip, is_xpu
class CustomOp(nn.Module):
@@ -54,7 +55,7 @@ class CustomOp(nn.Module):
return self.forward_hip
elif is_cpu():
return self.forward_cpu
elif is_tpu():
elif current_platform.is_tpu():
return self.forward_tpu
elif is_xpu():
return self.forward_xpu