[Hardware][TPU] Support MoE with Pallas GMM kernel (#6457)

This commit is contained in:
Woosuk Kwon
2024-07-16 09:56:28 -07:00
committed by GitHub
parent 9f4ccec761
commit c467dff24f
5 changed files with 89 additions and 8 deletions

View File

@@ -104,6 +104,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
raise NotImplementedError(
"The CPU backend currently does not support MoE.")
def forward_tpu(
self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
num_expert_group: Optional[int],
topk_group: Optional[int],
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
return fused_moe(x, w1, w2, router_logits, top_k, renormalize)
class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models.