[Bugfix] Fix test_mixtral_moe (#24371)
This commit is contained in:
@@ -371,8 +371,8 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
|
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
|
def test_mixtral_moe(dist_init, dtype: torch.dtype, padding: bool,
|
||||||
monkeypatch):
|
use_rocm_aiter: bool, monkeypatch):
|
||||||
"""Make sure our Mixtral MoE implementation agrees with the one from
|
"""Make sure our Mixtral MoE implementation agrees with the one from
|
||||||
huggingface."""
|
huggingface."""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user