[Kernels][Bugfix] Use torch op for all kernels in FusedMoE forward. Add additional testing for cudagraphs. (#19717)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -29,7 +29,10 @@ MNK_FACTORS = [
|
||||
(224, 1024, 1536),
|
||||
(224, 3072, 1024),
|
||||
(224, 3072, 1536),
|
||||
(1024 * 128, 1024, 1024),
|
||||
(32768, 1024, 1024),
|
||||
# These sizes trigger wrong answers.
|
||||
#(7232, 2048, 5120),
|
||||
#(40000, 2048, 5120),
|
||||
]
|
||||
|
||||
vllm_config = VllmConfig(parallel_config=ParallelConfig(
|
||||
@@ -232,8 +235,10 @@ def test_cutlass_moe_8_bit_no_graph(
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
monkeypatch,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
||||
per_out_ch)
|
||||
@@ -274,8 +279,10 @@ def test_cutlass_moe_8_bit_cuda_graph(
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
monkeypatch,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
dtype = torch.half
|
||||
|
||||
@@ -329,8 +336,10 @@ def test_cutlass_moe_8_bit_EP(
|
||||
per_act_token: bool,
|
||||
per_out_channel: bool,
|
||||
ep_size: int,
|
||||
monkeypatch,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
||||
per_out_channel)
|
||||
|
||||
Reference in New Issue
Block a user