[Bugfix] Fix platform-specific routing in CustomOp implementations (#24444)
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
This commit is contained in:
@@ -454,7 +454,7 @@ class XIELU(CustomOp):
|
||||
)
|
||||
return result.view(original_shape)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
def forward_native(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self._xielu_cuda_obj is not None and input.is_cuda:
|
||||
if not torch._dynamo.is_compiling():
|
||||
return self._xielu_cuda_fn(input)
|
||||
@@ -464,6 +464,9 @@ class XIELU(CustomOp):
|
||||
)
|
||||
return self._xielu_python(input)
|
||||
|
||||
def forward_cuda(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_native(input)
|
||||
|
||||
|
||||
class ScaledActivation(nn.Module):
|
||||
"""An activation function with post-scale parameters.
|
||||
|
||||
Reference in New Issue
Block a user