Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
|
||||
Run `pytest tests/kernels/test_pplx_moe.py`.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import textwrap
|
||||
@@ -15,29 +16,34 @@ import torch
|
||||
|
||||
try:
|
||||
from pplx_kernels import AllToAll
|
||||
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
|
||||
nvshmem_finalize, nvshmem_get_unique_id,
|
||||
nvshmem_init)
|
||||
from pplx_kernels.nvshmem import (
|
||||
nvshmem_alloc_empty_unique_id,
|
||||
nvshmem_finalize,
|
||||
nvshmem_get_unique_id,
|
||||
nvshmem_init,
|
||||
)
|
||||
|
||||
has_pplx = True
|
||||
except ImportError:
|
||||
has_pplx = False
|
||||
|
||||
from tests.kernels.moe.modular_kernel_tools.parallel_utils import (
|
||||
_set_vllm_config)
|
||||
from tests.kernels.moe.utils import (make_shared_experts, make_test_weights,
|
||||
naive_batched_moe)
|
||||
from tests.kernels.moe.modular_kernel_tools.parallel_utils import _set_vllm_config
|
||||
from tests.kernels.moe.utils import (
|
||||
make_shared_experts,
|
||||
make_test_weights,
|
||||
naive_batched_moe,
|
||||
)
|
||||
from tests.kernels.quant_utils import dequant
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe import fused_topk, override_config
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
TopKWeightAndReduceDelegate,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import round_up
|
||||
|
||||
@@ -59,7 +65,7 @@ BATCHED_MOE_MNK_FACTORS = [
|
||||
|
||||
PPLX_COMBOS = [
|
||||
# TODO(bnell): figure out why this fails, seems to be test problem
|
||||
#(1, 128, 128),
|
||||
# (1, 128, 128),
|
||||
(2, 128, 512),
|
||||
(3, 1024, 2048),
|
||||
(4, 128, 128),
|
||||
@@ -91,17 +97,16 @@ def torch_prepare(
|
||||
num_tokens, hidden_dim = a.shape
|
||||
topk = topk_ids.shape[1]
|
||||
|
||||
tokens_per_expert = torch.bincount(topk_ids.view(-1),
|
||||
minlength=num_experts)
|
||||
tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts)
|
||||
|
||||
assert tokens_per_expert.numel() == num_experts
|
||||
|
||||
if max_num_tokens is None:
|
||||
max_num_tokens = int(tokens_per_expert.max().item())
|
||||
|
||||
b_a = torch.zeros((num_experts, max_num_tokens, hidden_dim),
|
||||
dtype=a.dtype,
|
||||
device=a.device)
|
||||
b_a = torch.zeros(
|
||||
(num_experts, max_num_tokens, hidden_dim), dtype=a.dtype, device=a.device
|
||||
)
|
||||
|
||||
token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device)
|
||||
|
||||
@@ -109,28 +114,29 @@ def torch_prepare(
|
||||
for j in range(topk):
|
||||
expert_id = topk_ids[token, j]
|
||||
idx = token_counts[expert_id]
|
||||
b_a[expert_id, idx:idx + 1, :] = a[token, :]
|
||||
b_a[expert_id, idx : idx + 1, :] = a[token, :]
|
||||
token_counts[expert_id] = token_counts[expert_id] + 1
|
||||
|
||||
return b_a, tokens_per_expert
|
||||
|
||||
|
||||
def torch_finalize(b_out: torch.Tensor, topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor) -> torch.Tensor:
|
||||
def torch_finalize(
|
||||
b_out: torch.Tensor, topk_weight: torch.Tensor, topk_ids: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
num_tokens = topk_ids.shape[0]
|
||||
num_experts = b_out.shape[0]
|
||||
K = b_out.shape[-1]
|
||||
out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device)
|
||||
expert_counts = torch.zeros(num_experts,
|
||||
dtype=torch.int,
|
||||
device=b_out.device)
|
||||
expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device)
|
||||
for token in range(num_tokens):
|
||||
expert_ids = topk_ids[token]
|
||||
for i in range(expert_ids.numel()):
|
||||
expert_id = expert_ids[i]
|
||||
idx = expert_counts[expert_id]
|
||||
out[token, :] = out[token, :] + b_out[expert_id, idx:idx +
|
||||
1, :] * topk_weight[token, i]
|
||||
out[token, :] = (
|
||||
out[token, :]
|
||||
+ b_out[expert_id, idx : idx + 1, :] * topk_weight[token, i]
|
||||
)
|
||||
expert_counts[expert_id] = expert_counts[expert_id] + 1
|
||||
|
||||
return out
|
||||
@@ -149,17 +155,18 @@ def torch_batched_moe(
|
||||
num_tokens, topk = topk_ids.shape
|
||||
_, max_num_tokens, K = b_a.shape
|
||||
assert num_experts == b_a.shape[0] and w2.shape[1] == K
|
||||
out = torch.zeros((num_experts, max_num_tokens, K),
|
||||
dtype=b_a.dtype,
|
||||
device=b_a.device)
|
||||
tmp = torch.empty((max_num_tokens, w1.shape[1] // 2),
|
||||
dtype=b_a.dtype,
|
||||
device=b_a.device)
|
||||
out = torch.zeros(
|
||||
(num_experts, max_num_tokens, K), dtype=b_a.dtype, device=b_a.device
|
||||
)
|
||||
tmp = torch.empty(
|
||||
(max_num_tokens, w1.shape[1] // 2), dtype=b_a.dtype, device=b_a.device
|
||||
)
|
||||
for expert in range(num_experts):
|
||||
num = tokens_per_expert[expert]
|
||||
if num > 0:
|
||||
torch.ops._C.silu_and_mul(
|
||||
tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1))
|
||||
tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1)
|
||||
)
|
||||
out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1)
|
||||
|
||||
return torch_finalize(out, topk_weight, topk_ids)
|
||||
@@ -186,20 +193,16 @@ def test_fused_moe_batched_experts(
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
baseline_output = torch_experts(a, w1, w2, topk_weight,
|
||||
topk_ids) # only for baseline
|
||||
baseline_output = torch_experts(
|
||||
a, w1, w2, topk_weight, topk_ids
|
||||
) # only for baseline
|
||||
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
|
||||
batched_output = naive_batched_moe(
|
||||
a, w1, w2, topk_weight, topk_ids) # pick torch_experts or this
|
||||
a, w1, w2, topk_weight, topk_ids
|
||||
) # pick torch_experts or this
|
||||
|
||||
torch.testing.assert_close(baseline_output,
|
||||
torch_output,
|
||||
atol=2e-2,
|
||||
rtol=0)
|
||||
torch.testing.assert_close(baseline_output,
|
||||
batched_output,
|
||||
atol=2e-2,
|
||||
rtol=0)
|
||||
torch.testing.assert_close(baseline_output, torch_output, atol=2e-2, rtol=0)
|
||||
torch.testing.assert_close(baseline_output, batched_output, atol=2e-2, rtol=0)
|
||||
|
||||
|
||||
def create_pplx_prepare_finalize(
|
||||
@@ -217,7 +220,9 @@ def create_pplx_prepare_finalize(
|
||||
group_name: Optional[str],
|
||||
):
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes)
|
||||
PplxPrepareAndFinalize,
|
||||
pplx_hidden_dim_scale_bytes,
|
||||
)
|
||||
|
||||
max_num_tokens = max(rank_chunk(num_tokens, 0, world_size), 1)
|
||||
num_local_experts = rank_chunk(num_experts, 0, world_size)
|
||||
@@ -266,28 +271,31 @@ def rank_chunk(num: int, r: int, w: int) -> int:
|
||||
|
||||
def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
|
||||
chunk = rank_chunk(t.shape[0], r, w)
|
||||
return t[(r * chunk):(r + 1) * chunk]
|
||||
return t[(r * chunk) : (r + 1) * chunk]
|
||||
|
||||
|
||||
def maybe_chunk_by_rank(t: Optional[torch.Tensor], r: int,
|
||||
w: int) -> Optional[torch.Tensor]:
|
||||
def maybe_chunk_by_rank(
|
||||
t: Optional[torch.Tensor], r: int, w: int
|
||||
) -> Optional[torch.Tensor]:
|
||||
if t is not None:
|
||||
return chunk_by_rank(t, r, w)
|
||||
else:
|
||||
return t
|
||||
|
||||
|
||||
def chunk_scales_by_rank(t: Optional[torch.Tensor], r: int,
|
||||
w: int) -> Optional[torch.Tensor]:
|
||||
def chunk_scales_by_rank(
|
||||
t: Optional[torch.Tensor], r: int, w: int
|
||||
) -> Optional[torch.Tensor]:
|
||||
if t is not None and t.numel() > 1:
|
||||
chunk = rank_chunk(t.shape[0], r, w)
|
||||
return t[(r * chunk):(r + 1) * chunk]
|
||||
return t[(r * chunk) : (r + 1) * chunk]
|
||||
else:
|
||||
return t
|
||||
|
||||
|
||||
def chunk_scales(t: Optional[torch.Tensor], start: int,
|
||||
end: int) -> Optional[torch.Tensor]:
|
||||
def chunk_scales(
|
||||
t: Optional[torch.Tensor], start: int, end: int
|
||||
) -> Optional[torch.Tensor]:
|
||||
if t is not None and t.numel() > 1:
|
||||
return t[start:end]
|
||||
else:
|
||||
@@ -350,8 +358,7 @@ def pplx_prepare_finalize(
|
||||
device=device,
|
||||
)
|
||||
|
||||
if (quant_dtype is not None and not per_act_token_quant
|
||||
and block_shape is None):
|
||||
if quant_dtype is not None and not per_act_token_quant and block_shape is None:
|
||||
a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
else:
|
||||
@@ -375,8 +382,7 @@ def pplx_prepare_finalize(
|
||||
),
|
||||
)
|
||||
|
||||
b_a = dummy_work(
|
||||
dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype))
|
||||
b_a = dummy_work(dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype))
|
||||
|
||||
prepare_finalize.finalize(
|
||||
out,
|
||||
@@ -410,15 +416,17 @@ def _pplx_prepare_finalize(
|
||||
):
|
||||
try:
|
||||
if use_internode:
|
||||
uid = nvshmem_get_unique_id(
|
||||
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||
uid = (
|
||||
nvshmem_get_unique_id()
|
||||
if pgi.rank == 0
|
||||
else nvshmem_alloc_empty_unique_id()
|
||||
)
|
||||
torch.distributed.broadcast(uid, src=0)
|
||||
nvshmem_init(uid, pgi.rank, pgi.world_size)
|
||||
group_name = None
|
||||
else:
|
||||
group_ranks = list(range(pgi.world_size))
|
||||
cpu_group = torch.distributed.new_group(group_ranks,
|
||||
backend="gloo")
|
||||
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
|
||||
group_name = cpu_group.group_name
|
||||
|
||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
@@ -426,22 +434,28 @@ def _pplx_prepare_finalize(
|
||||
|
||||
a_rep = torch.repeat_interleave(dummy_work(a), topk, dim=0)
|
||||
|
||||
torch_output = (a_rep.view(m, topk, k) *
|
||||
topk_weight.view(m, topk, 1).to(a_rep.dtype)).sum(
|
||||
dim=1)
|
||||
torch_output = (
|
||||
a_rep.view(m, topk, k) * topk_weight.view(m, topk, 1).to(a_rep.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight,
|
||||
topk_ids, num_experts, quant_dtype,
|
||||
block_shape, per_act_token_quant,
|
||||
group_name)
|
||||
pplx_output = pplx_prepare_finalize(
|
||||
pgi,
|
||||
dp_size,
|
||||
a,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
num_experts,
|
||||
quant_dtype,
|
||||
block_shape,
|
||||
per_act_token_quant,
|
||||
group_name,
|
||||
)
|
||||
|
||||
torch_output = chunk_by_rank(torch_output, pgi.rank,
|
||||
pgi.world_size).to(pgi.device)
|
||||
torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(
|
||||
pgi.device
|
||||
)
|
||||
|
||||
torch.testing.assert_close(pplx_output,
|
||||
torch_output,
|
||||
atol=3e-2,
|
||||
rtol=3e-2)
|
||||
torch.testing.assert_close(pplx_output, torch_output, atol=3e-2, rtol=3e-2)
|
||||
finally:
|
||||
if use_internode:
|
||||
nvshmem_finalize()
|
||||
@@ -491,9 +505,19 @@ def test_pplx_prepare_finalize_slow(
|
||||
a = torch.randn((m, k), device=device, dtype=act_dtype) / 10
|
||||
score = torch.randn((m, e), device=device, dtype=act_dtype)
|
||||
|
||||
parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score,
|
||||
topk, e, quant_dtype, block_shape, per_act_token_quant,
|
||||
use_internode)
|
||||
parallel_launch(
|
||||
world_size,
|
||||
_pplx_prepare_finalize,
|
||||
dp_size,
|
||||
a,
|
||||
score,
|
||||
topk,
|
||||
e,
|
||||
quant_dtype,
|
||||
block_shape,
|
||||
per_act_token_quant,
|
||||
use_internode,
|
||||
)
|
||||
|
||||
|
||||
def pplx_moe(
|
||||
@@ -517,7 +541,6 @@ def pplx_moe(
|
||||
use_cudagraphs: bool = True,
|
||||
shared_experts: Optional[torch.nn.Module] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
|
||||
num_tokens, hidden_dim = a.shape
|
||||
num_experts = w1.shape[0]
|
||||
topk = topk_ids.shape[1]
|
||||
@@ -579,21 +602,23 @@ def pplx_moe(
|
||||
# large enough to trigger chunking. I'm leaving the flag and
|
||||
# setup code in case we are able to revisit this later.
|
||||
if use_compile:
|
||||
_fused_experts = torch.compile(fused_experts,
|
||||
backend='inductor',
|
||||
fullgraph=True)
|
||||
_fused_experts = torch.compile(
|
||||
fused_experts, backend="inductor", fullgraph=True
|
||||
)
|
||||
torch._dynamo.mark_dynamic(a_chunk, 0)
|
||||
torch._dynamo.mark_dynamic(chunk_topk_weight, 0)
|
||||
torch._dynamo.mark_dynamic(chunk_topk_ids, 0)
|
||||
else:
|
||||
_fused_experts = fused_experts
|
||||
|
||||
out = _fused_experts(a_chunk,
|
||||
w1_chunk,
|
||||
w2_chunk,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
global_num_experts=num_experts)
|
||||
out = _fused_experts(
|
||||
a_chunk,
|
||||
w1_chunk,
|
||||
w2_chunk,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
global_num_experts=num_experts,
|
||||
)
|
||||
|
||||
if use_cudagraphs:
|
||||
if isinstance(out, tuple):
|
||||
@@ -604,12 +629,14 @@ def pplx_moe(
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
out = _fused_experts(a_chunk,
|
||||
w1_chunk,
|
||||
w2_chunk,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
global_num_experts=num_experts)
|
||||
out = _fused_experts(
|
||||
a_chunk,
|
||||
w1_chunk,
|
||||
w2_chunk,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
global_num_experts=num_experts,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
@@ -640,15 +667,17 @@ def _pplx_moe(
|
||||
):
|
||||
try:
|
||||
if use_internode:
|
||||
uid = nvshmem_get_unique_id(
|
||||
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||
uid = (
|
||||
nvshmem_get_unique_id()
|
||||
if pgi.rank == 0
|
||||
else nvshmem_alloc_empty_unique_id()
|
||||
)
|
||||
torch.distributed.broadcast(uid, src=0)
|
||||
nvshmem_init(uid, pgi.rank, pgi.world_size)
|
||||
group_name = None
|
||||
else:
|
||||
group_ranks = list(range(pgi.world_size))
|
||||
cpu_group = torch.distributed.new_group(group_ranks,
|
||||
backend="gloo")
|
||||
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
|
||||
group_name = cpu_group.group_name
|
||||
|
||||
m, k = a.shape
|
||||
@@ -666,8 +695,7 @@ def _pplx_moe(
|
||||
w1_s = w1_s.to(device) if w1_s is not None else None
|
||||
w2_s = w2_s.to(device) if w2_s is not None else None
|
||||
|
||||
if (quant_dtype is not None and not per_act_token_quant
|
||||
and block_shape is None):
|
||||
if quant_dtype is not None and not per_act_token_quant and block_shape is None:
|
||||
a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
else:
|
||||
@@ -742,31 +770,27 @@ def _pplx_moe(
|
||||
if shared_output is not None:
|
||||
assert pplx_shared_output is not None
|
||||
chunked_shared_output = chunk_by_rank(
|
||||
shared_output, pgi.rank,
|
||||
pgi.world_size).to(pplx_shared_output.device)
|
||||
shared_output, pgi.rank, pgi.world_size
|
||||
).to(pplx_shared_output.device)
|
||||
else:
|
||||
chunked_shared_output = None
|
||||
|
||||
chunked_batch_output = chunk_by_rank(
|
||||
batched_output, pgi.rank, pgi.world_size).to(pplx_output.device)
|
||||
batched_output, pgi.rank, pgi.world_size
|
||||
).to(pplx_output.device)
|
||||
|
||||
torch.testing.assert_close(batched_output,
|
||||
torch_output,
|
||||
atol=3e-2,
|
||||
rtol=3e-2)
|
||||
torch.testing.assert_close(batched_output, torch_output, atol=3e-2, rtol=3e-2)
|
||||
|
||||
torch.testing.assert_close(pplx_output,
|
||||
chunked_batch_output,
|
||||
atol=3e-2,
|
||||
rtol=3e-2)
|
||||
torch.testing.assert_close(
|
||||
pplx_output, chunked_batch_output, atol=3e-2, rtol=3e-2
|
||||
)
|
||||
|
||||
if shared_experts is not None:
|
||||
assert chunked_shared_output is not None
|
||||
assert pplx_shared_output is not None
|
||||
torch.testing.assert_close(pplx_shared_output,
|
||||
chunked_shared_output,
|
||||
atol=3e-2,
|
||||
rtol=3e-2)
|
||||
torch.testing.assert_close(
|
||||
pplx_shared_output, chunked_shared_output, atol=3e-2, rtol=3e-2
|
||||
)
|
||||
|
||||
finally:
|
||||
if use_internode:
|
||||
@@ -823,15 +847,33 @@ def test_pplx_moe_slow(
|
||||
per_out_ch_quant=per_act_token_quant,
|
||||
)
|
||||
|
||||
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, e,
|
||||
w1_s, w2_s, quant_dtype, per_act_token_quant, block_shape,
|
||||
use_internode)
|
||||
parallel_launch(
|
||||
world_size,
|
||||
_pplx_moe,
|
||||
dp_size,
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
e,
|
||||
w1_s,
|
||||
w2_s,
|
||||
quant_dtype,
|
||||
per_act_token_quant,
|
||||
block_shape,
|
||||
use_internode,
|
||||
)
|
||||
|
||||
|
||||
def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
|
||||
use_shared_experts: bool, make_weights: bool,
|
||||
test_fn: Callable):
|
||||
|
||||
def _pplx_test_loop(
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
use_internode: bool,
|
||||
use_shared_experts: bool,
|
||||
make_weights: bool,
|
||||
test_fn: Callable,
|
||||
):
|
||||
def format_result(msg, ex=None):
|
||||
if ex is not None:
|
||||
x = str(ex)
|
||||
@@ -850,12 +892,12 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
|
||||
new_vllm_config = copy.deepcopy(vllm_config)
|
||||
new_vllm_config.parallel_config.data_parallel_size = pgi.world_size
|
||||
new_vllm_config.parallel_config.enable_expert_parallel = True
|
||||
_set_vllm_config(new_vllm_config, pgi.world_size, pgi.rank,
|
||||
pgi.local_rank)
|
||||
_set_vllm_config(new_vllm_config, pgi.world_size, pgi.rank, pgi.local_rank)
|
||||
|
||||
current_platform.seed_everything(7)
|
||||
combos = itertools.product(PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES,
|
||||
[False, True], [None, [128, 128]])
|
||||
combos = itertools.product(
|
||||
PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES, [False, True], [None, [128, 128]]
|
||||
)
|
||||
exceptions = []
|
||||
count = 0
|
||||
for mnk, e, topk, dtype, per_act_token_quant, block_shape in combos:
|
||||
@@ -873,13 +915,11 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
|
||||
f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, "
|
||||
f"dtype={dtype}, per_act_token={per_act_token_quant}, "
|
||||
f"block_shape={block_shape}, use_internode={use_internode}, "
|
||||
f"use_shared_experts={use_shared_experts}")
|
||||
f"use_shared_experts={use_shared_experts}"
|
||||
)
|
||||
|
||||
if not use_fp8_w8a8 and (per_act_token_quant
|
||||
or block_shape is not None):
|
||||
print(
|
||||
f"{test_desc} - Skip quantization test for non-quantized type."
|
||||
)
|
||||
if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
|
||||
print(f"{test_desc} - Skip quantization test for non-quantized type.")
|
||||
continue
|
||||
|
||||
if per_act_token_quant and block_shape is not None:
|
||||
@@ -934,10 +974,10 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
|
||||
if len(exceptions) > 0:
|
||||
raise RuntimeError(
|
||||
f"{len(exceptions)} of {count} tests failed in child process, "
|
||||
f"rank={pgi.rank}.")
|
||||
f"rank={pgi.rank}."
|
||||
)
|
||||
else:
|
||||
print(f"{count} of {count} tests passed in child process, "
|
||||
f"rank={pgi.rank}.")
|
||||
print(f"{count} of {count} tests passed in child process, rank={pgi.rank}.")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
|
||||
@@ -950,8 +990,15 @@ def test_pplx_prepare_finalize(
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
world_size, dp_size = world_dp_size
|
||||
parallel_launch(world_size * dp_size, _pplx_test_loop, dp_size,
|
||||
use_internode, False, False, _pplx_prepare_finalize)
|
||||
parallel_launch(
|
||||
world_size * dp_size,
|
||||
_pplx_test_loop,
|
||||
dp_size,
|
||||
use_internode,
|
||||
False,
|
||||
False,
|
||||
_pplx_prepare_finalize,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
|
||||
@@ -966,5 +1013,12 @@ def test_pplx_moe(
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
world_size, dp_size = world_dp_size
|
||||
parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode,
|
||||
use_shared_experts, True, _pplx_moe)
|
||||
parallel_launch(
|
||||
world_size,
|
||||
_pplx_test_loop,
|
||||
dp_size,
|
||||
use_internode,
|
||||
use_shared_experts,
|
||||
True,
|
||||
_pplx_moe,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user