[Hardware][TPU] workaround fix for MoE on TPU (#11764)
This commit is contained in:
committed by
GitHub
parent
8bddb73512
commit
263a870ee1
@@ -14,6 +14,8 @@ from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, moe_align_block_size)
|
||||
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
||||
fused_moe as iterative_moe)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
marlin_quantize)
|
||||
from vllm.model_executor.models.mixtral import MixtralMoE
|
||||
@@ -46,6 +48,11 @@ def test_fused_moe(
|
||||
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
|
||||
torch_output = torch_moe(a, w1, w2, score, topk)
|
||||
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
||||
iterative_output = iterative_moe(a, w1, w2, score, topk, renormalize=False)
|
||||
torch.testing.assert_close(iterative_output,
|
||||
torch_output,
|
||||
atol=2e-2,
|
||||
rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype",
|
||||
|
||||
Reference in New Issue
Block a user