[Hardware][TPU] Support MoE with Pallas GMM kernel (#6457)
This commit is contained in:
@@ -104,6 +104,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
raise NotImplementedError(
|
||||
"The CPU backend currently does not support MoE.")
|
||||
|
||||
def forward_tpu(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
num_expert_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
|
||||
assert not use_grouped_topk
|
||||
assert num_expert_group is None
|
||||
assert topk_group is None
|
||||
return fused_moe(x, w1, w2, router_logits, top_k, renormalize)
|
||||
|
||||
|
||||
class FusedMoE(torch.nn.Module):
|
||||
"""FusedMoE layer for MoE models.
|
||||
|
||||
Reference in New Issue
Block a user