[Bugfix] Fix platform-specific routing in CustomOp implementations (#24444)

Signed-off-by: Konrad Zawora <kzawora@habana.ai>
This commit is contained in:
Konrad Zawora
2025-09-11 19:15:01 +02:00
committed by GitHub
parent 1fdd5c42d7
commit 4aa23892d6
8 changed files with 53 additions and 30 deletions

View File

@@ -454,7 +454,7 @@ class XIELU(CustomOp):
)
return result.view(original_shape)
def forward(self, input: torch.Tensor) -> torch.Tensor:
def forward_native(self, input: torch.Tensor) -> torch.Tensor:
if self._xielu_cuda_obj is not None and input.is_cuda:
if not torch._dynamo.is_compiling():
return self._xielu_cuda_fn(input)
@@ -464,6 +464,9 @@ class XIELU(CustomOp):
)
return self._xielu_python(input)
def forward_cuda(self, input: torch.Tensor) -> torch.Tensor:
return self.forward_native(input)
class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.