[Bugfix] Fix platform-specific routing in CustomOp implementations (#24444)
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
This commit is contained in:
@@ -1593,7 +1593,7 @@ class FusedMoE(CustomOp):
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
def forward(
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
@@ -1627,6 +1627,13 @@ class FusedMoE(CustomOp):
|
||||
return (shared_output[..., :og_hidden_states],
|
||||
fused_output[..., :og_hidden_states])
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
return self.forward_native(hidden_states, router_logits)
|
||||
|
||||
def forward_impl_chunked(
|
||||
self,
|
||||
full_hidden_states: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user