[Hardware][TPU] workaround fix for MoE on TPU (#11764)

This commit is contained in:
Avshalom Manevich
2025-01-12 17:53:51 +02:00
committed by GitHub
parent 8bddb73512
commit 263a870ee1
3 changed files with 60 additions and 1 deletions

View File

@@ -14,6 +14,8 @@ from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size)
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as iterative_moe)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize)
from vllm.model_executor.models.mixtral import MixtralMoE
@@ -46,6 +48,11 @@ def test_fused_moe(
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
iterative_output = iterative_moe(a, w1, w2, score, topk, renormalize=False)
torch.testing.assert_close(iterative_output,
torch_output,
atol=2e-2,
rtol=0)
@pytest.mark.parametrize("dtype",