Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -11,13 +11,15 @@ import torch
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG, fp8_w8a8_moe_quant_config)
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp8, run_cutlass_moe_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
|
||||
fused_topk)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
cutlass_moe_fp8,
|
||||
run_cutlass_moe_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
NUM_EXPERTS = [40, 64]
|
||||
@@ -39,12 +41,11 @@ MNK_FACTORS = [
|
||||
(224, 3072, 1536),
|
||||
(32768, 1024, 1024),
|
||||
# These sizes trigger wrong answers.
|
||||
#(7232, 2048, 5120),
|
||||
#(40000, 2048, 5120),
|
||||
# (7232, 2048, 5120),
|
||||
# (40000, 2048, 5120),
|
||||
]
|
||||
|
||||
vllm_config = VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))
|
||||
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
@@ -60,22 +61,25 @@ class MOETensors:
|
||||
c_strides2: torch.Tensor
|
||||
|
||||
@staticmethod
|
||||
def make_moe_tensors(m: int, k: int, n: int, e: int,
|
||||
dtype: torch.dtype) -> "MOETensors":
|
||||
def make_moe_tensors(
|
||||
m: int, k: int, n: int, e: int, dtype: torch.dtype
|
||||
) -> "MOETensors":
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
return MOETensors(a=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
ab_strides1=ab_strides1,
|
||||
c_strides1=c_strides1,
|
||||
ab_strides2=ab_strides2,
|
||||
c_strides2=c_strides2)
|
||||
ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
|
||||
return MOETensors(
|
||||
a=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
ab_strides1=ab_strides1,
|
||||
c_strides1=c_strides1,
|
||||
ab_strides2=ab_strides2,
|
||||
c_strides2=c_strides2,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -93,9 +97,9 @@ class MOETensors8Bit(MOETensors):
|
||||
w2_d: Optional[torch.Tensor] = None # w2 -> w2_q -> w2_d
|
||||
|
||||
@staticmethod
|
||||
def make_moe_tensors_8bit(m: int, k: int, n: int, e: int,
|
||||
per_act_token: bool,
|
||||
per_out_channel: bool) -> "MOETensors8Bit":
|
||||
def make_moe_tensors_8bit(
|
||||
m: int, k: int, n: int, e: int, per_act_token: bool, per_out_channel: bool
|
||||
) -> "MOETensors8Bit":
|
||||
dtype = torch.half
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
|
||||
@@ -106,24 +110,21 @@ class MOETensors8Bit(MOETensors):
|
||||
k_b_scales = k if per_out_channel else 1
|
||||
# Get the right scale for tests.
|
||||
a_q, a_scale = ops.scaled_fp8_quant(
|
||||
moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token)
|
||||
moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token
|
||||
)
|
||||
|
||||
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
|
||||
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
|
||||
|
||||
w1_scale = torch.empty((e, n_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32)
|
||||
for expert in range(e):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
|
||||
moe_tensors_fp16.w1[expert],
|
||||
use_per_token_if_dynamic=per_out_channel)
|
||||
moe_tensors_fp16.w1[expert], use_per_token_if_dynamic=per_out_channel
|
||||
)
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
moe_tensors_fp16.w2[expert],
|
||||
use_per_token_if_dynamic=per_out_channel)
|
||||
moe_tensors_fp16.w2[expert], use_per_token_if_dynamic=per_out_channel
|
||||
)
|
||||
|
||||
# a_q -> a_d, w1_q -> w1_d, w2_q -> w2_d
|
||||
a_d = a_q.float().mul(a_scale).to(dtype)
|
||||
@@ -133,31 +134,37 @@ class MOETensors8Bit(MOETensors):
|
||||
w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half()
|
||||
w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half()
|
||||
|
||||
return MOETensors8Bit(a=moe_tensors_fp16.a,
|
||||
w1=moe_tensors_fp16.w1,
|
||||
w2=moe_tensors_fp16.w2,
|
||||
ab_strides1=moe_tensors_fp16.ab_strides1,
|
||||
c_strides1=moe_tensors_fp16.c_strides1,
|
||||
ab_strides2=moe_tensors_fp16.ab_strides2,
|
||||
c_strides2=moe_tensors_fp16.c_strides2,
|
||||
a_q=a_q,
|
||||
w1_q=w1_q,
|
||||
w2_q=w2_q,
|
||||
a_scale=a_scale,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a_d=a_d,
|
||||
w1_d=w1_d,
|
||||
w2_d=w2_d)
|
||||
return MOETensors8Bit(
|
||||
a=moe_tensors_fp16.a,
|
||||
w1=moe_tensors_fp16.w1,
|
||||
w2=moe_tensors_fp16.w2,
|
||||
ab_strides1=moe_tensors_fp16.ab_strides1,
|
||||
c_strides1=moe_tensors_fp16.c_strides1,
|
||||
ab_strides2=moe_tensors_fp16.ab_strides2,
|
||||
c_strides2=moe_tensors_fp16.c_strides2,
|
||||
a_q=a_q,
|
||||
w1_q=w1_q,
|
||||
w2_q=w2_q,
|
||||
a_scale=a_scale,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a_d=a_d,
|
||||
w1_d=w1_d,
|
||||
w2_d=w2_d,
|
||||
)
|
||||
|
||||
|
||||
def run_with_expert_maps(num_experts: int, num_local_experts: int,
|
||||
**cutlass_moe_kwargs):
|
||||
|
||||
def run_with_expert_maps(
|
||||
num_experts: int, num_local_experts: int, **cutlass_moe_kwargs
|
||||
):
|
||||
def slice_experts():
|
||||
slice_params = [
|
||||
"w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1",
|
||||
"c_strides2"
|
||||
"w1_q",
|
||||
"w2_q",
|
||||
"ab_strides1",
|
||||
"ab_strides2",
|
||||
"c_strides1",
|
||||
"c_strides2",
|
||||
]
|
||||
full_tensors = {
|
||||
k: v
|
||||
@@ -173,9 +180,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
|
||||
# make expert map
|
||||
expert_map = [-1] * num_experts
|
||||
expert_map[s:e] = list(range(num_local_experts))
|
||||
expert_map = torch.tensor(expert_map,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
|
||||
|
||||
# update cutlass moe arg with expert_map
|
||||
cutlass_moe_kwargs["expert_map"] = expert_map
|
||||
@@ -198,18 +203,26 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
|
||||
return out_tensor
|
||||
|
||||
|
||||
def run_8_bit(moe_tensors: MOETensors8Bit,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
num_local_experts: Optional[int] = None) -> torch.Tensor:
|
||||
assert not any([
|
||||
t is None for t in [
|
||||
moe_tensors.w1_q, moe_tensors.w2_q, moe_tensors.w1_scale,
|
||||
moe_tensors.w2_scale, moe_tensors.a_scale
|
||||
def run_8_bit(
|
||||
moe_tensors: MOETensors8Bit,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
num_local_experts: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
assert not any(
|
||||
[
|
||||
t is None
|
||||
for t in [
|
||||
moe_tensors.w1_q,
|
||||
moe_tensors.w2_q,
|
||||
moe_tensors.w1_scale,
|
||||
moe_tensors.w2_scale,
|
||||
moe_tensors.a_scale,
|
||||
]
|
||||
]
|
||||
])
|
||||
)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=moe_tensors.w1_scale,
|
||||
@@ -222,16 +235,16 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
'a': moe_tensors.a,
|
||||
'w1_q': moe_tensors.w1_q, # type: ignore[union-attr]
|
||||
'w2_q': moe_tensors.w2_q, # type: ignore[union-attr]
|
||||
'topk_weights': topk_weights,
|
||||
'topk_ids': topk_ids,
|
||||
'ab_strides1': moe_tensors.ab_strides1,
|
||||
'ab_strides2': moe_tensors.ab_strides2,
|
||||
'c_strides1': moe_tensors.c_strides1,
|
||||
'c_strides2': moe_tensors.c_strides2,
|
||||
'quant_config': quant_config,
|
||||
"a": moe_tensors.a,
|
||||
"w1_q": moe_tensors.w1_q, # type: ignore[union-attr]
|
||||
"w2_q": moe_tensors.w2_q, # type: ignore[union-attr]
|
||||
"topk_weights": topk_weights,
|
||||
"topk_ids": topk_ids,
|
||||
"ab_strides1": moe_tensors.ab_strides1,
|
||||
"ab_strides2": moe_tensors.ab_strides2,
|
||||
"c_strides1": moe_tensors.c_strides1,
|
||||
"c_strides2": moe_tensors.c_strides2,
|
||||
"quant_config": quant_config,
|
||||
}
|
||||
|
||||
num_experts = moe_tensors.w1.size(0)
|
||||
@@ -243,7 +256,8 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
|
||||
return run_with_expert_maps(
|
||||
num_experts,
|
||||
num_local_experts, # type: ignore[arg-type]
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@@ -253,8 +267,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_moe_8_bit_no_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
@@ -269,25 +285,18 @@ def test_cutlass_moe_8_bit_no_graph(
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
||||
per_out_ch)
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.half)
|
||||
topk_weights, topk_ids, _ = fused_topk(mt.a,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False)
|
||||
topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False)
|
||||
|
||||
# Note that we are using the dequantized versions of the tensors.
|
||||
# Using a, w1 and w2 directly results in minor output differences.
|
||||
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
triton_output = fused_experts(mt.a_d,
|
||||
mt.w1_d,
|
||||
mt.w2_d,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config)
|
||||
triton_output = fused_experts(
|
||||
mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids, quant_config=quant_config
|
||||
)
|
||||
|
||||
if ep_size is not None:
|
||||
assert e % ep_size == 0, "Cannot distribute experts evenly"
|
||||
@@ -295,15 +304,15 @@ def test_cutlass_moe_8_bit_no_graph(
|
||||
else:
|
||||
number_local_experts = None
|
||||
|
||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token,
|
||||
per_out_ch, number_local_experts)
|
||||
cutlass_output = run_8_bit(
|
||||
mt, topk_weights, topk_ids, per_act_token, per_out_ch, number_local_experts
|
||||
)
|
||||
|
||||
# Note 5.5 only needed for larger problem sizes, 5 works ok for
|
||||
# the rest.
|
||||
torch.testing.assert_close(triton_output,
|
||||
cutlass_output,
|
||||
atol=5.5e-2,
|
||||
rtol=1e-2)
|
||||
torch.testing.assert_close(
|
||||
triton_output, cutlass_output, atol=5.5e-2, rtol=1e-2
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@@ -313,8 +322,10 @@ def test_cutlass_moe_8_bit_no_graph(
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_moe_8_bit_cuda_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
@@ -330,39 +341,30 @@ def test_cutlass_moe_8_bit_cuda_graph(
|
||||
with set_current_vllm_config(vllm_config):
|
||||
dtype = torch.half
|
||||
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
||||
per_out_ch)
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(mt.a,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False)
|
||||
topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False)
|
||||
|
||||
# Note that we are using the dequantized versions of the tensors.
|
||||
# Using a, w1 and w2 directly results in minor output differences.
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
triton_output = fused_experts(mt.a_d,
|
||||
mt.w1_d,
|
||||
mt.w2_d,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config)
|
||||
triton_output = fused_experts(
|
||||
mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids, quant_config=quant_config
|
||||
)
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids,
|
||||
per_act_token, per_out_ch)
|
||||
cutlass_output = run_8_bit(
|
||||
mt, topk_weights, topk_ids, per_act_token, per_out_ch
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.testing.assert_close(triton_output,
|
||||
cutlass_output,
|
||||
atol=9e-2,
|
||||
rtol=1e-2)
|
||||
torch.testing.assert_close(triton_output, cutlass_output, atol=9e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [64])
|
||||
@@ -375,8 +377,10 @@ def test_cutlass_moe_8_bit_cuda_graph(
|
||||
@pytest.mark.parametrize("ep_size", [1, 2, 4, 8, 16])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_moe_8_bit_EP(
|
||||
m: int,
|
||||
n: int,
|
||||
@@ -388,8 +392,9 @@ def test_cutlass_moe_8_bit_EP(
|
||||
ep_size: int,
|
||||
monkeypatch,
|
||||
):
|
||||
test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token,
|
||||
per_out_channel, monkeypatch, ep_size)
|
||||
test_cutlass_moe_8_bit_no_graph(
|
||||
m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size
|
||||
)
|
||||
|
||||
|
||||
LARGE_MNK_FACTORS = [
|
||||
@@ -406,8 +411,10 @@ LARGE_MNK_FACTORS = [
|
||||
@pytest.mark.parametrize("ep_size", [8])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_moe_8_bit_EP_large(
|
||||
m: int,
|
||||
n: int,
|
||||
@@ -419,8 +426,9 @@ def test_cutlass_moe_8_bit_EP_large(
|
||||
ep_size: int,
|
||||
monkeypatch,
|
||||
):
|
||||
test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token,
|
||||
per_out_channel, monkeypatch, ep_size)
|
||||
test_cutlass_moe_8_bit_no_graph(
|
||||
m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k,topk", [(1, 8192, 5120, 31)])
|
||||
@@ -430,8 +438,10 @@ def test_cutlass_moe_8_bit_EP_large(
|
||||
@pytest.mark.parametrize("ep_size", [8])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
def test_run_cutlass_moe_fp8(
|
||||
m: int,
|
||||
n: int,
|
||||
@@ -444,14 +454,12 @@ def test_run_cutlass_moe_fp8(
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
||||
per_out_channel)
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(
|
||||
m, k, n, e, per_act_token, per_out_channel
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.half)
|
||||
topk_weights, topk_ids, _ = fused_topk(mt.a,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False)
|
||||
topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False)
|
||||
# we want to make sure there is at least one token that's generated in
|
||||
# this expert shard and at least one token that's NOT generated in this
|
||||
# expert shard
|
||||
@@ -462,12 +470,12 @@ def test_run_cutlass_moe_fp8(
|
||||
workspace2_shape = (m * topk, max(n, k))
|
||||
output_shape = (m, k)
|
||||
|
||||
workspace13 = torch.empty(prod(workspace13_shape),
|
||||
device="cuda",
|
||||
dtype=mt.a.dtype)
|
||||
workspace2 = torch.empty(prod(workspace2_shape),
|
||||
device="cuda",
|
||||
dtype=mt.a.dtype)
|
||||
workspace13 = torch.empty(
|
||||
prod(workspace13_shape), device="cuda", dtype=mt.a.dtype
|
||||
)
|
||||
workspace2 = torch.empty(
|
||||
prod(workspace2_shape), device="cuda", dtype=mt.a.dtype
|
||||
)
|
||||
|
||||
num_local_experts = e // ep_size
|
||||
start, end = 0, num_local_experts
|
||||
@@ -475,36 +483,55 @@ def test_run_cutlass_moe_fp8(
|
||||
expert_map[start:end] = list(range(num_local_experts))
|
||||
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
|
||||
|
||||
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
|
||||
torch.float8_e4m3fn,
|
||||
per_act_token)
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
mt.a, mt.a_scale, torch.float8_e4m3fn, per_act_token
|
||||
)
|
||||
global_num_experts = -1 if mt.w1_q is None else mt.w1_q.size(0)
|
||||
func = lambda output: run_cutlass_moe_fp8(
|
||||
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
|
||||
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
|
||||
a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2,
|
||||
workspace13, workspace2, None, mt.a.dtype, per_act_token,
|
||||
per_out_channel, False, topk_weights)
|
||||
output,
|
||||
a1q,
|
||||
mt.w1_q,
|
||||
mt.w2_q,
|
||||
topk_ids,
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
mt.w1_scale,
|
||||
mt.w2_scale,
|
||||
a1q_scale,
|
||||
None,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
workspace13,
|
||||
workspace2,
|
||||
None,
|
||||
mt.a.dtype,
|
||||
per_act_token,
|
||||
per_out_channel,
|
||||
False,
|
||||
topk_weights,
|
||||
)
|
||||
|
||||
workspace13.random_()
|
||||
output_random_workspace = torch.empty(output_shape,
|
||||
device="cuda",
|
||||
dtype=mt.a.dtype)
|
||||
output_random_workspace = torch.empty(
|
||||
output_shape, device="cuda", dtype=mt.a.dtype
|
||||
)
|
||||
func(output_random_workspace)
|
||||
|
||||
workspace13.fill_(0)
|
||||
output_zero_workspace = torch.zeros(output_shape,
|
||||
device="cuda",
|
||||
dtype=mt.a.dtype)
|
||||
output_zero_workspace = torch.zeros(
|
||||
output_shape, device="cuda", dtype=mt.a.dtype
|
||||
)
|
||||
func(output_zero_workspace)
|
||||
|
||||
torch.testing.assert_close(output_random_workspace,
|
||||
output_zero_workspace,
|
||||
atol=5e-3,
|
||||
rtol=1e-3)
|
||||
torch.testing.assert_close(
|
||||
output_random_workspace, output_zero_workspace, atol=5e-3, rtol=1e-3
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user