[Kernels] MoE refactor (#19636)
Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
190
tests/kernels/moe/parallel_utils.py
Normal file
190
tests/kernels/moe/parallel_utils.py
Normal file
@@ -0,0 +1,190 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
DeepEP test utilities
|
||||
"""
|
||||
import dataclasses
|
||||
import importlib
|
||||
import os
|
||||
import traceback
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.multiprocessing import (
|
||||
spawn) # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
from vllm.utils import get_open_port
|
||||
|
||||
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
|
||||
if has_deep_ep:
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
DeepEPHTPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
|
||||
## Parallel Processes Utils
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ProcessGroupInfo:
|
||||
world_size: int
|
||||
world_local_size: int
|
||||
rank: int
|
||||
node_rank: int
|
||||
local_rank: int
|
||||
device: torch.device
|
||||
|
||||
|
||||
def _worker_parallel_launch(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
world_local_size: int,
|
||||
node_rank: int,
|
||||
init_method: str,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
rank = node_rank * world_local_size + local_rank
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
torch.distributed.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
init_method=init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
device_id=device,
|
||||
)
|
||||
barrier = torch.tensor([rank], device=device)
|
||||
torch.distributed.all_reduce(barrier)
|
||||
|
||||
try:
|
||||
worker(
|
||||
ProcessGroupInfo(
|
||||
world_size=world_size,
|
||||
world_local_size=world_local_size,
|
||||
rank=rank,
|
||||
node_rank=node_rank,
|
||||
local_rank=local_rank,
|
||||
device=device,
|
||||
),
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def parallel_launch(
|
||||
world_size: int,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
assert not kwargs
|
||||
spawn(
|
||||
_worker_parallel_launch,
|
||||
args=(
|
||||
world_size,
|
||||
world_size,
|
||||
0,
|
||||
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
|
||||
worker,
|
||||
) + args,
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
|
||||
## DeepEP specific utils
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DeepEPHTArgs:
|
||||
num_local_experts: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DeepEPLLArgs:
|
||||
max_tokens_per_rank: int
|
||||
hidden_size: int
|
||||
num_experts: int
|
||||
use_fp8_dispatch: bool
|
||||
|
||||
|
||||
def make_deepep_ht_a2a(pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
ht_args: DeepEPHTArgs,
|
||||
q_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None):
|
||||
|
||||
import deep_ep
|
||||
|
||||
# high throughput a2a
|
||||
num_nvl_bytes = 1024 * 1024 * 1024 # 1GB
|
||||
num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1
|
||||
buffer = deep_ep.Buffer(group=pg,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=low_latency_mode,
|
||||
num_qps_per_rank=num_qps_per_rank)
|
||||
return DeepEPHTPrepareAndFinalize(buffer=buffer,
|
||||
world_size=pgi.world_size,
|
||||
rank=pgi.rank,
|
||||
dp_size=dp_size,
|
||||
rank_expert_offset=pgi.rank *
|
||||
ht_args.num_local_experts)
|
||||
|
||||
|
||||
def make_deepep_ll_a2a(pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
deepep_ll_args: DeepEPLLArgs,
|
||||
q_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None):
|
||||
|
||||
import deep_ep
|
||||
|
||||
# low-latency a2a
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size,
|
||||
pgi.world_size, deepep_ll_args.num_experts)
|
||||
|
||||
buffer = deep_ep.Buffer(group=pg,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=deepep_ll_args.num_experts //
|
||||
pgi.world_size)
|
||||
|
||||
return DeepEPLLPrepareAndFinalize(
|
||||
buffer=buffer,
|
||||
world_size=pgi.world_size,
|
||||
dp_size=dp_size,
|
||||
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
|
||||
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
|
||||
)
|
||||
|
||||
|
||||
def make_deepep_a2a(pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
deepep_ht_args: Optional[DeepEPHTArgs],
|
||||
deepep_ll_args: Optional[DeepEPLLArgs],
|
||||
q_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None):
|
||||
if deepep_ht_args is not None:
|
||||
assert deepep_ll_args is None
|
||||
return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype,
|
||||
block_shape)
|
||||
|
||||
assert deepep_ll_args is not None
|
||||
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype,
|
||||
block_shape)
|
||||
@@ -2,18 +2,59 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import triton.language as tl
|
||||
|
||||
from tests.kernels.moe.utils import (batched_moe,
|
||||
make_quantized_test_activations,
|
||||
make_test_weights, triton_moe)
|
||||
from tests.kernels.quant_utils import native_batched_masked_quant_matmul
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
invoke_moe_batched_triton_kernel)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 128, 128),
|
||||
(1, 128, 2048),
|
||||
(1, 512, 512),
|
||||
(1, 1024, 128),
|
||||
(1, 1024, 2048),
|
||||
(32, 128, 128),
|
||||
(32, 512, 512),
|
||||
(32, 1024, 2048),
|
||||
(45, 128, 128),
|
||||
(45, 128, 2048),
|
||||
(45, 512, 512),
|
||||
(45, 1024, 128),
|
||||
(45, 1024, 2048),
|
||||
(64, 128, 128),
|
||||
(64, 512, 512),
|
||||
(64, 1024, 2048),
|
||||
(222, 128, 128),
|
||||
(222, 128, 2048),
|
||||
(222, 512, 512),
|
||||
(222, 1024, 128),
|
||||
(222, 1024, 2048),
|
||||
]
|
||||
NUM_EXPERTS = [8, 64]
|
||||
TOP_KS = [1, 2, 6]
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchedMMConfig:
|
||||
dtype: torch.dtype
|
||||
in_dtype: torch.dtype
|
||||
quant_dtype: Optional[torch.dtype]
|
||||
out_dtype: torch.dtype
|
||||
num_experts: int
|
||||
max_tokens_per_expert: int
|
||||
K: int
|
||||
@@ -32,79 +73,127 @@ class BatchedMMTensors:
|
||||
A = torch.randn(
|
||||
(config.num_experts, config.max_tokens_per_expert, config.K),
|
||||
device="cuda",
|
||||
dtype=config.dtype) / 10
|
||||
dtype=config.in_dtype) / 10
|
||||
B = torch.randn((config.num_experts, config.N, config.K),
|
||||
device="cuda",
|
||||
dtype=config.dtype)
|
||||
dtype=config.in_dtype)
|
||||
C = torch.zeros(
|
||||
(config.num_experts, config.max_tokens_per_expert, config.N),
|
||||
device="cuda",
|
||||
dtype=config.dtype)
|
||||
dtype=config.out_dtype)
|
||||
|
||||
num_expert_tokens = torch.randint(low=0,
|
||||
high=config.max_tokens_per_expert,
|
||||
size=(config.num_experts, ),
|
||||
device="cuda",
|
||||
dtype=torch.int32)
|
||||
|
||||
return BatchedMMTensors(A, B, C, num_expert_tokens)
|
||||
|
||||
|
||||
def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
||||
num_expert_tokens: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
num_expert_tokens_cpu = num_expert_tokens.clone()
|
||||
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
|
||||
num_experts = num_expert_tokens.size(0)
|
||||
|
||||
for e in range(num_experts):
|
||||
num_tokens = num_expert_tokens_cpu[e]
|
||||
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
|
||||
|
||||
return C
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_experts", [16, 32])
|
||||
@pytest.mark.parametrize("num_experts", [8, 16, 32])
|
||||
@pytest.mark.parametrize("max_tokens_per_expert",
|
||||
[32, 64, 128, 192, 224, 256, 512])
|
||||
@pytest.mark.parametrize("K", [128, 256, 1024])
|
||||
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("block_shape", [None])
|
||||
@pytest.mark.parametrize("per_act_token_quant", [False])
|
||||
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
N: int, dtype: torch.dtype):
|
||||
N: int, dtype: torch.dtype,
|
||||
block_shape: Optional[list[int]],
|
||||
per_act_token_quant: bool):
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N)
|
||||
tensors = BatchedMMTensors.make_tensors(config)
|
||||
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
|
||||
|
||||
test_output = tensors.C
|
||||
ref_output = test_output.clone()
|
||||
if (per_act_token_quant or block_shape is not None) and not use_fp8_w8a8:
|
||||
pytest.skip("Don't test blocking for non-quantized types.")
|
||||
|
||||
if per_act_token_quant and block_shape is not None:
|
||||
pytest.skip("Skip illegal quantization test.")
|
||||
|
||||
if dtype.itemsize == 1:
|
||||
act_dtype = torch.bfloat16
|
||||
quant_dtype = dtype
|
||||
else:
|
||||
act_dtype = dtype
|
||||
quant_dtype = None
|
||||
|
||||
num_expert_tokens = torch.randint(low=0,
|
||||
high=max_tokens_per_expert,
|
||||
size=(num_experts, ),
|
||||
device="cuda",
|
||||
dtype=torch.int32)
|
||||
|
||||
A, A_q, A_scale = make_quantized_test_activations(
|
||||
num_experts,
|
||||
max_tokens_per_expert,
|
||||
K,
|
||||
in_dtype=act_dtype,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=block_shape,
|
||||
per_act_token_quant=per_act_token_quant)
|
||||
|
||||
B, B_q, B_scale, _, _, _ = make_test_weights(
|
||||
num_experts,
|
||||
N // 2,
|
||||
K,
|
||||
in_dtype=act_dtype,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
out_shape = (num_experts, max_tokens_per_expert, N)
|
||||
test_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
|
||||
ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
|
||||
q_ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
|
||||
|
||||
compute_tl_dtype = {
|
||||
torch.float16: tl.float16,
|
||||
torch.bfloat16: tl.bfloat16,
|
||||
torch.float32: tl.float32
|
||||
}[test_output.dtype]
|
||||
|
||||
assert A_q.dtype == B_q.dtype
|
||||
|
||||
invoke_moe_batched_triton_kernel(
|
||||
tensors.A,
|
||||
tensors.B,
|
||||
A_q,
|
||||
B_q,
|
||||
test_output,
|
||||
tensors.num_expert_tokens,
|
||||
num_expert_tokens,
|
||||
compute_tl_dtype,
|
||||
# Quantization data
|
||||
None,
|
||||
None,
|
||||
A_scale,
|
||||
B_scale,
|
||||
None,
|
||||
# Quantization schemes
|
||||
False,
|
||||
use_fp8_w8a8,
|
||||
False,
|
||||
False,
|
||||
config={
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 16
|
||||
})
|
||||
"BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32
|
||||
},
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
ref_output = ref_impl(tensors.A, tensors.B, ref_output,
|
||||
tensors.num_expert_tokens)
|
||||
ref_output = native_batched_masked_quant_matmul(
|
||||
A,
|
||||
B,
|
||||
ref_output,
|
||||
num_expert_tokens,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
q_ref_output = native_batched_masked_quant_matmul(A_q, B_q, q_ref_output,
|
||||
num_expert_tokens,
|
||||
A_scale, B_scale,
|
||||
block_shape)
|
||||
|
||||
rtol, atol = {
|
||||
torch.float16: (6e-2, 6e-2),
|
||||
@@ -112,4 +201,98 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
torch.float32: (1e-2, 1e-2),
|
||||
}[test_output.dtype]
|
||||
|
||||
torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("per_act_token_quant", [False])
|
||||
@pytest.mark.parametrize("block_shape", [None])
|
||||
def test_fused_moe_batched_experts(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]],
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
|
||||
|
||||
if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
|
||||
pytest.skip("Skip quantization test for non-quantized type")
|
||||
|
||||
if per_act_token_quant and block_shape is not None or topk > e:
|
||||
pytest.skip("Skip illegal quantization test.")
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
if dtype.itemsize == 1:
|
||||
act_dtype = torch.bfloat16
|
||||
quant_dtype = dtype
|
||||
else:
|
||||
act_dtype = dtype
|
||||
quant_dtype = None
|
||||
|
||||
_, w1, w1_s, _, w2, w2_s = make_test_weights(e,
|
||||
n,
|
||||
k,
|
||||
block_shape=block_shape,
|
||||
in_dtype=act_dtype,
|
||||
quant_dtype=quant_dtype)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
batched_output = batched_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
baseline_output = torch_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape)
|
||||
|
||||
triton_output = triton_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(triton_output,
|
||||
baseline_output,
|
||||
atol=2e-2,
|
||||
rtol=2e-2)
|
||||
|
||||
torch.testing.assert_close(triton_output,
|
||||
batched_output,
|
||||
atol=2e-2,
|
||||
rtol=2e-2)
|
||||
|
||||
296
tests/kernels/moe/test_block_fp8.py
Normal file
296
tests/kernels/moe/test_block_fp8.py
Normal file
@@ -0,0 +1,296 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_weights
|
||||
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
|
||||
native_w8a8_block_matmul)
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, modular_triton_fused_moe)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
dg_available = False
|
||||
try:
|
||||
import deep_gemm
|
||||
dg_available = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
||||
allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
# Test configurations
|
||||
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
|
||||
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
|
||||
# and its hidden size is 7168.
|
||||
MNK_FACTORS = [
|
||||
(1, 128, 128),
|
||||
(1, 512, 512),
|
||||
(1, 128, 7168),
|
||||
(1, 1024, 7168),
|
||||
(1, 4608, 128),
|
||||
(1, 4608, 512),
|
||||
(1, 4608, 7168),
|
||||
(83, 128, 128),
|
||||
(83, 512, 512),
|
||||
(83, 1024, 7168),
|
||||
(83, 4608, 512),
|
||||
(83, 4608, 7168),
|
||||
(128, 128, 128),
|
||||
(128, 512, 512),
|
||||
(128, 1024, 7168),
|
||||
(128, 4608, 512),
|
||||
(128, 4608, 7168),
|
||||
(2048, 128, 128),
|
||||
(2048, 1024, 7168),
|
||||
(2048, 4608, 512),
|
||||
(2048, 4608, 7168),
|
||||
(8192, 128, 128),
|
||||
(8192, 512, 512),
|
||||
(8192, 128, 7168),
|
||||
(8192, 1024, 7168),
|
||||
(8192, 4608, 512),
|
||||
(8192, 4608, 7168),
|
||||
]
|
||||
|
||||
MNK_FACTORS_DG = [
|
||||
(128, 128, 128),
|
||||
(128, 512, 512),
|
||||
(128, 128, 7168),
|
||||
(128, 1024, 7168),
|
||||
(128, 4608, 128),
|
||||
(128, 4608, 512),
|
||||
(128, 4608, 7168),
|
||||
(192, 128, 128),
|
||||
(192, 512, 512),
|
||||
(192, 1024, 7168),
|
||||
(192, 4608, 512),
|
||||
(192, 4608, 7168),
|
||||
(1335, 128, 128),
|
||||
(1335, 1024, 7168),
|
||||
(1335, 4608, 512),
|
||||
(1335, 4608, 7168),
|
||||
(2048, 128, 128),
|
||||
(2048, 512, 512),
|
||||
(2048, 128, 7168),
|
||||
(2048, 1024, 7168),
|
||||
(2048, 4608, 128),
|
||||
(2048, 4608, 512),
|
||||
(2048, 4608, 7168),
|
||||
]
|
||||
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
E = [2, 8, 16] # [128, 256]
|
||||
TOP_KS = [1, 2, 6]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids,
|
||||
block_shape):
|
||||
"""Fused moe with block-wise quantization using native torch."""
|
||||
B, D = a.shape
|
||||
topk = topk_ids.size(1)
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
|
||||
a_q = a_q.to(torch.float32)
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
inter_out = native_w8a8_block_matmul(a_q[mask],
|
||||
w1[i],
|
||||
a_s[mask],
|
||||
w1_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype)
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
act_out_q, act_out_s = native_per_token_group_quant_fp8(
|
||||
act_out, block_k)
|
||||
out[mask] = native_w8a8_block_matmul(act_out_q,
|
||||
w2[i],
|
||||
act_out_s,
|
||||
w2_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype)
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
# Skip all tests if CUDA is not available
|
||||
pytest.importorskip("torch.cuda")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_cuda():
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS)
|
||||
@pytest.mark.parametrize("E", E)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
||||
monkeypatch):
|
||||
if topk > E:
|
||||
pytest.skip(f"Skipping test; topk={topk} > E={E}")
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "2048")
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
|
||||
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = torch_w8a8_block_fp8_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
w1_s,
|
||||
w2_s,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
block_size,
|
||||
)
|
||||
|
||||
out = fused_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=block_size,
|
||||
)
|
||||
|
||||
m_out = m_fused_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
)
|
||||
|
||||
# 0.039 only needed for [40000-4608-7168-2-1-block_size852-dtype852-0]
|
||||
tol = 0.035 if M < 40000 else 0.039
|
||||
torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol)
|
||||
torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS_DG)
|
||||
@pytest.mark.parametrize("E", E)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
||||
monkeypatch):
|
||||
if topk > E:
|
||||
pytest.skip(f"Skipping test: topk={topk} > E={E}")
|
||||
|
||||
if not _valid_deep_gemm_shape(M, N, K):
|
||||
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
|
||||
|
||||
chunk_size = 1024
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
|
||||
|
||||
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
|
||||
block_size = [block_m, block_m]
|
||||
dtype = torch.bfloat16
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
|
||||
# Note: for now use_compile will error out if the problem size is
|
||||
# large enough to trigger chunking. I'm leaving the flag and
|
||||
# setup code in case we are able to revisit this later.
|
||||
use_compile = False
|
||||
|
||||
use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024
|
||||
and current_platform.is_cuda_alike())
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weights,
|
||||
topk_ids, block_size)
|
||||
|
||||
if use_compile:
|
||||
deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8,
|
||||
backend="inductor",
|
||||
fullgraph=True)
|
||||
torch._dynamo.mark_dynamic(a, 0)
|
||||
torch._dynamo.mark_dynamic(topk_weights, 0)
|
||||
torch._dynamo.mark_dynamic(topk_ids, 0)
|
||||
else:
|
||||
deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8
|
||||
|
||||
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
|
||||
topk_ids)
|
||||
|
||||
if use_cudagraph:
|
||||
out.fill_(0)
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
|
||||
topk_ids)
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.testing.assert_close(out, ref_out, atol=0.035, rtol=0.035)
|
||||
147
tests/kernels/moe/test_block_int8.py
Normal file
147
tests/kernels/moe/test_block_int8.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_weights
|
||||
from tests.kernels.quant_utils import (native_per_token_group_quant_int8,
|
||||
native_w8a8_block_matmul)
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.get_device_capability() < (7, 0):
|
||||
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
|
||||
allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 128, 128),
|
||||
(1, 512, 512),
|
||||
(1, 128, 7168),
|
||||
(1, 1024, 7168),
|
||||
(1, 4096, 128),
|
||||
(1, 4096, 512),
|
||||
(1, 4096, 7168),
|
||||
(33, 128, 128),
|
||||
(33, 512, 512),
|
||||
(33, 128, 7168),
|
||||
(33, 1024, 7168),
|
||||
(33, 4096, 128),
|
||||
(33, 4096, 512),
|
||||
(33, 4096, 7168),
|
||||
(128, 128, 128),
|
||||
(128, 512, 512),
|
||||
(128, 1024, 7168),
|
||||
(128, 4096, 512),
|
||||
(128, 4096, 7168),
|
||||
(222, 128, 128),
|
||||
(222, 512, 512),
|
||||
(222, 1024, 7168),
|
||||
(222, 4096, 512),
|
||||
(222, 4096, 7168),
|
||||
(2048, 128, 128),
|
||||
(2048, 1024, 7168),
|
||||
(2048, 4096, 512),
|
||||
(2048, 4096, 7168),
|
||||
]
|
||||
|
||||
E = [8, 24]
|
||||
TOP_KS = [2, 6]
|
||||
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
# For test
|
||||
def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
|
||||
"""This function performs fused moe with block-wise quantization using
|
||||
native torch."""
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
a_q, a_s = native_per_token_group_quant_int8(a, block_k)
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
inter_out = native_w8a8_block_matmul(a_q[mask],
|
||||
w1[i],
|
||||
a_s[mask],
|
||||
w1_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype)
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
act_out_q, act_out_s = native_per_token_group_quant_int8(
|
||||
act_out, block_k)
|
||||
act_out = act_out.to(torch.float32)
|
||||
out[mask] = native_w8a8_block_matmul(act_out_q,
|
||||
w2[i],
|
||||
act_out_s,
|
||||
w2_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype)
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def setup_cuda():
|
||||
"""Sets the default CUDA device for all tests in this module."""
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS)
|
||||
@pytest.mark.parametrize("E", E)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
||||
"""Tests the fused_moe kernel with W8A8 INT8 block quantization against a
|
||||
native torch reference."""
|
||||
torch.manual_seed(seed)
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.int8,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
out = fused_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
use_int8_w8a8=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=block_size,
|
||||
)
|
||||
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk,
|
||||
block_size)
|
||||
|
||||
# Check results
|
||||
torch.testing.assert_close(out, ref_out, atol=0.065, rtol=0.065)
|
||||
@@ -97,11 +97,9 @@ class MOETensors8Bit(MOETensors):
|
||||
n_b_scales = 2 * n if per_out_channel else 1
|
||||
k_b_scales = k if per_out_channel else 1
|
||||
# Get the right scale for tests.
|
||||
_, a_scale = ops.scaled_fp8_quant(
|
||||
moe_tensors_fp16.a, use_per_token_if_dynamic=per_act_token)
|
||||
a_q, _ = ops.scaled_fp8_quant(moe_tensors_fp16.a,
|
||||
a_scale,
|
||||
use_per_token_if_dynamic=per_act_token)
|
||||
a_q, a_scale = ops.scaled_fp8_quant(
|
||||
moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token)
|
||||
|
||||
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
|
||||
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
|
||||
|
||||
@@ -187,6 +185,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
|
||||
def run_8_bit(moe_tensors: MOETensors8Bit,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
num_local_experts: Optional[int] = None) -> torch.Tensor:
|
||||
assert not any([
|
||||
t is None for t in [
|
||||
@@ -203,7 +202,8 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
|
||||
'topk_ids': topk_ids,
|
||||
'w1_scale': moe_tensors.w1_scale,
|
||||
'w2_scale': moe_tensors.w2_scale,
|
||||
'a1_scale': moe_tensors.a_scale
|
||||
'per_act_token': per_act_token,
|
||||
'a1_scale': None #moe_tensors.a_scale
|
||||
}
|
||||
|
||||
num_experts = moe_tensors.w1.size(0)
|
||||
@@ -254,11 +254,13 @@ def test_cutlass_moe_8_bit_no_graph(
|
||||
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
|
||||
topk_ids)
|
||||
|
||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids)
|
||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token)
|
||||
|
||||
# Note 5.5 only needed for larger problem sizes, 5 works ok for
|
||||
# the rest.
|
||||
torch.testing.assert_close(triton_output,
|
||||
cutlass_output,
|
||||
atol=5e-2,
|
||||
atol=5.5e-2,
|
||||
rtol=1e-2)
|
||||
|
||||
|
||||
@@ -303,7 +305,8 @@ def test_cutlass_moe_8_bit_cuda_graph(
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids)
|
||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids,
|
||||
per_act_token)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
@@ -359,6 +362,7 @@ def test_cutlass_moe_8_bit_EP(
|
||||
cutlass_output = run_8_bit(mt,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
per_act_token,
|
||||
num_local_experts=e // ep_size)
|
||||
|
||||
torch.testing.assert_close(triton_output,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Test DeepEP + DeepGEMM integration
|
||||
Test DeepEP + DeepGEMM integration
|
||||
DeepGEMM are gemm kernels specialized for the
|
||||
fp8 block-quantized case.
|
||||
"""
|
||||
@@ -17,12 +17,11 @@ from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm
|
||||
|
||||
from .utils import ProcessGroupInfo, parallel_launch
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
from .utils import make_test_weights
|
||||
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
@@ -30,10 +29,9 @@ if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
|
||||
from .utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||
|
||||
if has_deep_gemm():
|
||||
import deep_gemm
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts)
|
||||
@@ -60,25 +58,6 @@ def next_power_of_2(x):
|
||||
return 2**math.ceil(math.log2(x))
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor,
|
||||
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros(
|
||||
(deep_gemm.ceil_div(m, 128) * 128,
|
||||
deep_gemm.ceil_div(n, block_size_n) * block_size_n),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
|
||||
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
||||
return x_scaled_sub, scales
|
||||
|
||||
|
||||
def make_block_quant_fp8_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
@@ -86,43 +65,11 @@ def make_block_quant_fp8_weights(
|
||||
block_size: list[int],
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return weights w1, w2, w1q, w2q, w1_scale, w2_scale
|
||||
Return weights w1q, w2q, w1_scale, w2_scale
|
||||
"""
|
||||
dtype = torch.bfloat16
|
||||
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
w1_bf16 = torch.randn((e, 2 * n, k), dtype=dtype) / 10
|
||||
w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
|
||||
|
||||
w2_bf16 = torch.randn((e, k, n), dtype=dtype) / 10
|
||||
w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles_w1 = ((2 * n) + block_n - 1) // block_n
|
||||
k_tiles_w1 = (k + block_k - 1) // block_k
|
||||
n_tiles_w2 = (k + block_n - 1) // block_n
|
||||
k_tiles_w2 = (n + block_k - 1) // block_k
|
||||
|
||||
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
|
||||
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
|
||||
|
||||
w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
|
||||
assert w1_s.shape == (e, (2 * n + 127) // 128, (k + 127) // 128)
|
||||
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
|
||||
|
||||
for i in range(e):
|
||||
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
|
||||
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
|
||||
|
||||
return w1, w2, w1_s, w2_s
|
||||
w1, w1q, w1_scale, w2, w2q, w2_scale = make_test_weights(
|
||||
e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_size)
|
||||
return w1q, w2q, w1_scale, w2_scale
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -132,6 +79,7 @@ class TestConfig:
|
||||
k: int
|
||||
n: int
|
||||
num_experts: int
|
||||
per_act_token_quant: bool
|
||||
block_size: list[int]
|
||||
# configs for testing low-latency kernels
|
||||
low_latency: bool
|
||||
@@ -150,8 +98,7 @@ class TestTensors:
|
||||
def make(config: TestConfig, rank) -> "TestTensors":
|
||||
|
||||
dtype = torch.bfloat16
|
||||
topk, m, k, block_size = (config.topk, config.m, config.k,
|
||||
config.block_size)
|
||||
topk, m, k = (config.topk, config.m, config.k)
|
||||
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
@@ -159,9 +106,7 @@ class TestTensors:
|
||||
rank_tokens = torch.randn(
|
||||
(m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
|
||||
rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
|
||||
|
||||
block_k = block_size[1]
|
||||
_, rank_token_scales = per_token_group_quant_fp8(rank_tokens, block_k)
|
||||
rank_token_scales = None
|
||||
|
||||
topk_ids = torch.randint(
|
||||
low=0,
|
||||
@@ -201,10 +146,12 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
q_dtype=q_dtype,
|
||||
block_shape=test_config.block_size)
|
||||
|
||||
fused_experts = BatchedDeepGemmExperts(max_num_tokens=max_tokens_per_rank,
|
||||
world_size=pgi.world_size,
|
||||
dp_size=dp_size,
|
||||
block_shape=test_config.block_size)
|
||||
fused_experts = BatchedDeepGemmExperts(
|
||||
max_num_tokens=max_tokens_per_rank,
|
||||
world_size=pgi.world_size,
|
||||
dp_size=dp_size,
|
||||
block_shape=test_config.block_size,
|
||||
per_act_token_quant=test_config.per_act_token_quant)
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||
fused_experts=fused_experts)
|
||||
return mk
|
||||
@@ -426,6 +373,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
|
||||
"""
|
||||
Tests for High-Throughput DeepEP + DeepGemm integration.
|
||||
"""
|
||||
import deep_gemm
|
||||
|
||||
m, n, k = mnk
|
||||
current_platform.seed_everything(7)
|
||||
@@ -442,6 +390,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
|
||||
k=k,
|
||||
n=n,
|
||||
num_experts=num_experts,
|
||||
per_act_token_quant=False,
|
||||
block_size=block_size,
|
||||
low_latency=False,
|
||||
use_fp8_dispatch=None)
|
||||
@@ -474,10 +423,14 @@ USE_FP8_DISPATCH = [False]
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@requires_deep_ep
|
||||
@requires_deep_gemm
|
||||
def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int,
|
||||
int], num_experts: int, topk: int,
|
||||
use_fp8_dispatch: bool, block_size: list[int],
|
||||
world_dp_size: tuple[int, int]):
|
||||
def test_ll_deepep_deepgemm_moe(
|
||||
mnk: tuple[int, int, int],
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
use_fp8_dispatch: bool,
|
||||
block_size: list[int],
|
||||
world_dp_size: tuple[int, int],
|
||||
):
|
||||
"""
|
||||
Tests for Low-Latency DeepEP + DeepGemm integration.
|
||||
"""
|
||||
@@ -495,6 +448,7 @@ def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int,
|
||||
k=k,
|
||||
n=n,
|
||||
num_experts=num_experts,
|
||||
per_act_token_quant=False,
|
||||
block_size=block_size,
|
||||
low_latency=True,
|
||||
use_fp8_dispatch=use_fp8_dispatch,
|
||||
|
||||
@@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_ep
|
||||
|
||||
from .utils import ProcessGroupInfo, parallel_launch
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
@@ -31,7 +31,7 @@ if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
|
||||
from .utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||
|
||||
requires_deep_ep = pytest.mark.skipif(
|
||||
not has_deep_ep(),
|
||||
@@ -102,10 +102,6 @@ class TestTensors:
|
||||
rank_tokens = torch.randn(
|
||||
(config.m, config.k), device="cuda", dtype=token_dtype) / 10
|
||||
rank_token_scales = None
|
||||
if config.dtype == torch.float8_e4m3fn:
|
||||
# low_latency_mode kernels dont support per-token quant.
|
||||
_, rank_token_scales = ops.scaled_fp8_quant(
|
||||
rank_tokens, use_per_token_if_dynamic=not low_latency_mode)
|
||||
|
||||
topk = torch.randint(low=0,
|
||||
high=config.num_experts,
|
||||
@@ -121,11 +117,18 @@ class TestTensors:
|
||||
config=config)
|
||||
|
||||
|
||||
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
low_latency_mode: bool, hidden_size: int, dp_size: int,
|
||||
num_experts: int, num_local_experts: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
use_fp8_dispatch: bool) -> FusedMoEModularKernel:
|
||||
def make_modular_kernel(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
low_latency_mode: bool,
|
||||
hidden_size: int,
|
||||
dp_size: int,
|
||||
num_experts: int,
|
||||
num_local_experts: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
use_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
) -> FusedMoEModularKernel:
|
||||
|
||||
is_quantized = q_dtype is not None
|
||||
|
||||
@@ -152,6 +155,7 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
deepep_ll_args = ll_args)
|
||||
|
||||
if low_latency_mode:
|
||||
assert not per_act_token_quant, "not supported in ll mode"
|
||||
fused_experts = BatchedTritonExperts(
|
||||
max_num_tokens=MAX_TOKENS_PER_RANK,
|
||||
world_size=pgi.world_size,
|
||||
@@ -159,25 +163,37 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
use_fp8_w8a8=is_quantized,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False)
|
||||
use_int4_w4a16=False,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
else:
|
||||
fused_experts = TritonExperts(use_fp8_w8a8=is_quantized,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
per_channel_quant=False)
|
||||
fused_experts = TritonExperts(
|
||||
use_fp8_w8a8=is_quantized,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
)
|
||||
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||
fused_experts=fused_experts)
|
||||
return mk
|
||||
|
||||
|
||||
def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
low_latency_mode: bool, dp_size: int,
|
||||
test_tensors: TestTensors, w1: torch.Tensor,
|
||||
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor], num_experts: int,
|
||||
use_fp8_dispatch: bool) -> torch.Tensor:
|
||||
def deep_ep_moe_impl(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
low_latency_mode: bool,
|
||||
dp_size: int,
|
||||
test_tensors: TestTensors,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
num_experts: int,
|
||||
use_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
) -> torch.Tensor:
|
||||
|
||||
num_local_experts = w1.size(0)
|
||||
|
||||
@@ -199,11 +215,9 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
|
||||
# Make modular kernel
|
||||
mk: FusedMoEModularKernel = make_modular_kernel(pg, pgi, low_latency_mode,
|
||||
hidden_size, dp_size,
|
||||
num_experts,
|
||||
num_local_experts, q_dtype,
|
||||
use_fp8_dispatch)
|
||||
mk: FusedMoEModularKernel = make_modular_kernel(
|
||||
pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts,
|
||||
num_local_experts, q_dtype, use_fp8_dispatch, per_act_token_quant)
|
||||
|
||||
out_hidden_states = torch.empty_like(test_tensors.rank_tokens)
|
||||
total_num_tokens = test_tensors.rank_tokens.size(0)
|
||||
@@ -257,9 +271,15 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor,
|
||||
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor], using_fp8_dispatch: bool):
|
||||
def torch_moe_impl(
|
||||
test_tensors: TestTensors,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
using_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
|
||||
a, topk_ids, topk_weights = (test_tensors.rank_tokens, test_tensors.topk,
|
||||
test_tensors.topk_weights)
|
||||
@@ -267,6 +287,7 @@ def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor,
|
||||
# The DeepEP implementation is requested to dispatch using FP8.
|
||||
# For numerical stability for testing, emulate the fp8 dispatch by
|
||||
# blockwise quant and de-quant.
|
||||
assert not per_act_token_quant
|
||||
a = test_tensors.rank_tokens
|
||||
aq, aq_scale = per_token_group_quant_fp8(a, 128)
|
||||
a = (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)).view(
|
||||
@@ -310,6 +331,7 @@ def _deep_ep_moe(
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
use_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
|
||||
if not low_latency_mode:
|
||||
@@ -331,7 +353,8 @@ def _deep_ep_moe(
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
# Reference
|
||||
torch_combined = torch_moe_impl(test_tensors, w1, w2, w1_scale,
|
||||
w2_scale, use_fp8_dispatch)
|
||||
w2_scale, use_fp8_dispatch,
|
||||
per_act_token_quant)
|
||||
|
||||
# Splice experts for this rank.
|
||||
num_local_experts = config.num_experts // pgi.world_size
|
||||
@@ -356,6 +379,7 @@ def _deep_ep_moe(
|
||||
w2_scale_ep,
|
||||
config.num_experts,
|
||||
use_fp8_dispatch,
|
||||
per_act_token_quant,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
@@ -384,10 +408,16 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("topk", [6])
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
||||
@requires_deep_ep
|
||||
def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
|
||||
num_experts: int, topk: int, world_dp_size: tuple[int,
|
||||
int]):
|
||||
def test_deep_ep_moe(
|
||||
dtype: torch.dtype,
|
||||
mnk: tuple[int, int, int],
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
low_latency_mode = False
|
||||
use_fp8_dispatch = False
|
||||
m, n, k = mnk
|
||||
@@ -404,7 +434,8 @@ def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
|
||||
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
|
||||
|
||||
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
|
||||
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch)
|
||||
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch,
|
||||
per_act_token_quant)
|
||||
|
||||
|
||||
MNKs = [
|
||||
@@ -454,4 +485,5 @@ def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
|
||||
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
|
||||
|
||||
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
|
||||
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch)
|
||||
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch,
|
||||
False)
|
||||
|
||||
@@ -17,6 +17,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import init_distributed_environment
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
@@ -142,6 +143,10 @@ def test_fused_moe(
|
||||
# Setup test data
|
||||
#
|
||||
|
||||
#
|
||||
# Setup test data
|
||||
#
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
@@ -169,7 +174,7 @@ def test_fused_moe(
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
per_channel_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None)
|
||||
|
||||
def m_fused_moe(
|
||||
@@ -365,6 +370,13 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
|
||||
if dtype == torch.float32:
|
||||
pytest.skip("AITER ROCm test skip for float32")
|
||||
|
||||
monkeypatch.setenv('RANK', "0")
|
||||
monkeypatch.setenv('LOCAL_RANK', "0")
|
||||
monkeypatch.setenv('WORLD_SIZE', "1")
|
||||
monkeypatch.setenv('MASTER_ADDR', 'localhost')
|
||||
monkeypatch.setenv('MASTER_PORT', '12345')
|
||||
init_distributed_environment()
|
||||
|
||||
# Instantiate our and huggingface's MoE blocks
|
||||
vllm_config.compilation_config.static_forward_context = dict()
|
||||
with (set_current_vllm_config(vllm_config),
|
||||
|
||||
@@ -14,7 +14,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
pytest.skip("Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True)
|
||||
|
||||
MNK_FACTORS = [
|
||||
|
||||
@@ -15,7 +15,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import ProcessGroupInfo, parallel_launch
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
|
||||
try:
|
||||
from pplx_kernels import AllToAll
|
||||
@@ -93,7 +93,7 @@ def pplx_cutlass_moe(
|
||||
num_experts=num_experts,
|
||||
experts_per_token=topk,
|
||||
rank=rank,
|
||||
world_size=pgi.world_size,
|
||||
world_size=world_size,
|
||||
dp_size=dp_size,
|
||||
hidden_dim=hidden_dim,
|
||||
hidden_dim_bytes=hidden_dim, # because a.dtype.itemsize == 1
|
||||
@@ -118,8 +118,6 @@ def pplx_cutlass_moe(
|
||||
pgi.world_size,
|
||||
rank,
|
||||
dp_size,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_act_token=per_act_token,
|
||||
)
|
||||
|
||||
experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size,
|
||||
|
||||
@@ -18,18 +18,20 @@ try:
|
||||
except ImportError:
|
||||
has_pplx = False
|
||||
|
||||
from tests.kernels.moe.utils import make_test_weights, naive_batched_moe
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe import override_config
|
||||
from vllm.model_executor.layers.fused_moe import fused_topk, override_config
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk,
|
||||
get_default_config)
|
||||
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import round_up
|
||||
|
||||
from .utils import ProcessGroupInfo, parallel_launch
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
|
||||
requires_pplx = pytest.mark.skipif(
|
||||
not has_pplx,
|
||||
@@ -144,25 +146,6 @@ def torch_batched_moe(
|
||||
return torch_finalize(out, topk_weight, topk_ids)
|
||||
|
||||
|
||||
def batched_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_experts = w1.shape[0]
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(max_num_tokens=a.shape[0],
|
||||
world_size=1,
|
||||
dp_size=1,
|
||||
rank=0),
|
||||
BatchedExperts(max_num_tokens=a.shape[0], dp_size=1, world_size=1))
|
||||
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||
@pytest.mark.parametrize("k", [128, 512, 1024])
|
||||
@@ -188,7 +171,7 @@ def test_fused_moe_batched_experts(
|
||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
baseline_output = torch_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
|
||||
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids)
|
||||
batched_output = naive_batched_moe(a, w1, w2, topk_weight, topk_ids)
|
||||
|
||||
torch.testing.assert_close(baseline_output,
|
||||
torch_output,
|
||||
@@ -226,7 +209,6 @@ def pplx_prepare_finalize(
|
||||
|
||||
topk = topk_ids.shape[1]
|
||||
num_tokens, hidden_dim = a.shape
|
||||
block_size = 128
|
||||
device = pgi.device
|
||||
rank = pgi.rank
|
||||
world_size = pgi.world_size
|
||||
@@ -241,9 +223,7 @@ def pplx_prepare_finalize(
|
||||
dp_size=dp_size,
|
||||
hidden_dim=hidden_dim,
|
||||
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
|
||||
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
|
||||
((hidden_dim + block_size - 1) // block_size *
|
||||
torch.float32.itemsize)),
|
||||
hidden_dim_scale_bytes=0,
|
||||
)
|
||||
|
||||
if group_name is None:
|
||||
@@ -260,7 +240,6 @@ def pplx_prepare_finalize(
|
||||
world_size,
|
||||
rank,
|
||||
dp_size,
|
||||
a.dtype,
|
||||
)
|
||||
|
||||
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
|
||||
@@ -276,6 +255,7 @@ def pplx_prepare_finalize(
|
||||
num_experts,
|
||||
None,
|
||||
False,
|
||||
FusedMoEQuantConfig(),
|
||||
)
|
||||
|
||||
b_a = b_a * 1.5
|
||||
@@ -350,6 +330,7 @@ def _pplx_prepare_finalize(
|
||||
# TODO (bnell): this test point does not work for odd M due to how the test is
|
||||
# written, not due to limitations of the pplx kernels. The pplx_moe
|
||||
# test below is able to deal with odd M.
|
||||
# TODO (bnell) add fp8 tests
|
||||
@pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@@ -386,18 +367,31 @@ def pplx_moe(
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
qtype: Optional[torch.dtype] = None,
|
||||
per_act_token_quant=False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
use_compile: bool = False,
|
||||
use_cudagraphs: bool = True,
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize)
|
||||
PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes)
|
||||
|
||||
device = torch.device("cuda", rank)
|
||||
hidden_dim = a.shape[1]
|
||||
num_experts = w1.shape[0]
|
||||
block_size = 128
|
||||
topk = topk_ids.shape[1]
|
||||
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
|
||||
max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), 64)
|
||||
|
||||
hidden_dim_bytes, scale_bytes = pplx_hidden_dim_scale_bytes(
|
||||
max_num_tokens,
|
||||
hidden_dim,
|
||||
a.dtype,
|
||||
qtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
args = dict(
|
||||
max_num_tokens=max_num_tokens,
|
||||
@@ -407,10 +401,8 @@ def pplx_moe(
|
||||
world_size=world_size,
|
||||
dp_size=dp_size,
|
||||
hidden_dim=hidden_dim,
|
||||
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
|
||||
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
|
||||
((hidden_dim + block_size - 1) // block_size *
|
||||
torch.float32.itemsize)),
|
||||
hidden_dim_bytes=hidden_dim_bytes,
|
||||
hidden_dim_scale_bytes=scale_bytes,
|
||||
)
|
||||
|
||||
if group_name is None:
|
||||
@@ -429,9 +421,11 @@ def pplx_moe(
|
||||
dp_size,
|
||||
)
|
||||
|
||||
experts = BatchedTritonExperts(max_num_tokens=a.shape[0],
|
||||
experts = BatchedTritonExperts(max_num_tokens=max_num_tokens,
|
||||
world_size=world_size,
|
||||
dp_size=dp_size)
|
||||
dp_size=dp_size,
|
||||
use_fp8_w8a8=qtype == torch.float8_e4m3fn,
|
||||
block_shape=block_shape)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
@@ -447,6 +441,13 @@ def pplx_moe(
|
||||
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
|
||||
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
|
||||
|
||||
if w1_scale is not None:
|
||||
w1_scale_chunk = chunk_by_rank(w1_scale, rank, world_size).to(device)
|
||||
w2_scale_chunk = chunk_by_rank(w2_scale, rank, world_size).to(device)
|
||||
else:
|
||||
w1_scale_chunk = None
|
||||
w2_scale_chunk = None
|
||||
|
||||
# Note: for now use_compile will error out if the problem size is
|
||||
# large enough to trigger chunking. I'm leaving the flag and
|
||||
# setup code in case we are able to revisit this later.
|
||||
@@ -465,6 +466,8 @@ def pplx_moe(
|
||||
w2_chunk,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
w1_scale=w1_scale_chunk,
|
||||
w2_scale=w2_scale_chunk,
|
||||
global_num_experts=num_experts)
|
||||
|
||||
if use_cudagraphs:
|
||||
@@ -477,6 +480,8 @@ def pplx_moe(
|
||||
w2_chunk,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
w1_scale=w1_scale_chunk,
|
||||
w2_scale=w2_scale_chunk,
|
||||
global_num_experts=num_experts)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
@@ -505,9 +510,9 @@ def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
experts = BatchedExperts(max_num_tokens=a.shape[0],
|
||||
world_size=1,
|
||||
dp_size=1)
|
||||
experts = NaiveBatchedExperts(max_num_tokens=a.shape[0],
|
||||
world_size=1,
|
||||
dp_size=1)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
@@ -539,7 +544,12 @@ def _pplx_moe(
|
||||
w2: torch.Tensor,
|
||||
score: torch.Tensor,
|
||||
topk: int,
|
||||
use_internode: bool,
|
||||
w1_s: Optional[torch.Tensor] = None,
|
||||
w2_s: Optional[torch.Tensor] = None,
|
||||
qtype: Optional[torch.dtype] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
use_internode: bool = False,
|
||||
):
|
||||
if use_internode:
|
||||
uid = nvshmem_get_unique_id(
|
||||
@@ -557,11 +567,28 @@ def _pplx_moe(
|
||||
|
||||
moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)
|
||||
|
||||
device = torch.device("cuda", pgi.rank)
|
||||
a = a.to(device)
|
||||
w1 = w1.to(device)
|
||||
w2 = w2.to(device)
|
||||
w1_s = w1_s.to(device) if w1_s is not None else None
|
||||
w2_s = w2_s.to(device) if w2_s is not None else None
|
||||
|
||||
with set_current_vllm_config(vllm_config), override_config(moe_config):
|
||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
torch_output = torch_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
torch_output = torch_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
quant_dtype=qtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape)
|
||||
pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size,
|
||||
a, w1, w2, topk_weight, topk_ids)
|
||||
a, w1, w2, topk_weight, topk_ids, w1_s, w2_s,
|
||||
qtype, per_act_token_quant, block_shape)
|
||||
# TODO (bnell): fix + re-enable
|
||||
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
|
||||
# topk_ids)
|
||||
@@ -581,6 +608,8 @@ def _pplx_moe(
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
|
||||
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
||||
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
|
||||
@pytest.mark.parametrize("use_internode", [False])
|
||||
@requires_pplx
|
||||
def test_pplx_moe(
|
||||
@@ -589,15 +618,33 @@ def test_pplx_moe(
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
world_dp_size: tuple[int, int],
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]],
|
||||
use_internode: bool,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
m, n, k = mnk
|
||||
world_size, dp_size = world_dp_size
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
use_fp8_w8a8 = True
|
||||
quant_dtype = dtype
|
||||
else:
|
||||
use_fp8_w8a8 = False
|
||||
quant_dtype = None
|
||||
|
||||
if not use_fp8_w8a8 and per_act_token_quant and block_shape is not None:
|
||||
pytest.skip("Skip quantization test for non-quantized type")
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
_, w1, w1_s, _, w2, w2_s = make_test_weights(e,
|
||||
n,
|
||||
k,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=block_shape)
|
||||
|
||||
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk,
|
||||
w1_s, w2_s, quant_dtype, per_act_token_quant, block_shape,
|
||||
use_internode)
|
||||
|
||||
@@ -1,194 +1,249 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
DeepEP test utilities
|
||||
"""
|
||||
import dataclasses
|
||||
import importlib
|
||||
import os
|
||||
import traceback
|
||||
from typing import Callable, Optional
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.multiprocessing import (
|
||||
spawn) # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
from vllm.utils import get_open_port
|
||||
|
||||
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
|
||||
if has_deep_ep:
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
DeepEPHTPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
|
||||
## Parallel Processes Utils
|
||||
|
||||
P = ParamSpec("P")
|
||||
import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import (per_block_cast_to_fp8,
|
||||
per_block_cast_to_int8)
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
from vllm.utils import round_up
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ProcessGroupInfo:
|
||||
world_size: int
|
||||
world_local_size: int
|
||||
rank: int
|
||||
node_rank: int
|
||||
local_rank: int
|
||||
device: torch.device
|
||||
def triton_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
per_act_token_quant=False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
return fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
per_channel_quant=per_act_token_quant,
|
||||
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
||||
block_shape=block_shape)
|
||||
|
||||
|
||||
def _worker_parallel_launch(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
world_local_size: int,
|
||||
node_rank: int,
|
||||
init_method: str,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
rank = node_rank * world_local_size + local_rank
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
torch.distributed.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
init_method=init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
device_id=device,
|
||||
)
|
||||
barrier = torch.tensor([rank], device=device)
|
||||
torch.distributed.all_reduce(barrier)
|
||||
def batched_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
max_num_tokens = round_up(a.shape[0], 64)
|
||||
|
||||
try:
|
||||
worker(
|
||||
ProcessGroupInfo(
|
||||
world_size=world_size,
|
||||
world_local_size=world_local_size,
|
||||
rank=rank,
|
||||
node_rank=node_rank,
|
||||
local_rank=local_rank,
|
||||
device=device,
|
||||
),
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def parallel_launch(
|
||||
world_size: int,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
assert not kwargs
|
||||
spawn(
|
||||
_worker_parallel_launch,
|
||||
args=(
|
||||
world_size,
|
||||
world_size,
|
||||
0,
|
||||
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
|
||||
worker,
|
||||
) + args,
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(max_num_tokens,
|
||||
world_size=1,
|
||||
dp_size=1,
|
||||
rank=0),
|
||||
BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
world_size=1,
|
||||
dp_size=1,
|
||||
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
## DeepEP specific utils
|
||||
return fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DeepEPHTArgs:
|
||||
num_local_experts: int
|
||||
def naive_batched_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
max_num_tokens = round_up(a.shape[0], 64)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DeepEPLLArgs:
|
||||
max_tokens_per_rank: int
|
||||
hidden_size: int
|
||||
num_experts: int
|
||||
use_fp8_dispatch: bool
|
||||
|
||||
|
||||
def make_deepep_ht_a2a(pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
ht_args: DeepEPHTArgs,
|
||||
q_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None):
|
||||
|
||||
import deep_ep
|
||||
|
||||
# high throughput a2a
|
||||
num_nvl_bytes = 1024 * 1024 * 1024 # 1GB
|
||||
num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1
|
||||
buffer = deep_ep.Buffer(group=pg,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=low_latency_mode,
|
||||
num_qps_per_rank=num_qps_per_rank)
|
||||
return DeepEPHTPrepareAndFinalize(buffer=buffer,
|
||||
world_size=pgi.world_size,
|
||||
rank=pgi.rank,
|
||||
dp_size=dp_size,
|
||||
rank_expert_offset=pgi.rank *
|
||||
ht_args.num_local_experts,
|
||||
quant_dtype=q_dtype,
|
||||
block_shape=block_shape)
|
||||
|
||||
|
||||
def make_deepep_ll_a2a(pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
deepep_ll_args: DeepEPLLArgs,
|
||||
q_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None):
|
||||
|
||||
import deep_ep
|
||||
|
||||
# low-latency a2a
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size,
|
||||
pgi.world_size, deepep_ll_args.num_experts)
|
||||
|
||||
buffer = deep_ep.Buffer(group=pg,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=deepep_ll_args.num_experts //
|
||||
pgi.world_size)
|
||||
|
||||
return DeepEPLLPrepareAndFinalize(
|
||||
buffer=buffer,
|
||||
world_size=pgi.world_size,
|
||||
dp_size=dp_size,
|
||||
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
|
||||
quant_dtype=q_dtype,
|
||||
block_shape=block_shape,
|
||||
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(max_num_tokens,
|
||||
world_size=1,
|
||||
dp_size=1,
|
||||
rank=0),
|
||||
NaiveBatchedExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
dp_size=1,
|
||||
world_size=1,
|
||||
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
),
|
||||
)
|
||||
|
||||
return fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale)
|
||||
|
||||
def make_deepep_a2a(pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
deepep_ht_args: Optional[DeepEPHTArgs],
|
||||
deepep_ll_args: Optional[DeepEPLLArgs],
|
||||
q_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None):
|
||||
if deepep_ht_args is not None:
|
||||
assert deepep_ll_args is None
|
||||
return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype,
|
||||
block_shape)
|
||||
|
||||
assert deepep_ll_args is not None
|
||||
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype,
|
||||
block_shape)
|
||||
def chunk_scales(scales: Optional[torch.Tensor], start: int,
|
||||
end: int) -> Optional[torch.Tensor]:
|
||||
if scales is not None:
|
||||
if scales.numel() == 1:
|
||||
return scales
|
||||
else:
|
||||
return scales[start:end]
|
||||
return None
|
||||
|
||||
|
||||
def make_quantized_test_activations(
|
||||
E: int,
|
||||
m: int,
|
||||
k: int,
|
||||
in_dtype: torch.dtype,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10
|
||||
a_q = a
|
||||
a_scale = None
|
||||
|
||||
if quant_dtype is not None:
|
||||
assert (quant_dtype == torch.float8_e4m3fn
|
||||
or quant_dtype == torch.int8), "only fp8/int8 supported"
|
||||
a_q = torch.zeros_like(a, dtype=quant_dtype)
|
||||
a_scale_l = [None] * E
|
||||
for e in range(E):
|
||||
a_q[e], a_scale_l[e] = moe_kernel_quantize_input(
|
||||
a[e], None, quant_dtype, per_act_token_quant, block_shape)
|
||||
a_scale = torch.stack(a_scale_l)
|
||||
|
||||
if not per_act_token_quant and block_shape is None:
|
||||
a_scale = a_scale.view(E, 1, 1)
|
||||
|
||||
return a, a_q, a_scale
|
||||
|
||||
|
||||
def moe_quantize_weights(
|
||||
w: torch.Tensor,
|
||||
w_s: Optional[torch.Tensor],
|
||||
quant_dtype: Optional[torch.dtype],
|
||||
per_token_quant: bool,
|
||||
block_shape: Optional[list[int]],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert (quant_dtype == torch.float8_e4m3fn
|
||||
or quant_dtype == torch.int8), "only fp8/int8 supported"
|
||||
|
||||
if block_shape is not None:
|
||||
assert not per_token_quant
|
||||
if quant_dtype == torch.int8:
|
||||
w, w_s = per_block_cast_to_int8(w, block_shape)
|
||||
else:
|
||||
w, w_s = per_block_cast_to_fp8(w, block_shape)
|
||||
else:
|
||||
if quant_dtype == torch.int8:
|
||||
w, w_s = ops.scaled_int8_quant(
|
||||
w, w_s, use_per_token_if_dynamic=per_token_quant)
|
||||
else:
|
||||
w, w_s = ops.scaled_fp8_quant(
|
||||
w, w_s, use_per_token_if_dynamic=per_token_quant)
|
||||
|
||||
return w, w_s
|
||||
|
||||
|
||||
def make_test_weight(
|
||||
e: int,
|
||||
rows: int,
|
||||
cols: int,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
|
||||
|
||||
if quant_dtype is not None:
|
||||
w_l = [None] * e
|
||||
w_s_l = [None] * e
|
||||
for idx in range(e):
|
||||
w_l[idx], w_s_l[idx] = moe_quantize_weights(
|
||||
w_16[idx], None, quant_dtype, per_act_token_quant, block_shape)
|
||||
|
||||
w = torch.stack(w_l)
|
||||
w_s = torch.stack(w_s_l)
|
||||
if w_s.ndim == 2:
|
||||
assert w_s.shape[-1] == 1
|
||||
w_s = w_s.view(-1, 1, 1)
|
||||
|
||||
if block_shape is not None:
|
||||
block_n, block_k = block_shape
|
||||
n_tiles = (rows + block_n - 1) // block_n
|
||||
k_tiles = (cols + block_k - 1) // block_k
|
||||
assert w_s.shape == (e, n_tiles, k_tiles)
|
||||
else:
|
||||
w = w_16
|
||||
w_s = None
|
||||
|
||||
return w_16, w, w_s
|
||||
|
||||
|
||||
def make_test_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor,
|
||||
torch.Tensor, Optional[torch.Tensor]]:
|
||||
return (
|
||||
*make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
|
||||
per_act_token_quant),
|
||||
*make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
|
||||
per_act_token_quant),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user