[Kernels] MoE refactor (#19636)

Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Co-authored-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
bnellnm
2025-07-02 09:08:27 -04:00
committed by GitHub
parent b95877509b
commit c1909e7e8c
36 changed files with 2698 additions and 1584 deletions

View File

@@ -0,0 +1,190 @@
# SPDX-License-Identifier: Apache-2.0
"""
DeepEP test utilities
"""
import dataclasses
import importlib
import os
import traceback
from typing import Callable, Optional
import torch
from torch.distributed import ProcessGroup
from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
from vllm.utils import get_open_port
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
if has_deep_ep:
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
## Parallel Processes Utils
P = ParamSpec("P")
@dataclasses.dataclass
class ProcessGroupInfo:
world_size: int
world_local_size: int
rank: int
node_rank: int
local_rank: int
device: torch.device
def _worker_parallel_launch(
local_rank: int,
world_size: int,
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
rank = node_rank * world_local_size + local_rank
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method=init_method,
rank=rank,
world_size=world_size,
device_id=device,
)
barrier = torch.tensor([rank], device=device)
torch.distributed.all_reduce(barrier)
try:
worker(
ProcessGroupInfo(
world_size=world_size,
world_local_size=world_local_size,
rank=rank,
node_rank=node_rank,
local_rank=local_rank,
device=device,
),
*args,
**kwargs,
)
except Exception as ex:
print(ex)
traceback.print_exc()
raise
finally:
torch.distributed.destroy_process_group()
def parallel_launch(
world_size: int,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
assert not kwargs
spawn(
_worker_parallel_launch,
args=(
world_size,
world_size,
0,
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
worker,
) + args,
nprocs=world_size,
join=True,
)
## DeepEP specific utils
@dataclasses.dataclass
class DeepEPHTArgs:
num_local_experts: int
@dataclasses.dataclass
class DeepEPLLArgs:
max_tokens_per_rank: int
hidden_size: int
num_experts: int
use_fp8_dispatch: bool
def make_deepep_ht_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
ht_args: DeepEPHTArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
import deep_ep
# high throughput a2a
num_nvl_bytes = 1024 * 1024 * 1024 # 1GB
num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1
buffer = deep_ep.Buffer(group=pg,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=low_latency_mode,
num_qps_per_rank=num_qps_per_rank)
return DeepEPHTPrepareAndFinalize(buffer=buffer,
world_size=pgi.world_size,
rank=pgi.rank,
dp_size=dp_size,
rank_expert_offset=pgi.rank *
ht_args.num_local_experts)
def make_deepep_ll_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
deepep_ll_args: DeepEPLLArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
import deep_ep
# low-latency a2a
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size,
pgi.world_size, deepep_ll_args.num_experts)
buffer = deep_ep.Buffer(group=pg,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=deepep_ll_args.num_experts //
pgi.world_size)
return DeepEPLLPrepareAndFinalize(
buffer=buffer,
world_size=pgi.world_size,
dp_size=dp_size,
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
)
def make_deepep_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
deepep_ht_args: Optional[DeepEPHTArgs],
deepep_ll_args: Optional[DeepEPLLArgs],
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
if deepep_ht_args is not None:
assert deepep_ll_args is None
return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype,
block_shape)
assert deepep_ll_args is not None
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype,
block_shape)

View File

@@ -2,18 +2,59 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
import pytest
import torch
import triton.language as tl
from tests.kernels.moe.utils import (batched_moe,
make_quantized_test_activations,
make_test_weights, triton_moe)
from tests.kernels.quant_utils import native_batched_masked_quant_matmul
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
invoke_moe_batched_triton_kernel)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform
MNK_FACTORS = [
(1, 128, 128),
(1, 128, 2048),
(1, 512, 512),
(1, 1024, 128),
(1, 1024, 2048),
(32, 128, 128),
(32, 512, 512),
(32, 1024, 2048),
(45, 128, 128),
(45, 128, 2048),
(45, 512, 512),
(45, 1024, 128),
(45, 1024, 2048),
(64, 128, 128),
(64, 512, 512),
(64, 1024, 2048),
(222, 128, 128),
(222, 128, 2048),
(222, 512, 512),
(222, 1024, 128),
(222, 1024, 2048),
]
NUM_EXPERTS = [8, 64]
TOP_KS = [1, 2, 6]
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@dataclass
class BatchedMMConfig:
dtype: torch.dtype
in_dtype: torch.dtype
quant_dtype: Optional[torch.dtype]
out_dtype: torch.dtype
num_experts: int
max_tokens_per_expert: int
K: int
@@ -32,79 +73,127 @@ class BatchedMMTensors:
A = torch.randn(
(config.num_experts, config.max_tokens_per_expert, config.K),
device="cuda",
dtype=config.dtype) / 10
dtype=config.in_dtype) / 10
B = torch.randn((config.num_experts, config.N, config.K),
device="cuda",
dtype=config.dtype)
dtype=config.in_dtype)
C = torch.zeros(
(config.num_experts, config.max_tokens_per_expert, config.N),
device="cuda",
dtype=config.dtype)
dtype=config.out_dtype)
num_expert_tokens = torch.randint(low=0,
high=config.max_tokens_per_expert,
size=(config.num_experts, ),
device="cuda",
dtype=torch.int32)
return BatchedMMTensors(A, B, C, num_expert_tokens)
def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
num_expert_tokens: torch.Tensor) -> torch.Tensor:
num_expert_tokens_cpu = num_expert_tokens.clone()
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
num_experts = num_expert_tokens.size(0)
for e in range(num_experts):
num_tokens = num_expert_tokens_cpu[e]
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
return C
@pytest.mark.parametrize("num_experts", [16, 32])
@pytest.mark.parametrize("num_experts", [8, 16, 32])
@pytest.mark.parametrize("max_tokens_per_expert",
[32, 64, 128, 192, 224, 256, 512])
@pytest.mark.parametrize("K", [128, 256, 1024])
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("block_shape", [None])
@pytest.mark.parametrize("per_act_token_quant", [False])
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
N: int, dtype: torch.dtype):
N: int, dtype: torch.dtype,
block_shape: Optional[list[int]],
per_act_token_quant: bool):
current_platform.seed_everything(7)
config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N)
tensors = BatchedMMTensors.make_tensors(config)
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
test_output = tensors.C
ref_output = test_output.clone()
if (per_act_token_quant or block_shape is not None) and not use_fp8_w8a8:
pytest.skip("Don't test blocking for non-quantized types.")
if per_act_token_quant and block_shape is not None:
pytest.skip("Skip illegal quantization test.")
if dtype.itemsize == 1:
act_dtype = torch.bfloat16
quant_dtype = dtype
else:
act_dtype = dtype
quant_dtype = None
num_expert_tokens = torch.randint(low=0,
high=max_tokens_per_expert,
size=(num_experts, ),
device="cuda",
dtype=torch.int32)
A, A_q, A_scale = make_quantized_test_activations(
num_experts,
max_tokens_per_expert,
K,
in_dtype=act_dtype,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant)
B, B_q, B_scale, _, _, _ = make_test_weights(
num_experts,
N // 2,
K,
in_dtype=act_dtype,
quant_dtype=quant_dtype,
block_shape=block_shape,
)
out_shape = (num_experts, max_tokens_per_expert, N)
test_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
q_ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
compute_tl_dtype = {
torch.float16: tl.float16,
torch.bfloat16: tl.bfloat16,
torch.float32: tl.float32
}[test_output.dtype]
assert A_q.dtype == B_q.dtype
invoke_moe_batched_triton_kernel(
tensors.A,
tensors.B,
A_q,
B_q,
test_output,
tensors.num_expert_tokens,
num_expert_tokens,
compute_tl_dtype,
# Quantization data
None,
None,
A_scale,
B_scale,
None,
# Quantization schemes
False,
use_fp8_w8a8,
False,
False,
config={
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 16
})
"BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32
},
block_shape=block_shape,
)
ref_output = ref_impl(tensors.A, tensors.B, ref_output,
tensors.num_expert_tokens)
ref_output = native_batched_masked_quant_matmul(
A,
B,
ref_output,
num_expert_tokens,
None,
None,
None,
)
q_ref_output = native_batched_masked_quant_matmul(A_q, B_q, q_ref_output,
num_expert_tokens,
A_scale, B_scale,
block_shape)
rtol, atol = {
torch.float16: (6e-2, 6e-2),
@@ -112,4 +201,98 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
torch.float32: (1e-2, 1e-2),
}[test_output.dtype]
torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol)
torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol)
torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)
@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("per_act_token_quant", [False])
@pytest.mark.parametrize("block_shape", [None])
def test_fused_moe_batched_experts(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
per_act_token_quant: bool,
block_shape: Optional[list[int]],
):
current_platform.seed_everything(7)
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
pytest.skip("Skip quantization test for non-quantized type")
if per_act_token_quant and block_shape is not None or topk > e:
pytest.skip("Skip illegal quantization test.")
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
if dtype.itemsize == 1:
act_dtype = torch.bfloat16
quant_dtype = dtype
else:
act_dtype = dtype
quant_dtype = None
_, w1, w1_s, _, w2, w2_s = make_test_weights(e,
n,
k,
block_shape=block_shape,
in_dtype=act_dtype,
quant_dtype=quant_dtype)
with set_current_vllm_config(vllm_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
batched_output = batched_moe(
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
baseline_output = torch_experts(
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape)
triton_output = triton_moe(
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
torch.testing.assert_close(triton_output,
baseline_output,
atol=2e-2,
rtol=2e-2)
torch.testing.assert_close(triton_output,
batched_output,
atol=2e-2,
rtol=2e-2)

View File

@@ -0,0 +1,296 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
native_w8a8_block_matmul)
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
from vllm.platforms import current_platform
dg_available = False
try:
import deep_gemm
dg_available = True
except ImportError:
pass
if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
# Test configurations
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
# and its hidden size is 7168.
MNK_FACTORS = [
(1, 128, 128),
(1, 512, 512),
(1, 128, 7168),
(1, 1024, 7168),
(1, 4608, 128),
(1, 4608, 512),
(1, 4608, 7168),
(83, 128, 128),
(83, 512, 512),
(83, 1024, 7168),
(83, 4608, 512),
(83, 4608, 7168),
(128, 128, 128),
(128, 512, 512),
(128, 1024, 7168),
(128, 4608, 512),
(128, 4608, 7168),
(2048, 128, 128),
(2048, 1024, 7168),
(2048, 4608, 512),
(2048, 4608, 7168),
(8192, 128, 128),
(8192, 512, 512),
(8192, 128, 7168),
(8192, 1024, 7168),
(8192, 4608, 512),
(8192, 4608, 7168),
]
MNK_FACTORS_DG = [
(128, 128, 128),
(128, 512, 512),
(128, 128, 7168),
(128, 1024, 7168),
(128, 4608, 128),
(128, 4608, 512),
(128, 4608, 7168),
(192, 128, 128),
(192, 512, 512),
(192, 1024, 7168),
(192, 4608, 512),
(192, 4608, 7168),
(1335, 128, 128),
(1335, 1024, 7168),
(1335, 4608, 512),
(1335, 4608, 7168),
(2048, 128, 128),
(2048, 512, 512),
(2048, 128, 7168),
(2048, 1024, 7168),
(2048, 4608, 128),
(2048, 4608, 512),
(2048, 4608, 7168),
]
BLOCK_SIZE = [[128, 128]]
E = [2, 8, 16] # [128, 256]
TOP_KS = [1, 2, 6]
SEEDS = [0]
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids,
block_shape):
"""Fused moe with block-wise quantization using native torch."""
B, D = a.shape
topk = topk_ids.size(1)
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
_, block_k = block_shape[0], block_shape[1]
a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
a_q = a_q.to(torch.float32)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
inter_out = native_w8a8_block_matmul(a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
block_shape,
output_dtype=a.dtype)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = native_per_token_group_quant_fp8(
act_out, block_k)
out[mask] = native_w8a8_block_matmul(act_out_q,
w2[i],
act_out_s,
w2_s[i],
block_shape,
output_dtype=a.dtype)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
# Skip all tests if CUDA is not available
pytest.importorskip("torch.cuda")
@pytest.fixture(autouse=True)
def setup_cuda():
torch.set_default_device("cuda")
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS)
@pytest.mark.parametrize("E", E)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
monkeypatch):
if topk > E:
pytest.skip(f"Skipping test; topk={topk} > E={E}")
torch.manual_seed(seed)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "2048")
a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype)
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
N,
K,
dtype,
torch.float8_e4m3fn,
per_act_token_quant=False,
block_shape=block_size)
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_act_token_quant=False,
block_shape=block_size)
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
ref_out = torch_w8a8_block_fp8_moe(
a,
w1,
w2,
w1_s,
w2_s,
topk_weights,
topk_ids,
block_size,
)
out = fused_experts(
a,
w1,
w2,
topk_weights,
topk_ids,
use_fp8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=block_size,
)
m_out = m_fused_moe(
a,
w1,
w2,
topk_weights,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
)
# 0.039 only needed for [40000-4608-7168-2-1-block_size852-dtype852-0]
tol = 0.035 if M < 40000 else 0.039
torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol)
torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol)
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS_DG)
@pytest.mark.parametrize("E", E)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
monkeypatch):
if topk > E:
pytest.skip(f"Skipping test: topk={topk} > E={E}")
if not _valid_deep_gemm_shape(M, N, K):
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
chunk_size = 1024
torch.manual_seed(seed)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
block_size = [block_m, block_m]
dtype = torch.bfloat16
a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype)
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
N,
K,
dtype,
torch.float8_e4m3fn,
per_act_token_quant=False,
block_shape=block_size)
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
use_compile = False
use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024
and current_platform.is_cuda_alike())
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weights,
topk_ids, block_size)
if use_compile:
deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8,
backend="inductor",
fullgraph=True)
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(topk_weights, 0)
torch._dynamo.mark_dynamic(topk_ids, 0)
else:
deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
topk_ids)
if use_cudagraph:
out.fill_(0)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
topk_ids)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()
torch.testing.assert_close(out, ref_out, atol=0.035, rtol=0.035)

View File

@@ -0,0 +1,147 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.quant_utils import (native_per_token_group_quant_int8,
native_w8a8_block_matmul)
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.platforms import current_platform
if current_platform.get_device_capability() < (7, 0):
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
DTYPES = [torch.half, torch.bfloat16]
MNK_FACTORS = [
(1, 128, 128),
(1, 512, 512),
(1, 128, 7168),
(1, 1024, 7168),
(1, 4096, 128),
(1, 4096, 512),
(1, 4096, 7168),
(33, 128, 128),
(33, 512, 512),
(33, 128, 7168),
(33, 1024, 7168),
(33, 4096, 128),
(33, 4096, 512),
(33, 4096, 7168),
(128, 128, 128),
(128, 512, 512),
(128, 1024, 7168),
(128, 4096, 512),
(128, 4096, 7168),
(222, 128, 128),
(222, 512, 512),
(222, 1024, 7168),
(222, 4096, 512),
(222, 4096, 7168),
(2048, 128, 128),
(2048, 1024, 7168),
(2048, 4096, 512),
(2048, 4096, 7168),
]
E = [8, 24]
TOP_KS = [2, 6]
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
BLOCK_SIZE = [[128, 128]]
SEEDS = [0]
# For test
def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
"""This function performs fused moe with block-wise quantization using
native torch."""
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
_, block_k = block_shape[0], block_shape[1]
a_q, a_s = native_per_token_group_quant_int8(a, block_k)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
inter_out = native_w8a8_block_matmul(a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
block_shape,
output_dtype=a.dtype)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = native_per_token_group_quant_int8(
act_out, block_k)
act_out = act_out.to(torch.float32)
out[mask] = native_w8a8_block_matmul(act_out_q,
w2[i],
act_out_s,
w2_s[i],
block_shape,
output_dtype=a.dtype)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
@pytest.fixture(autouse=True, scope="module")
def setup_cuda():
"""Sets the default CUDA device for all tests in this module."""
torch.set_default_device("cuda")
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS)
@pytest.mark.parametrize("E", E)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
"""Tests the fused_moe kernel with W8A8 INT8 block quantization against a
native torch reference."""
torch.manual_seed(seed)
a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype)
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
N,
K,
dtype,
torch.int8,
per_act_token_quant=False,
block_shape=block_size)
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
out = fused_moe(
a,
w1,
w2,
score,
topk,
renormalize=False,
use_int8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=block_size,
)
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk,
block_size)
# Check results
torch.testing.assert_close(out, ref_out, atol=0.065, rtol=0.065)

View File

@@ -97,11 +97,9 @@ class MOETensors8Bit(MOETensors):
n_b_scales = 2 * n if per_out_channel else 1
k_b_scales = k if per_out_channel else 1
# Get the right scale for tests.
_, a_scale = ops.scaled_fp8_quant(
moe_tensors_fp16.a, use_per_token_if_dynamic=per_act_token)
a_q, _ = ops.scaled_fp8_quant(moe_tensors_fp16.a,
a_scale,
use_per_token_if_dynamic=per_act_token)
a_q, a_scale = ops.scaled_fp8_quant(
moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token)
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
@@ -187,6 +185,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
def run_8_bit(moe_tensors: MOETensors8Bit,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
num_local_experts: Optional[int] = None) -> torch.Tensor:
assert not any([
t is None for t in [
@@ -203,7 +202,8 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'topk_ids': topk_ids,
'w1_scale': moe_tensors.w1_scale,
'w2_scale': moe_tensors.w2_scale,
'a1_scale': moe_tensors.a_scale
'per_act_token': per_act_token,
'a1_scale': None #moe_tensors.a_scale
}
num_experts = moe_tensors.w1.size(0)
@@ -254,11 +254,13 @@ def test_cutlass_moe_8_bit_no_graph(
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
topk_ids)
cutlass_output = run_8_bit(mt, topk_weights, topk_ids)
cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token)
# Note 5.5 only needed for larger problem sizes, 5 works ok for
# the rest.
torch.testing.assert_close(triton_output,
cutlass_output,
atol=5e-2,
atol=5.5e-2,
rtol=1e-2)
@@ -303,7 +305,8 @@ def test_cutlass_moe_8_bit_cuda_graph(
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
cutlass_output = run_8_bit(mt, topk_weights, topk_ids)
cutlass_output = run_8_bit(mt, topk_weights, topk_ids,
per_act_token)
torch.cuda.synchronize()
graph.replay()
@@ -359,6 +362,7 @@ def test_cutlass_moe_8_bit_EP(
cutlass_output = run_8_bit(mt,
topk_weights,
topk_ids,
per_act_token,
num_local_experts=e // ep_size)
torch.testing.assert_close(triton_output,

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
"""
Test DeepEP + DeepGEMM integration
Test DeepEP + DeepGEMM integration
DeepGEMM are gemm kernels specialized for the
fp8 block-quantized case.
"""
@@ -17,12 +17,11 @@ from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm
from .utils import ProcessGroupInfo, parallel_launch
from .parallel_utils import ProcessGroupInfo, parallel_launch
from .utils import make_test_weights
if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
@@ -30,10 +29,9 @@ if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
from .utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
if has_deep_gemm():
import deep_gemm
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
@@ -60,25 +58,6 @@ def next_power_of_2(x):
return 2**math.ceil(math.log2(x))
def per_block_cast_to_fp8(
x: torch.Tensor,
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(deep_gemm.ceil_div(m, 128) * 128,
deep_gemm.ceil_div(n, block_size_n) * block_size_n),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
def make_block_quant_fp8_weights(
e: int,
n: int,
@@ -86,43 +65,11 @@ def make_block_quant_fp8_weights(
block_size: list[int],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Return weights w1, w2, w1q, w2q, w1_scale, w2_scale
Return weights w1q, w2q, w1_scale, w2_scale
"""
dtype = torch.bfloat16
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
w1_bf16 = torch.randn((e, 2 * n, k), dtype=dtype) / 10
w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
w2_bf16 = torch.randn((e, k, n), dtype=dtype) / 10
w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = ((2 * n) + block_n - 1) // block_n
k_tiles_w1 = (k + block_k - 1) // block_k
n_tiles_w2 = (k + block_n - 1) // block_n
k_tiles_w2 = (n + block_k - 1) // block_k
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1),
device="cuda",
dtype=torch.float32)
w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2),
device="cuda",
dtype=torch.float32)
assert w1_s.shape == (e, (2 * n + 127) // 128, (k + 127) // 128)
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
for i in range(e):
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
return w1, w2, w1_s, w2_s
w1, w1q, w1_scale, w2, w2q, w2_scale = make_test_weights(
e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_size)
return w1q, w2q, w1_scale, w2_scale
@dataclasses.dataclass
@@ -132,6 +79,7 @@ class TestConfig:
k: int
n: int
num_experts: int
per_act_token_quant: bool
block_size: list[int]
# configs for testing low-latency kernels
low_latency: bool
@@ -150,8 +98,7 @@ class TestTensors:
def make(config: TestConfig, rank) -> "TestTensors":
dtype = torch.bfloat16
topk, m, k, block_size = (config.topk, config.m, config.k,
config.block_size)
topk, m, k = (config.topk, config.m, config.k)
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
@@ -159,9 +106,7 @@ class TestTensors:
rank_tokens = torch.randn(
(m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
block_k = block_size[1]
_, rank_token_scales = per_token_group_quant_fp8(rank_tokens, block_k)
rank_token_scales = None
topk_ids = torch.randint(
low=0,
@@ -201,10 +146,12 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
q_dtype=q_dtype,
block_shape=test_config.block_size)
fused_experts = BatchedDeepGemmExperts(max_num_tokens=max_tokens_per_rank,
world_size=pgi.world_size,
dp_size=dp_size,
block_shape=test_config.block_size)
fused_experts = BatchedDeepGemmExperts(
max_num_tokens=max_tokens_per_rank,
world_size=pgi.world_size,
dp_size=dp_size,
block_shape=test_config.block_size,
per_act_token_quant=test_config.per_act_token_quant)
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
return mk
@@ -426,6 +373,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
"""
Tests for High-Throughput DeepEP + DeepGemm integration.
"""
import deep_gemm
m, n, k = mnk
current_platform.seed_everything(7)
@@ -442,6 +390,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
k=k,
n=n,
num_experts=num_experts,
per_act_token_quant=False,
block_size=block_size,
low_latency=False,
use_fp8_dispatch=None)
@@ -474,10 +423,14 @@ USE_FP8_DISPATCH = [False]
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@requires_deep_ep
@requires_deep_gemm
def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int,
int], num_experts: int, topk: int,
use_fp8_dispatch: bool, block_size: list[int],
world_dp_size: tuple[int, int]):
def test_ll_deepep_deepgemm_moe(
mnk: tuple[int, int, int],
num_experts: int,
topk: int,
use_fp8_dispatch: bool,
block_size: list[int],
world_dp_size: tuple[int, int],
):
"""
Tests for Low-Latency DeepEP + DeepGemm integration.
"""
@@ -495,6 +448,7 @@ def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int,
k=k,
n=n,
num_experts=num_experts,
per_act_token_quant=False,
block_size=block_size,
low_latency=True,
use_fp8_dispatch=use_fp8_dispatch,

View File

@@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep
from .utils import ProcessGroupInfo, parallel_launch
from .parallel_utils import ProcessGroupInfo, parallel_launch
if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
@@ -31,7 +31,7 @@ if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
from .utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
requires_deep_ep = pytest.mark.skipif(
not has_deep_ep(),
@@ -102,10 +102,6 @@ class TestTensors:
rank_tokens = torch.randn(
(config.m, config.k), device="cuda", dtype=token_dtype) / 10
rank_token_scales = None
if config.dtype == torch.float8_e4m3fn:
# low_latency_mode kernels dont support per-token quant.
_, rank_token_scales = ops.scaled_fp8_quant(
rank_tokens, use_per_token_if_dynamic=not low_latency_mode)
topk = torch.randint(low=0,
high=config.num_experts,
@@ -121,11 +117,18 @@ class TestTensors:
config=config)
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
low_latency_mode: bool, hidden_size: int, dp_size: int,
num_experts: int, num_local_experts: int,
q_dtype: Optional[torch.dtype],
use_fp8_dispatch: bool) -> FusedMoEModularKernel:
def make_modular_kernel(
pg: ProcessGroup,
pgi: ProcessGroupInfo,
low_latency_mode: bool,
hidden_size: int,
dp_size: int,
num_experts: int,
num_local_experts: int,
q_dtype: Optional[torch.dtype],
use_fp8_dispatch: bool,
per_act_token_quant: bool,
) -> FusedMoEModularKernel:
is_quantized = q_dtype is not None
@@ -152,6 +155,7 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
deepep_ll_args = ll_args)
if low_latency_mode:
assert not per_act_token_quant, "not supported in ll mode"
fused_experts = BatchedTritonExperts(
max_num_tokens=MAX_TOKENS_PER_RANK,
world_size=pgi.world_size,
@@ -159,25 +163,37 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
use_fp8_w8a8=is_quantized,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False)
use_int4_w4a16=False,
per_act_token_quant=False,
)
else:
fused_experts = TritonExperts(use_fp8_w8a8=is_quantized,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=False)
fused_experts = TritonExperts(
use_fp8_w8a8=is_quantized,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_act_token_quant=per_act_token_quant,
)
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
return mk
def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
low_latency_mode: bool, dp_size: int,
test_tensors: TestTensors, w1: torch.Tensor,
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], num_experts: int,
use_fp8_dispatch: bool) -> torch.Tensor:
def deep_ep_moe_impl(
pg: ProcessGroup,
pgi: ProcessGroupInfo,
low_latency_mode: bool,
dp_size: int,
test_tensors: TestTensors,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
num_experts: int,
use_fp8_dispatch: bool,
per_act_token_quant: bool,
) -> torch.Tensor:
num_local_experts = w1.size(0)
@@ -199,11 +215,9 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
q_dtype = torch.float8_e4m3fn
# Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(pg, pgi, low_latency_mode,
hidden_size, dp_size,
num_experts,
num_local_experts, q_dtype,
use_fp8_dispatch)
mk: FusedMoEModularKernel = make_modular_kernel(
pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts,
num_local_experts, q_dtype, use_fp8_dispatch, per_act_token_quant)
out_hidden_states = torch.empty_like(test_tensors.rank_tokens)
total_num_tokens = test_tensors.rank_tokens.size(0)
@@ -257,9 +271,15 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
return out_hidden_states
def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor,
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], using_fp8_dispatch: bool):
def torch_moe_impl(
test_tensors: TestTensors,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
using_fp8_dispatch: bool,
per_act_token_quant: bool,
):
a, topk_ids, topk_weights = (test_tensors.rank_tokens, test_tensors.topk,
test_tensors.topk_weights)
@@ -267,6 +287,7 @@ def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor,
# The DeepEP implementation is requested to dispatch using FP8.
# For numerical stability for testing, emulate the fp8 dispatch by
# blockwise quant and de-quant.
assert not per_act_token_quant
a = test_tensors.rank_tokens
aq, aq_scale = per_token_group_quant_fp8(a, 128)
a = (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)).view(
@@ -310,6 +331,7 @@ def _deep_ep_moe(
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
use_fp8_dispatch: bool,
per_act_token_quant: bool,
):
if not low_latency_mode:
@@ -331,7 +353,8 @@ def _deep_ep_moe(
with set_current_vllm_config(VllmConfig()):
# Reference
torch_combined = torch_moe_impl(test_tensors, w1, w2, w1_scale,
w2_scale, use_fp8_dispatch)
w2_scale, use_fp8_dispatch,
per_act_token_quant)
# Splice experts for this rank.
num_local_experts = config.num_experts // pgi.world_size
@@ -356,6 +379,7 @@ def _deep_ep_moe(
w2_scale_ep,
config.num_experts,
use_fp8_dispatch,
per_act_token_quant,
)
torch.testing.assert_close(
@@ -384,10 +408,16 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("topk", [6])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@requires_deep_ep
def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
num_experts: int, topk: int, world_dp_size: tuple[int,
int]):
def test_deep_ep_moe(
dtype: torch.dtype,
mnk: tuple[int, int, int],
num_experts: int,
topk: int,
world_dp_size: tuple[int, int],
per_act_token_quant: bool,
):
low_latency_mode = False
use_fp8_dispatch = False
m, n, k = mnk
@@ -404,7 +434,8 @@ def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch)
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch,
per_act_token_quant)
MNKs = [
@@ -454,4 +485,5 @@ def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch)
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch,
False)

View File

@@ -17,6 +17,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import (
@@ -142,6 +143,10 @@ def test_fused_moe(
# Setup test data
#
#
# Setup test data
#
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
@@ -169,7 +174,7 @@ def test_fused_moe(
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=False,
per_act_token_quant=False,
block_shape=None)
def m_fused_moe(
@@ -365,6 +370,13 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
if dtype == torch.float32:
pytest.skip("AITER ROCm test skip for float32")
monkeypatch.setenv('RANK', "0")
monkeypatch.setenv('LOCAL_RANK', "0")
monkeypatch.setenv('WORLD_SIZE', "1")
monkeypatch.setenv('MASTER_ADDR', 'localhost')
monkeypatch.setenv('MASTER_PORT', '12345')
init_distributed_environment()
# Instantiate our and huggingface's MoE blocks
vllm_config.compilation_config.static_forward_context = dict()
with (set_current_vllm_config(vllm_config),

View File

@@ -14,7 +14,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform
if not current_platform.has_device_capability(100):
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
pytest.skip("Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True)
MNK_FACTORS = [

View File

@@ -15,7 +15,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.platforms import current_platform
from .utils import ProcessGroupInfo, parallel_launch
from .parallel_utils import ProcessGroupInfo, parallel_launch
try:
from pplx_kernels import AllToAll
@@ -93,7 +93,7 @@ def pplx_cutlass_moe(
num_experts=num_experts,
experts_per_token=topk,
rank=rank,
world_size=pgi.world_size,
world_size=world_size,
dp_size=dp_size,
hidden_dim=hidden_dim,
hidden_dim_bytes=hidden_dim, # because a.dtype.itemsize == 1
@@ -118,8 +118,6 @@ def pplx_cutlass_moe(
pgi.world_size,
rank,
dp_size,
quant_dtype=torch.float8_e4m3fn,
per_act_token=per_act_token,
)
experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size,

View File

@@ -18,18 +18,20 @@ try:
except ImportError:
has_pplx = False
from tests.kernels.moe.utils import make_test_weights, naive_batched_moe
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import override_config
from vllm.model_executor.layers.fused_moe import fused_topk, override_config
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk,
get_default_config)
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.platforms import current_platform
from vllm.utils import round_up
from .utils import ProcessGroupInfo, parallel_launch
from .parallel_utils import ProcessGroupInfo, parallel_launch
requires_pplx = pytest.mark.skipif(
not has_pplx,
@@ -144,25 +146,6 @@ def torch_batched_moe(
return torch_finalize(out, topk_weight, topk_ids)
def batched_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor:
num_experts = w1.shape[0]
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens=a.shape[0],
world_size=1,
dp_size=1,
rank=0),
BatchedExperts(max_num_tokens=a.shape[0], dp_size=1, world_size=1))
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 512, 1024])
@@ -188,7 +171,7 @@ def test_fused_moe_batched_experts(
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
baseline_output = torch_experts(a, w1, w2, topk_weight, topk_ids)
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids)
batched_output = naive_batched_moe(a, w1, w2, topk_weight, topk_ids)
torch.testing.assert_close(baseline_output,
torch_output,
@@ -226,7 +209,6 @@ def pplx_prepare_finalize(
topk = topk_ids.shape[1]
num_tokens, hidden_dim = a.shape
block_size = 128
device = pgi.device
rank = pgi.rank
world_size = pgi.world_size
@@ -241,9 +223,7 @@ def pplx_prepare_finalize(
dp_size=dp_size,
hidden_dim=hidden_dim,
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
((hidden_dim + block_size - 1) // block_size *
torch.float32.itemsize)),
hidden_dim_scale_bytes=0,
)
if group_name is None:
@@ -260,7 +240,6 @@ def pplx_prepare_finalize(
world_size,
rank,
dp_size,
a.dtype,
)
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
@@ -276,6 +255,7 @@ def pplx_prepare_finalize(
num_experts,
None,
False,
FusedMoEQuantConfig(),
)
b_a = b_a * 1.5
@@ -350,6 +330,7 @@ def _pplx_prepare_finalize(
# TODO (bnell): this test point does not work for odd M due to how the test is
# written, not due to limitations of the pplx kernels. The pplx_moe
# test below is able to deal with odd M.
# TODO (bnell) add fp8 tests
@pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@@ -386,18 +367,31 @@ def pplx_moe(
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
qtype: Optional[torch.dtype] = None,
per_act_token_quant=False,
block_shape: Optional[list[int]] = None,
use_compile: bool = False,
use_cudagraphs: bool = True,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes)
device = torch.device("cuda", rank)
hidden_dim = a.shape[1]
num_experts = w1.shape[0]
block_size = 128
topk = topk_ids.shape[1]
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), 64)
hidden_dim_bytes, scale_bytes = pplx_hidden_dim_scale_bytes(
max_num_tokens,
hidden_dim,
a.dtype,
qtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
args = dict(
max_num_tokens=max_num_tokens,
@@ -407,10 +401,8 @@ def pplx_moe(
world_size=world_size,
dp_size=dp_size,
hidden_dim=hidden_dim,
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
((hidden_dim + block_size - 1) // block_size *
torch.float32.itemsize)),
hidden_dim_bytes=hidden_dim_bytes,
hidden_dim_scale_bytes=scale_bytes,
)
if group_name is None:
@@ -429,9 +421,11 @@ def pplx_moe(
dp_size,
)
experts = BatchedTritonExperts(max_num_tokens=a.shape[0],
experts = BatchedTritonExperts(max_num_tokens=max_num_tokens,
world_size=world_size,
dp_size=dp_size)
dp_size=dp_size,
use_fp8_w8a8=qtype == torch.float8_e4m3fn,
block_shape=block_shape)
fused_experts = FusedMoEModularKernel(
prepare_finalize,
@@ -447,6 +441,13 @@ def pplx_moe(
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
if w1_scale is not None:
w1_scale_chunk = chunk_by_rank(w1_scale, rank, world_size).to(device)
w2_scale_chunk = chunk_by_rank(w2_scale, rank, world_size).to(device)
else:
w1_scale_chunk = None
w2_scale_chunk = None
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
@@ -465,6 +466,8 @@ def pplx_moe(
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
global_num_experts=num_experts)
if use_cudagraphs:
@@ -477,6 +480,8 @@ def pplx_moe(
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
global_num_experts=num_experts)
torch.cuda.synchronize()
@@ -505,9 +510,9 @@ def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
rank=rank,
)
experts = BatchedExperts(max_num_tokens=a.shape[0],
world_size=1,
dp_size=1)
experts = NaiveBatchedExperts(max_num_tokens=a.shape[0],
world_size=1,
dp_size=1)
fused_experts = FusedMoEModularKernel(
prepare_finalize,
@@ -539,7 +544,12 @@ def _pplx_moe(
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
use_internode: bool,
w1_s: Optional[torch.Tensor] = None,
w2_s: Optional[torch.Tensor] = None,
qtype: Optional[torch.dtype] = None,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
use_internode: bool = False,
):
if use_internode:
uid = nvshmem_get_unique_id(
@@ -557,11 +567,28 @@ def _pplx_moe(
moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)
device = torch.device("cuda", pgi.rank)
a = a.to(device)
w1 = w1.to(device)
w2 = w2.to(device)
w1_s = w1_s.to(device) if w1_s is not None else None
w2_s = w2_s.to(device) if w2_s is not None else None
with set_current_vllm_config(vllm_config), override_config(moe_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
torch_output = torch_experts(a, w1, w2, topk_weight, topk_ids)
torch_output = torch_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
quant_dtype=qtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape)
pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size,
a, w1, w2, topk_weight, topk_ids)
a, w1, w2, topk_weight, topk_ids, w1_s, w2_s,
qtype, per_act_token_quant, block_shape)
# TODO (bnell): fix + re-enable
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
# topk_ids)
@@ -581,6 +608,8 @@ def _pplx_moe(
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("use_internode", [False])
@requires_pplx
def test_pplx_moe(
@@ -589,15 +618,33 @@ def test_pplx_moe(
topk: int,
dtype: torch.dtype,
world_dp_size: tuple[int, int],
per_act_token_quant: bool,
block_shape: Optional[list[int]],
use_internode: bool,
):
current_platform.seed_everything(7)
m, n, k = mnk
world_size, dp_size = world_dp_size
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype)
if dtype == torch.float8_e4m3fn:
use_fp8_w8a8 = True
quant_dtype = dtype
else:
use_fp8_w8a8 = False
quant_dtype = None
if not use_fp8_w8a8 and per_act_token_quant and block_shape is not None:
pytest.skip("Skip quantization test for non-quantized type")
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
_, w1, w1_s, _, w2, w2_s = make_test_weights(e,
n,
k,
quant_dtype=quant_dtype,
block_shape=block_shape)
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk,
w1_s, w2_s, quant_dtype, per_act_token_quant, block_shape,
use_internode)

View File

@@ -1,194 +1,249 @@
# SPDX-License-Identifier: Apache-2.0
"""
DeepEP test utilities
"""
import dataclasses
import importlib
import os
import traceback
from typing import Callable, Optional
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from torch.distributed import ProcessGroup
from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
from vllm.utils import get_open_port
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
if has_deep_ep:
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
## Parallel Processes Utils
P = ParamSpec("P")
import vllm._custom_ops as ops
from tests.kernels.quant_utils import (per_block_cast_to_fp8,
per_block_cast_to_int8)
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
from vllm.utils import round_up
@dataclasses.dataclass
class ProcessGroupInfo:
world_size: int
world_local_size: int
rank: int
node_rank: int
local_rank: int
device: torch.device
def triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
per_act_token_quant=False,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
per_channel_quant=per_act_token_quant,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
block_shape=block_shape)
def _worker_parallel_launch(
local_rank: int,
world_size: int,
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
rank = node_rank * world_local_size + local_rank
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method=init_method,
rank=rank,
world_size=world_size,
device_id=device,
)
barrier = torch.tensor([rank], device=device)
torch.distributed.all_reduce(barrier)
def batched_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64)
try:
worker(
ProcessGroupInfo(
world_size=world_size,
world_local_size=world_local_size,
rank=rank,
node_rank=node_rank,
local_rank=local_rank,
device=device,
),
*args,
**kwargs,
)
except Exception as ex:
print(ex)
traceback.print_exc()
raise
finally:
torch.distributed.destroy_process_group()
def parallel_launch(
world_size: int,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
assert not kwargs
spawn(
_worker_parallel_launch,
args=(
world_size,
world_size,
0,
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
worker,
) + args,
nprocs=world_size,
join=True,
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
world_size=1,
dp_size=1,
rank=0),
BatchedTritonExperts(
max_num_tokens=max_num_tokens,
world_size=1,
dp_size=1,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
),
)
## DeepEP specific utils
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
@dataclasses.dataclass
class DeepEPHTArgs:
num_local_experts: int
def naive_batched_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64)
@dataclasses.dataclass
class DeepEPLLArgs:
max_tokens_per_rank: int
hidden_size: int
num_experts: int
use_fp8_dispatch: bool
def make_deepep_ht_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
ht_args: DeepEPHTArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
import deep_ep
# high throughput a2a
num_nvl_bytes = 1024 * 1024 * 1024 # 1GB
num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1
buffer = deep_ep.Buffer(group=pg,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=low_latency_mode,
num_qps_per_rank=num_qps_per_rank)
return DeepEPHTPrepareAndFinalize(buffer=buffer,
world_size=pgi.world_size,
rank=pgi.rank,
dp_size=dp_size,
rank_expert_offset=pgi.rank *
ht_args.num_local_experts,
quant_dtype=q_dtype,
block_shape=block_shape)
def make_deepep_ll_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
deepep_ll_args: DeepEPLLArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
import deep_ep
# low-latency a2a
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size,
pgi.world_size, deepep_ll_args.num_experts)
buffer = deep_ep.Buffer(group=pg,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=deepep_ll_args.num_experts //
pgi.world_size)
return DeepEPLLPrepareAndFinalize(
buffer=buffer,
world_size=pgi.world_size,
dp_size=dp_size,
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
quant_dtype=q_dtype,
block_shape=block_shape,
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
world_size=1,
dp_size=1,
rank=0),
NaiveBatchedExperts(
max_num_tokens=max_num_tokens,
dp_size=1,
world_size=1,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
),
)
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
def make_deepep_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
deepep_ht_args: Optional[DeepEPHTArgs],
deepep_ll_args: Optional[DeepEPLLArgs],
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
if deepep_ht_args is not None:
assert deepep_ll_args is None
return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype,
block_shape)
assert deepep_ll_args is not None
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype,
block_shape)
def chunk_scales(scales: Optional[torch.Tensor], start: int,
end: int) -> Optional[torch.Tensor]:
if scales is not None:
if scales.numel() == 1:
return scales
else:
return scales[start:end]
return None
def make_quantized_test_activations(
E: int,
m: int,
k: int,
in_dtype: torch.dtype,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10
a_q = a
a_scale = None
if quant_dtype is not None:
assert (quant_dtype == torch.float8_e4m3fn
or quant_dtype == torch.int8), "only fp8/int8 supported"
a_q = torch.zeros_like(a, dtype=quant_dtype)
a_scale_l = [None] * E
for e in range(E):
a_q[e], a_scale_l[e] = moe_kernel_quantize_input(
a[e], None, quant_dtype, per_act_token_quant, block_shape)
a_scale = torch.stack(a_scale_l)
if not per_act_token_quant and block_shape is None:
a_scale = a_scale.view(E, 1, 1)
return a, a_q, a_scale
def moe_quantize_weights(
w: torch.Tensor,
w_s: Optional[torch.Tensor],
quant_dtype: Optional[torch.dtype],
per_token_quant: bool,
block_shape: Optional[list[int]],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert (quant_dtype == torch.float8_e4m3fn
or quant_dtype == torch.int8), "only fp8/int8 supported"
if block_shape is not None:
assert not per_token_quant
if quant_dtype == torch.int8:
w, w_s = per_block_cast_to_int8(w, block_shape)
else:
w, w_s = per_block_cast_to_fp8(w, block_shape)
else:
if quant_dtype == torch.int8:
w, w_s = ops.scaled_int8_quant(
w, w_s, use_per_token_if_dynamic=per_token_quant)
else:
w, w_s = ops.scaled_fp8_quant(
w, w_s, use_per_token_if_dynamic=per_token_quant)
return w, w_s
def make_test_weight(
e: int,
rows: int,
cols: int,
in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
if quant_dtype is not None:
w_l = [None] * e
w_s_l = [None] * e
for idx in range(e):
w_l[idx], w_s_l[idx] = moe_quantize_weights(
w_16[idx], None, quant_dtype, per_act_token_quant, block_shape)
w = torch.stack(w_l)
w_s = torch.stack(w_s_l)
if w_s.ndim == 2:
assert w_s.shape[-1] == 1
w_s = w_s.view(-1, 1, 1)
if block_shape is not None:
block_n, block_k = block_shape
n_tiles = (rows + block_n - 1) // block_n
k_tiles = (cols + block_k - 1) // block_k
assert w_s.shape == (e, n_tiles, k_tiles)
else:
w = w_16
w_s = None
return w_16, w, w_s
def make_test_weights(
e: int,
n: int,
k: int,
in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor,
torch.Tensor, Optional[torch.Tensor]]:
return (
*make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
per_act_token_quant),
*make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
per_act_token_quant),
)