[Bugfix] Fix platform-specific routing in CustomOp implementations (#24444)

Signed-off-by: Konrad Zawora <kzawora@habana.ai>
This commit is contained in:
Konrad Zawora
2025-09-11 19:15:01 +02:00
committed by GitHub
parent 1fdd5c42d7
commit 4aa23892d6
8 changed files with 53 additions and 30 deletions

View File

@@ -1593,7 +1593,7 @@ class FusedMoE(CustomOp):
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
def forward(
def forward_native(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
@@ -1627,6 +1627,13 @@ class FusedMoE(CustomOp):
return (shared_output[..., :og_hidden_states],
fused_output[..., :og_hidden_states])
def forward_cuda(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
return self.forward_native(hidden_states, router_logits)
def forward_impl_chunked(
self,
full_hidden_states: torch.Tensor,