[Hardware][TPU] workaround fix for MoE on TPU (#11764)

This commit is contained in:
Avshalom Manevich
2025-01-12 17:53:51 +02:00
committed by GitHub
parent 8bddb73512
commit 263a870ee1
3 changed files with 60 additions and 1 deletions

View File

@@ -20,7 +20,8 @@ if current_platform.is_cuda_alike():
else:
fused_experts = None # type: ignore
if current_platform.is_tpu():
from .moe_pallas import fused_moe as fused_moe_pallas
# the iterative moe implementation is used until the moe_pallas is fixed
from .moe_torch_iterative import fused_moe as fused_moe_pallas
else:
fused_moe_pallas = None # type: ignore
logger = init_logger(__name__)