[Hardware] Initial TPU integration (#5292)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.utils import is_cpu, is_hip
|
||||
from vllm.utils import is_cpu, is_hip, is_tpu
|
||||
|
||||
|
||||
class CustomOp(nn.Module):
|
||||
@@ -56,5 +56,7 @@ class CustomOp(nn.Module):
|
||||
return self.forward_hip
|
||||
elif is_cpu():
|
||||
return self.forward_cpu
|
||||
elif is_tpu():
|
||||
return self.forward_tpu
|
||||
else:
|
||||
return self.forward_cuda
|
||||
|
||||
Reference in New Issue
Block a user