[Kernels] MoE refactor (#19636)
Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
@@ -113,6 +113,7 @@ def bench_run(
|
|||||||
w2_scale: torch.Tensor,
|
w2_scale: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
|
per_act_token: bool,
|
||||||
num_repeats: int,
|
num_repeats: int,
|
||||||
):
|
):
|
||||||
for _ in range(num_repeats):
|
for _ in range(num_repeats):
|
||||||
@@ -124,7 +125,8 @@ def bench_run(
|
|||||||
topk_ids,
|
topk_ids,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
w2_scale,
|
w2_scale,
|
||||||
a1_scale=a_scale,
|
per_act_token,
|
||||||
|
a1_scale=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_cutlass_from_graph(
|
def run_cutlass_from_graph(
|
||||||
@@ -148,7 +150,8 @@ def bench_run(
|
|||||||
topk_ids,
|
topk_ids,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
w2_scale,
|
w2_scale,
|
||||||
a1_scale=a_scale,
|
per_act_token,
|
||||||
|
a1_scale=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_triton_from_graph(
|
def run_triton_from_graph(
|
||||||
@@ -227,6 +230,7 @@ def bench_run(
|
|||||||
"w2_q": w2_q,
|
"w2_q": w2_q,
|
||||||
"w1_scale": w1_scale,
|
"w1_scale": w1_scale,
|
||||||
"w2_scale": w2_scale,
|
"w2_scale": w2_scale,
|
||||||
|
"per_act_token": per_act_token,
|
||||||
# cuda graph params
|
# cuda graph params
|
||||||
"cutlass_graph": cutlass_graph,
|
"cutlass_graph": cutlass_graph,
|
||||||
"triton_graph": triton_graph,
|
"triton_graph": triton_graph,
|
||||||
@@ -287,12 +291,13 @@ def bench_run(
|
|||||||
w2_scale,
|
w2_scale,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
|
per_act_token,
|
||||||
num_warmup,
|
num_warmup,
|
||||||
)
|
)
|
||||||
|
|
||||||
results.append(
|
results.append(
|
||||||
benchmark.Timer(
|
benchmark.Timer(
|
||||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, num_runs)", # noqa: E501
|
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
|
||||||
globals=globals,
|
globals=globals,
|
||||||
label=label,
|
label=label,
|
||||||
sub_label=sub_label,
|
sub_label=sub_label,
|
||||||
|
|||||||
190
tests/kernels/moe/parallel_utils.py
Normal file
190
tests/kernels/moe/parallel_utils.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""
|
||||||
|
DeepEP test utilities
|
||||||
|
"""
|
||||||
|
import dataclasses
|
||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
import traceback
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
from torch.multiprocessing import (
|
||||||
|
spawn) # pyright: ignore[reportPrivateImportUsage]
|
||||||
|
from typing_extensions import Concatenate, ParamSpec
|
||||||
|
|
||||||
|
from vllm.utils import get_open_port
|
||||||
|
|
||||||
|
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
|
||||||
|
if has_deep_ep:
|
||||||
|
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||||
|
DeepEPHTPrepareAndFinalize)
|
||||||
|
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||||
|
DeepEPLLPrepareAndFinalize)
|
||||||
|
|
||||||
|
## Parallel Processes Utils
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ProcessGroupInfo:
|
||||||
|
world_size: int
|
||||||
|
world_local_size: int
|
||||||
|
rank: int
|
||||||
|
node_rank: int
|
||||||
|
local_rank: int
|
||||||
|
device: torch.device
|
||||||
|
|
||||||
|
|
||||||
|
def _worker_parallel_launch(
|
||||||
|
local_rank: int,
|
||||||
|
world_size: int,
|
||||||
|
world_local_size: int,
|
||||||
|
node_rank: int,
|
||||||
|
init_method: str,
|
||||||
|
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||||
|
*args: P.args,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> None:
|
||||||
|
rank = node_rank * world_local_size + local_rank
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
device = torch.device("cuda", local_rank)
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend="cpu:gloo,cuda:nccl",
|
||||||
|
init_method=init_method,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
device_id=device,
|
||||||
|
)
|
||||||
|
barrier = torch.tensor([rank], device=device)
|
||||||
|
torch.distributed.all_reduce(barrier)
|
||||||
|
|
||||||
|
try:
|
||||||
|
worker(
|
||||||
|
ProcessGroupInfo(
|
||||||
|
world_size=world_size,
|
||||||
|
world_local_size=world_local_size,
|
||||||
|
rank=rank,
|
||||||
|
node_rank=node_rank,
|
||||||
|
local_rank=local_rank,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except Exception as ex:
|
||||||
|
print(ex)
|
||||||
|
traceback.print_exc()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
torch.distributed.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
def parallel_launch(
|
||||||
|
world_size: int,
|
||||||
|
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||||
|
*args: P.args,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> None:
|
||||||
|
assert not kwargs
|
||||||
|
spawn(
|
||||||
|
_worker_parallel_launch,
|
||||||
|
args=(
|
||||||
|
world_size,
|
||||||
|
world_size,
|
||||||
|
0,
|
||||||
|
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
|
||||||
|
worker,
|
||||||
|
) + args,
|
||||||
|
nprocs=world_size,
|
||||||
|
join=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
## DeepEP specific utils
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class DeepEPHTArgs:
|
||||||
|
num_local_experts: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class DeepEPLLArgs:
|
||||||
|
max_tokens_per_rank: int
|
||||||
|
hidden_size: int
|
||||||
|
num_experts: int
|
||||||
|
use_fp8_dispatch: bool
|
||||||
|
|
||||||
|
|
||||||
|
def make_deepep_ht_a2a(pg: ProcessGroup,
|
||||||
|
pgi: ProcessGroupInfo,
|
||||||
|
dp_size: int,
|
||||||
|
ht_args: DeepEPHTArgs,
|
||||||
|
q_dtype: Optional[torch.dtype] = None,
|
||||||
|
block_shape: Optional[list[int]] = None):
|
||||||
|
|
||||||
|
import deep_ep
|
||||||
|
|
||||||
|
# high throughput a2a
|
||||||
|
num_nvl_bytes = 1024 * 1024 * 1024 # 1GB
|
||||||
|
num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1
|
||||||
|
buffer = deep_ep.Buffer(group=pg,
|
||||||
|
num_nvl_bytes=num_nvl_bytes,
|
||||||
|
num_rdma_bytes=num_rdma_bytes,
|
||||||
|
low_latency_mode=low_latency_mode,
|
||||||
|
num_qps_per_rank=num_qps_per_rank)
|
||||||
|
return DeepEPHTPrepareAndFinalize(buffer=buffer,
|
||||||
|
world_size=pgi.world_size,
|
||||||
|
rank=pgi.rank,
|
||||||
|
dp_size=dp_size,
|
||||||
|
rank_expert_offset=pgi.rank *
|
||||||
|
ht_args.num_local_experts)
|
||||||
|
|
||||||
|
|
||||||
|
def make_deepep_ll_a2a(pg: ProcessGroup,
|
||||||
|
pgi: ProcessGroupInfo,
|
||||||
|
dp_size: int,
|
||||||
|
deepep_ll_args: DeepEPLLArgs,
|
||||||
|
q_dtype: Optional[torch.dtype] = None,
|
||||||
|
block_shape: Optional[list[int]] = None):
|
||||||
|
|
||||||
|
import deep_ep
|
||||||
|
|
||||||
|
# low-latency a2a
|
||||||
|
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||||
|
deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size,
|
||||||
|
pgi.world_size, deepep_ll_args.num_experts)
|
||||||
|
|
||||||
|
buffer = deep_ep.Buffer(group=pg,
|
||||||
|
num_rdma_bytes=num_rdma_bytes,
|
||||||
|
low_latency_mode=True,
|
||||||
|
num_qps_per_rank=deepep_ll_args.num_experts //
|
||||||
|
pgi.world_size)
|
||||||
|
|
||||||
|
return DeepEPLLPrepareAndFinalize(
|
||||||
|
buffer=buffer,
|
||||||
|
world_size=pgi.world_size,
|
||||||
|
dp_size=dp_size,
|
||||||
|
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
|
||||||
|
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_deepep_a2a(pg: ProcessGroup,
|
||||||
|
pgi: ProcessGroupInfo,
|
||||||
|
dp_size: int,
|
||||||
|
deepep_ht_args: Optional[DeepEPHTArgs],
|
||||||
|
deepep_ll_args: Optional[DeepEPLLArgs],
|
||||||
|
q_dtype: Optional[torch.dtype] = None,
|
||||||
|
block_shape: Optional[list[int]] = None):
|
||||||
|
if deepep_ht_args is not None:
|
||||||
|
assert deepep_ll_args is None
|
||||||
|
return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype,
|
||||||
|
block_shape)
|
||||||
|
|
||||||
|
assert deepep_ll_args is not None
|
||||||
|
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype,
|
||||||
|
block_shape)
|
||||||
@@ -2,18 +2,59 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
from tests.kernels.moe.utils import (batched_moe,
|
||||||
|
make_quantized_test_activations,
|
||||||
|
make_test_weights, triton_moe)
|
||||||
|
from tests.kernels.quant_utils import native_batched_masked_quant_matmul
|
||||||
|
from tests.kernels.utils import torch_experts
|
||||||
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||||
invoke_moe_batched_triton_kernel)
|
invoke_moe_batched_triton_kernel)
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
MNK_FACTORS = [
|
||||||
|
(1, 128, 128),
|
||||||
|
(1, 128, 2048),
|
||||||
|
(1, 512, 512),
|
||||||
|
(1, 1024, 128),
|
||||||
|
(1, 1024, 2048),
|
||||||
|
(32, 128, 128),
|
||||||
|
(32, 512, 512),
|
||||||
|
(32, 1024, 2048),
|
||||||
|
(45, 128, 128),
|
||||||
|
(45, 128, 2048),
|
||||||
|
(45, 512, 512),
|
||||||
|
(45, 1024, 128),
|
||||||
|
(45, 1024, 2048),
|
||||||
|
(64, 128, 128),
|
||||||
|
(64, 512, 512),
|
||||||
|
(64, 1024, 2048),
|
||||||
|
(222, 128, 128),
|
||||||
|
(222, 128, 2048),
|
||||||
|
(222, 512, 512),
|
||||||
|
(222, 1024, 128),
|
||||||
|
(222, 1024, 2048),
|
||||||
|
]
|
||||||
|
NUM_EXPERTS = [8, 64]
|
||||||
|
TOP_KS = [1, 2, 6]
|
||||||
|
|
||||||
|
vllm_config = VllmConfig()
|
||||||
|
vllm_config.scheduler_config.max_num_seqs = 128
|
||||||
|
vllm_config.scheduler_config.max_model_len = 8192
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchedMMConfig:
|
class BatchedMMConfig:
|
||||||
dtype: torch.dtype
|
in_dtype: torch.dtype
|
||||||
|
quant_dtype: Optional[torch.dtype]
|
||||||
|
out_dtype: torch.dtype
|
||||||
num_experts: int
|
num_experts: int
|
||||||
max_tokens_per_expert: int
|
max_tokens_per_expert: int
|
||||||
K: int
|
K: int
|
||||||
@@ -32,79 +73,127 @@ class BatchedMMTensors:
|
|||||||
A = torch.randn(
|
A = torch.randn(
|
||||||
(config.num_experts, config.max_tokens_per_expert, config.K),
|
(config.num_experts, config.max_tokens_per_expert, config.K),
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dtype=config.dtype) / 10
|
dtype=config.in_dtype) / 10
|
||||||
B = torch.randn((config.num_experts, config.N, config.K),
|
B = torch.randn((config.num_experts, config.N, config.K),
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dtype=config.dtype)
|
dtype=config.in_dtype)
|
||||||
C = torch.zeros(
|
C = torch.zeros(
|
||||||
(config.num_experts, config.max_tokens_per_expert, config.N),
|
(config.num_experts, config.max_tokens_per_expert, config.N),
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dtype=config.dtype)
|
dtype=config.out_dtype)
|
||||||
|
|
||||||
num_expert_tokens = torch.randint(low=0,
|
num_expert_tokens = torch.randint(low=0,
|
||||||
high=config.max_tokens_per_expert,
|
high=config.max_tokens_per_expert,
|
||||||
size=(config.num_experts, ),
|
size=(config.num_experts, ),
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
|
|
||||||
return BatchedMMTensors(A, B, C, num_expert_tokens)
|
return BatchedMMTensors(A, B, C, num_expert_tokens)
|
||||||
|
|
||||||
|
|
||||||
def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
@pytest.mark.parametrize("num_experts", [8, 16, 32])
|
||||||
num_expert_tokens: torch.Tensor) -> torch.Tensor:
|
|
||||||
|
|
||||||
num_expert_tokens_cpu = num_expert_tokens.clone()
|
|
||||||
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
|
|
||||||
num_experts = num_expert_tokens.size(0)
|
|
||||||
|
|
||||||
for e in range(num_experts):
|
|
||||||
num_tokens = num_expert_tokens_cpu[e]
|
|
||||||
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
|
|
||||||
|
|
||||||
return C
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_experts", [16, 32])
|
|
||||||
@pytest.mark.parametrize("max_tokens_per_expert",
|
@pytest.mark.parametrize("max_tokens_per_expert",
|
||||||
[32, 64, 128, 192, 224, 256, 512])
|
[32, 64, 128, 192, 224, 256, 512])
|
||||||
@pytest.mark.parametrize("K", [128, 256, 1024])
|
@pytest.mark.parametrize("K", [128, 256, 1024])
|
||||||
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
|
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
|
||||||
@pytest.mark.parametrize("dtype",
|
@pytest.mark.parametrize("dtype",
|
||||||
[torch.float32, torch.float16, torch.bfloat16])
|
[torch.float32, torch.float16, torch.bfloat16])
|
||||||
|
@pytest.mark.parametrize("block_shape", [None])
|
||||||
|
@pytest.mark.parametrize("per_act_token_quant", [False])
|
||||||
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||||
N: int, dtype: torch.dtype):
|
N: int, dtype: torch.dtype,
|
||||||
|
block_shape: Optional[list[int]],
|
||||||
|
per_act_token_quant: bool):
|
||||||
|
current_platform.seed_everything(7)
|
||||||
|
|
||||||
config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N)
|
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
|
||||||
tensors = BatchedMMTensors.make_tensors(config)
|
|
||||||
|
|
||||||
test_output = tensors.C
|
if (per_act_token_quant or block_shape is not None) and not use_fp8_w8a8:
|
||||||
ref_output = test_output.clone()
|
pytest.skip("Don't test blocking for non-quantized types.")
|
||||||
|
|
||||||
|
if per_act_token_quant and block_shape is not None:
|
||||||
|
pytest.skip("Skip illegal quantization test.")
|
||||||
|
|
||||||
|
if dtype.itemsize == 1:
|
||||||
|
act_dtype = torch.bfloat16
|
||||||
|
quant_dtype = dtype
|
||||||
|
else:
|
||||||
|
act_dtype = dtype
|
||||||
|
quant_dtype = None
|
||||||
|
|
||||||
|
num_expert_tokens = torch.randint(low=0,
|
||||||
|
high=max_tokens_per_expert,
|
||||||
|
size=(num_experts, ),
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
A, A_q, A_scale = make_quantized_test_activations(
|
||||||
|
num_experts,
|
||||||
|
max_tokens_per_expert,
|
||||||
|
K,
|
||||||
|
in_dtype=act_dtype,
|
||||||
|
quant_dtype=quant_dtype,
|
||||||
|
block_shape=block_shape,
|
||||||
|
per_act_token_quant=per_act_token_quant)
|
||||||
|
|
||||||
|
B, B_q, B_scale, _, _, _ = make_test_weights(
|
||||||
|
num_experts,
|
||||||
|
N // 2,
|
||||||
|
K,
|
||||||
|
in_dtype=act_dtype,
|
||||||
|
quant_dtype=quant_dtype,
|
||||||
|
block_shape=block_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
out_shape = (num_experts, max_tokens_per_expert, N)
|
||||||
|
test_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
|
||||||
|
ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
|
||||||
|
q_ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
|
||||||
|
|
||||||
compute_tl_dtype = {
|
compute_tl_dtype = {
|
||||||
torch.float16: tl.float16,
|
torch.float16: tl.float16,
|
||||||
torch.bfloat16: tl.bfloat16,
|
torch.bfloat16: tl.bfloat16,
|
||||||
torch.float32: tl.float32
|
torch.float32: tl.float32
|
||||||
}[test_output.dtype]
|
}[test_output.dtype]
|
||||||
|
|
||||||
|
assert A_q.dtype == B_q.dtype
|
||||||
|
|
||||||
invoke_moe_batched_triton_kernel(
|
invoke_moe_batched_triton_kernel(
|
||||||
tensors.A,
|
A_q,
|
||||||
tensors.B,
|
B_q,
|
||||||
test_output,
|
test_output,
|
||||||
tensors.num_expert_tokens,
|
num_expert_tokens,
|
||||||
compute_tl_dtype,
|
compute_tl_dtype,
|
||||||
# Quantization data
|
# Quantization data
|
||||||
None,
|
A_scale,
|
||||||
None,
|
B_scale,
|
||||||
None,
|
None,
|
||||||
# Quantization schemes
|
# Quantization schemes
|
||||||
False,
|
use_fp8_w8a8,
|
||||||
False,
|
False,
|
||||||
False,
|
False,
|
||||||
config={
|
config={
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": 16,
|
||||||
"BLOCK_SIZE_N": 16,
|
"BLOCK_SIZE_N": 16,
|
||||||
"BLOCK_SIZE_K": 16
|
"BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32
|
||||||
})
|
},
|
||||||
|
block_shape=block_shape,
|
||||||
|
)
|
||||||
|
|
||||||
ref_output = ref_impl(tensors.A, tensors.B, ref_output,
|
ref_output = native_batched_masked_quant_matmul(
|
||||||
tensors.num_expert_tokens)
|
A,
|
||||||
|
B,
|
||||||
|
ref_output,
|
||||||
|
num_expert_tokens,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
q_ref_output = native_batched_masked_quant_matmul(A_q, B_q, q_ref_output,
|
||||||
|
num_expert_tokens,
|
||||||
|
A_scale, B_scale,
|
||||||
|
block_shape)
|
||||||
|
|
||||||
rtol, atol = {
|
rtol, atol = {
|
||||||
torch.float16: (6e-2, 6e-2),
|
torch.float16: (6e-2, 6e-2),
|
||||||
@@ -112,4 +201,98 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
|||||||
torch.float32: (1e-2, 1e-2),
|
torch.float32: (1e-2, 1e-2),
|
||||||
}[test_output.dtype]
|
}[test_output.dtype]
|
||||||
|
|
||||||
torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol)
|
torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol)
|
||||||
|
torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
|
||||||
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||||
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||||
|
@pytest.mark.parametrize("per_act_token_quant", [False])
|
||||||
|
@pytest.mark.parametrize("block_shape", [None])
|
||||||
|
def test_fused_moe_batched_experts(
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
e: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
per_act_token_quant: bool,
|
||||||
|
block_shape: Optional[list[int]],
|
||||||
|
):
|
||||||
|
current_platform.seed_everything(7)
|
||||||
|
|
||||||
|
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
|
||||||
|
|
||||||
|
if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
|
||||||
|
pytest.skip("Skip quantization test for non-quantized type")
|
||||||
|
|
||||||
|
if per_act_token_quant and block_shape is not None or topk > e:
|
||||||
|
pytest.skip("Skip illegal quantization test.")
|
||||||
|
|
||||||
|
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||||
|
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
if dtype.itemsize == 1:
|
||||||
|
act_dtype = torch.bfloat16
|
||||||
|
quant_dtype = dtype
|
||||||
|
else:
|
||||||
|
act_dtype = dtype
|
||||||
|
quant_dtype = None
|
||||||
|
|
||||||
|
_, w1, w1_s, _, w2, w2_s = make_test_weights(e,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
block_shape=block_shape,
|
||||||
|
in_dtype=act_dtype,
|
||||||
|
quant_dtype=quant_dtype)
|
||||||
|
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||||
|
batched_output = batched_moe(
|
||||||
|
a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weight,
|
||||||
|
topk_ids,
|
||||||
|
w1_scale=w1_s,
|
||||||
|
w2_scale=w2_s,
|
||||||
|
quant_dtype=quant_dtype,
|
||||||
|
per_act_token_quant=per_act_token_quant,
|
||||||
|
block_shape=block_shape,
|
||||||
|
)
|
||||||
|
baseline_output = torch_experts(
|
||||||
|
a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weight,
|
||||||
|
topk_ids,
|
||||||
|
w1_scale=w1_s,
|
||||||
|
w2_scale=w2_s,
|
||||||
|
quant_dtype=quant_dtype,
|
||||||
|
per_act_token_quant=per_act_token_quant,
|
||||||
|
block_shape=block_shape)
|
||||||
|
|
||||||
|
triton_output = triton_moe(
|
||||||
|
a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weight,
|
||||||
|
topk_ids,
|
||||||
|
w1_scale=w1_s,
|
||||||
|
w2_scale=w2_s,
|
||||||
|
quant_dtype=quant_dtype,
|
||||||
|
per_act_token_quant=per_act_token_quant,
|
||||||
|
block_shape=block_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(triton_output,
|
||||||
|
baseline_output,
|
||||||
|
atol=2e-2,
|
||||||
|
rtol=2e-2)
|
||||||
|
|
||||||
|
torch.testing.assert_close(triton_output,
|
||||||
|
batched_output,
|
||||||
|
atol=2e-2,
|
||||||
|
rtol=2e-2)
|
||||||
|
|||||||
296
tests/kernels/moe/test_block_fp8.py
Normal file
296
tests/kernels/moe/test_block_fp8.py
Normal file
@@ -0,0 +1,296 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.kernels.moe.utils import make_test_weights
|
||||||
|
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
|
||||||
|
native_w8a8_block_matmul)
|
||||||
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
|
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||||
|
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
|
fused_topk, modular_triton_fused_moe)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
dg_available = False
|
||||||
|
try:
|
||||||
|
import deep_gemm
|
||||||
|
dg_available = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if current_platform.get_device_capability() < (9, 0):
|
||||||
|
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
||||||
|
allow_module_level=True)
|
||||||
|
|
||||||
|
vllm_config = VllmConfig()
|
||||||
|
vllm_config.scheduler_config.max_num_seqs = 128
|
||||||
|
vllm_config.scheduler_config.max_model_len = 8192
|
||||||
|
|
||||||
|
# Test configurations
|
||||||
|
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
|
||||||
|
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
|
||||||
|
# and its hidden size is 7168.
|
||||||
|
MNK_FACTORS = [
|
||||||
|
(1, 128, 128),
|
||||||
|
(1, 512, 512),
|
||||||
|
(1, 128, 7168),
|
||||||
|
(1, 1024, 7168),
|
||||||
|
(1, 4608, 128),
|
||||||
|
(1, 4608, 512),
|
||||||
|
(1, 4608, 7168),
|
||||||
|
(83, 128, 128),
|
||||||
|
(83, 512, 512),
|
||||||
|
(83, 1024, 7168),
|
||||||
|
(83, 4608, 512),
|
||||||
|
(83, 4608, 7168),
|
||||||
|
(128, 128, 128),
|
||||||
|
(128, 512, 512),
|
||||||
|
(128, 1024, 7168),
|
||||||
|
(128, 4608, 512),
|
||||||
|
(128, 4608, 7168),
|
||||||
|
(2048, 128, 128),
|
||||||
|
(2048, 1024, 7168),
|
||||||
|
(2048, 4608, 512),
|
||||||
|
(2048, 4608, 7168),
|
||||||
|
(8192, 128, 128),
|
||||||
|
(8192, 512, 512),
|
||||||
|
(8192, 128, 7168),
|
||||||
|
(8192, 1024, 7168),
|
||||||
|
(8192, 4608, 512),
|
||||||
|
(8192, 4608, 7168),
|
||||||
|
]
|
||||||
|
|
||||||
|
MNK_FACTORS_DG = [
|
||||||
|
(128, 128, 128),
|
||||||
|
(128, 512, 512),
|
||||||
|
(128, 128, 7168),
|
||||||
|
(128, 1024, 7168),
|
||||||
|
(128, 4608, 128),
|
||||||
|
(128, 4608, 512),
|
||||||
|
(128, 4608, 7168),
|
||||||
|
(192, 128, 128),
|
||||||
|
(192, 512, 512),
|
||||||
|
(192, 1024, 7168),
|
||||||
|
(192, 4608, 512),
|
||||||
|
(192, 4608, 7168),
|
||||||
|
(1335, 128, 128),
|
||||||
|
(1335, 1024, 7168),
|
||||||
|
(1335, 4608, 512),
|
||||||
|
(1335, 4608, 7168),
|
||||||
|
(2048, 128, 128),
|
||||||
|
(2048, 512, 512),
|
||||||
|
(2048, 128, 7168),
|
||||||
|
(2048, 1024, 7168),
|
||||||
|
(2048, 4608, 128),
|
||||||
|
(2048, 4608, 512),
|
||||||
|
(2048, 4608, 7168),
|
||||||
|
]
|
||||||
|
|
||||||
|
BLOCK_SIZE = [[128, 128]]
|
||||||
|
E = [2, 8, 16] # [128, 256]
|
||||||
|
TOP_KS = [1, 2, 6]
|
||||||
|
SEEDS = [0]
|
||||||
|
|
||||||
|
|
||||||
|
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids,
|
||||||
|
block_shape):
|
||||||
|
"""Fused moe with block-wise quantization using native torch."""
|
||||||
|
B, D = a.shape
|
||||||
|
topk = topk_ids.size(1)
|
||||||
|
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||||
|
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||||
|
|
||||||
|
topk_weight = topk_weight.view(-1)
|
||||||
|
topk_ids = topk_ids.view(-1)
|
||||||
|
|
||||||
|
_, block_k = block_shape[0], block_shape[1]
|
||||||
|
a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
|
||||||
|
a_q = a_q.to(torch.float32)
|
||||||
|
for i in range(w1.shape[0]):
|
||||||
|
mask = topk_ids == i
|
||||||
|
if mask.sum():
|
||||||
|
inter_out = native_w8a8_block_matmul(a_q[mask],
|
||||||
|
w1[i],
|
||||||
|
a_s[mask],
|
||||||
|
w1_s[i],
|
||||||
|
block_shape,
|
||||||
|
output_dtype=a.dtype)
|
||||||
|
act_out = SiluAndMul().forward_native(inter_out)
|
||||||
|
act_out_q, act_out_s = native_per_token_group_quant_fp8(
|
||||||
|
act_out, block_k)
|
||||||
|
out[mask] = native_w8a8_block_matmul(act_out_q,
|
||||||
|
w2[i],
|
||||||
|
act_out_s,
|
||||||
|
w2_s[i],
|
||||||
|
block_shape,
|
||||||
|
output_dtype=a.dtype)
|
||||||
|
return (out.view(B, -1, w2.shape[1]) *
|
||||||
|
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
# Skip all tests if CUDA is not available
|
||||||
|
pytest.importorskip("torch.cuda")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_cuda():
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS)
|
||||||
|
@pytest.mark.parametrize("E", E)
|
||||||
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
|
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
||||||
|
monkeypatch):
|
||||||
|
if topk > E:
|
||||||
|
pytest.skip(f"Skipping test; topk={topk} > E={E}")
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "2048")
|
||||||
|
|
||||||
|
a = torch.randn((M, K), dtype=dtype) / 10
|
||||||
|
score = torch.randn((M, E), dtype=dtype)
|
||||||
|
|
||||||
|
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
dtype,
|
||||||
|
torch.float8_e4m3fn,
|
||||||
|
per_act_token_quant=False,
|
||||||
|
block_shape=block_size)
|
||||||
|
|
||||||
|
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
|
||||||
|
use_int8_w8a8=False,
|
||||||
|
use_int8_w8a16=False,
|
||||||
|
use_int4_w4a16=False,
|
||||||
|
per_act_token_quant=False,
|
||||||
|
block_shape=block_size)
|
||||||
|
|
||||||
|
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
||||||
|
|
||||||
|
# Set the context to avoid lots of warning spam.
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
ref_out = torch_w8a8_block_fp8_moe(
|
||||||
|
a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
w1_s,
|
||||||
|
w2_s,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = fused_experts(
|
||||||
|
a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
w1_scale=w1_s,
|
||||||
|
w2_scale=w2_s,
|
||||||
|
block_shape=block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
m_out = m_fused_moe(
|
||||||
|
a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
w1_scale=w1_s,
|
||||||
|
w2_scale=w2_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 0.039 only needed for [40000-4608-7168-2-1-block_size852-dtype852-0]
|
||||||
|
tol = 0.035 if M < 40000 else 0.039
|
||||||
|
torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol)
|
||||||
|
torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS_DG)
|
||||||
|
@pytest.mark.parametrize("E", E)
|
||||||
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
|
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
||||||
|
monkeypatch):
|
||||||
|
if topk > E:
|
||||||
|
pytest.skip(f"Skipping test: topk={topk} > E={E}")
|
||||||
|
|
||||||
|
if not _valid_deep_gemm_shape(M, N, K):
|
||||||
|
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
|
||||||
|
|
||||||
|
chunk_size = 1024
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
|
||||||
|
|
||||||
|
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
|
||||||
|
block_size = [block_m, block_m]
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
a = torch.randn((M, K), dtype=dtype) / 10
|
||||||
|
score = torch.randn((M, E), dtype=dtype)
|
||||||
|
|
||||||
|
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
dtype,
|
||||||
|
torch.float8_e4m3fn,
|
||||||
|
per_act_token_quant=False,
|
||||||
|
block_shape=block_size)
|
||||||
|
|
||||||
|
# Note: for now use_compile will error out if the problem size is
|
||||||
|
# large enough to trigger chunking. I'm leaving the flag and
|
||||||
|
# setup code in case we are able to revisit this later.
|
||||||
|
use_compile = False
|
||||||
|
|
||||||
|
use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024
|
||||||
|
and current_platform.is_cuda_alike())
|
||||||
|
|
||||||
|
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
||||||
|
|
||||||
|
# Set the context to avoid lots of warning spam.
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weights,
|
||||||
|
topk_ids, block_size)
|
||||||
|
|
||||||
|
if use_compile:
|
||||||
|
deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8,
|
||||||
|
backend="inductor",
|
||||||
|
fullgraph=True)
|
||||||
|
torch._dynamo.mark_dynamic(a, 0)
|
||||||
|
torch._dynamo.mark_dynamic(topk_weights, 0)
|
||||||
|
torch._dynamo.mark_dynamic(topk_ids, 0)
|
||||||
|
else:
|
||||||
|
deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8
|
||||||
|
|
||||||
|
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
|
||||||
|
topk_ids)
|
||||||
|
|
||||||
|
if use_cudagraph:
|
||||||
|
out.fill_(0)
|
||||||
|
stream = torch.cuda.Stream()
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(graph, stream=stream):
|
||||||
|
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
|
||||||
|
topk_ids)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
graph.replay()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
torch.testing.assert_close(out, ref_out, atol=0.035, rtol=0.035)
|
||||||
147
tests/kernels/moe/test_block_int8.py
Normal file
147
tests/kernels/moe/test_block_int8.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.kernels.moe.utils import make_test_weights
|
||||||
|
from tests.kernels.quant_utils import (native_per_token_group_quant_int8,
|
||||||
|
native_w8a8_block_matmul)
|
||||||
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
if current_platform.get_device_capability() < (7, 0):
|
||||||
|
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
|
||||||
|
allow_module_level=True)
|
||||||
|
|
||||||
|
vllm_config = VllmConfig()
|
||||||
|
vllm_config.scheduler_config.max_num_seqs = 128
|
||||||
|
vllm_config.scheduler_config.max_model_len = 8192
|
||||||
|
|
||||||
|
DTYPES = [torch.half, torch.bfloat16]
|
||||||
|
|
||||||
|
MNK_FACTORS = [
|
||||||
|
(1, 128, 128),
|
||||||
|
(1, 512, 512),
|
||||||
|
(1, 128, 7168),
|
||||||
|
(1, 1024, 7168),
|
||||||
|
(1, 4096, 128),
|
||||||
|
(1, 4096, 512),
|
||||||
|
(1, 4096, 7168),
|
||||||
|
(33, 128, 128),
|
||||||
|
(33, 512, 512),
|
||||||
|
(33, 128, 7168),
|
||||||
|
(33, 1024, 7168),
|
||||||
|
(33, 4096, 128),
|
||||||
|
(33, 4096, 512),
|
||||||
|
(33, 4096, 7168),
|
||||||
|
(128, 128, 128),
|
||||||
|
(128, 512, 512),
|
||||||
|
(128, 1024, 7168),
|
||||||
|
(128, 4096, 512),
|
||||||
|
(128, 4096, 7168),
|
||||||
|
(222, 128, 128),
|
||||||
|
(222, 512, 512),
|
||||||
|
(222, 1024, 7168),
|
||||||
|
(222, 4096, 512),
|
||||||
|
(222, 4096, 7168),
|
||||||
|
(2048, 128, 128),
|
||||||
|
(2048, 1024, 7168),
|
||||||
|
(2048, 4096, 512),
|
||||||
|
(2048, 4096, 7168),
|
||||||
|
]
|
||||||
|
|
||||||
|
E = [8, 24]
|
||||||
|
TOP_KS = [2, 6]
|
||||||
|
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
|
||||||
|
BLOCK_SIZE = [[128, 128]]
|
||||||
|
SEEDS = [0]
|
||||||
|
|
||||||
|
|
||||||
|
# For test
|
||||||
|
def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
|
||||||
|
"""This function performs fused moe with block-wise quantization using
|
||||||
|
native torch."""
|
||||||
|
B, D = a.shape
|
||||||
|
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||||
|
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||||
|
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||||
|
topk_weight, topk_ids = torch.topk(score, topk)
|
||||||
|
topk_weight = topk_weight.view(-1)
|
||||||
|
topk_ids = topk_ids.view(-1)
|
||||||
|
|
||||||
|
_, block_k = block_shape[0], block_shape[1]
|
||||||
|
a_q, a_s = native_per_token_group_quant_int8(a, block_k)
|
||||||
|
for i in range(w1.shape[0]):
|
||||||
|
mask = topk_ids == i
|
||||||
|
if mask.sum():
|
||||||
|
inter_out = native_w8a8_block_matmul(a_q[mask],
|
||||||
|
w1[i],
|
||||||
|
a_s[mask],
|
||||||
|
w1_s[i],
|
||||||
|
block_shape,
|
||||||
|
output_dtype=a.dtype)
|
||||||
|
act_out = SiluAndMul().forward_native(inter_out)
|
||||||
|
act_out_q, act_out_s = native_per_token_group_quant_int8(
|
||||||
|
act_out, block_k)
|
||||||
|
act_out = act_out.to(torch.float32)
|
||||||
|
out[mask] = native_w8a8_block_matmul(act_out_q,
|
||||||
|
w2[i],
|
||||||
|
act_out_s,
|
||||||
|
w2_s[i],
|
||||||
|
block_shape,
|
||||||
|
output_dtype=a.dtype)
|
||||||
|
return (out.view(B, -1, w2.shape[1]) *
|
||||||
|
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True, scope="module")
|
||||||
|
def setup_cuda():
|
||||||
|
"""Sets the default CUDA device for all tests in this module."""
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS)
|
||||||
|
@pytest.mark.parametrize("E", E)
|
||||||
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
|
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
||||||
|
"""Tests the fused_moe kernel with W8A8 INT8 block quantization against a
|
||||||
|
native torch reference."""
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
a = torch.randn((M, K), dtype=dtype) / 10
|
||||||
|
score = torch.randn((M, E), dtype=dtype)
|
||||||
|
|
||||||
|
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
dtype,
|
||||||
|
torch.int8,
|
||||||
|
per_act_token_quant=False,
|
||||||
|
block_shape=block_size)
|
||||||
|
|
||||||
|
# Set the context to avoid lots of warning spam.
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
out = fused_moe(
|
||||||
|
a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
score,
|
||||||
|
topk,
|
||||||
|
renormalize=False,
|
||||||
|
use_int8_w8a8=True,
|
||||||
|
w1_scale=w1_s,
|
||||||
|
w2_scale=w2_s,
|
||||||
|
block_shape=block_size,
|
||||||
|
)
|
||||||
|
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk,
|
||||||
|
block_size)
|
||||||
|
|
||||||
|
# Check results
|
||||||
|
torch.testing.assert_close(out, ref_out, atol=0.065, rtol=0.065)
|
||||||
@@ -97,11 +97,9 @@ class MOETensors8Bit(MOETensors):
|
|||||||
n_b_scales = 2 * n if per_out_channel else 1
|
n_b_scales = 2 * n if per_out_channel else 1
|
||||||
k_b_scales = k if per_out_channel else 1
|
k_b_scales = k if per_out_channel else 1
|
||||||
# Get the right scale for tests.
|
# Get the right scale for tests.
|
||||||
_, a_scale = ops.scaled_fp8_quant(
|
a_q, a_scale = ops.scaled_fp8_quant(
|
||||||
moe_tensors_fp16.a, use_per_token_if_dynamic=per_act_token)
|
moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token)
|
||||||
a_q, _ = ops.scaled_fp8_quant(moe_tensors_fp16.a,
|
|
||||||
a_scale,
|
|
||||||
use_per_token_if_dynamic=per_act_token)
|
|
||||||
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
|
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
|
||||||
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
|
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
|
||||||
|
|
||||||
@@ -187,6 +185,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
|
|||||||
def run_8_bit(moe_tensors: MOETensors8Bit,
|
def run_8_bit(moe_tensors: MOETensors8Bit,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
|
per_act_token: bool,
|
||||||
num_local_experts: Optional[int] = None) -> torch.Tensor:
|
num_local_experts: Optional[int] = None) -> torch.Tensor:
|
||||||
assert not any([
|
assert not any([
|
||||||
t is None for t in [
|
t is None for t in [
|
||||||
@@ -203,7 +202,8 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
|
|||||||
'topk_ids': topk_ids,
|
'topk_ids': topk_ids,
|
||||||
'w1_scale': moe_tensors.w1_scale,
|
'w1_scale': moe_tensors.w1_scale,
|
||||||
'w2_scale': moe_tensors.w2_scale,
|
'w2_scale': moe_tensors.w2_scale,
|
||||||
'a1_scale': moe_tensors.a_scale
|
'per_act_token': per_act_token,
|
||||||
|
'a1_scale': None #moe_tensors.a_scale
|
||||||
}
|
}
|
||||||
|
|
||||||
num_experts = moe_tensors.w1.size(0)
|
num_experts = moe_tensors.w1.size(0)
|
||||||
@@ -254,11 +254,13 @@ def test_cutlass_moe_8_bit_no_graph(
|
|||||||
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
|
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
|
||||||
topk_ids)
|
topk_ids)
|
||||||
|
|
||||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids)
|
cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token)
|
||||||
|
|
||||||
|
# Note 5.5 only needed for larger problem sizes, 5 works ok for
|
||||||
|
# the rest.
|
||||||
torch.testing.assert_close(triton_output,
|
torch.testing.assert_close(triton_output,
|
||||||
cutlass_output,
|
cutlass_output,
|
||||||
atol=5e-2,
|
atol=5.5e-2,
|
||||||
rtol=1e-2)
|
rtol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
@@ -303,7 +305,8 @@ def test_cutlass_moe_8_bit_cuda_graph(
|
|||||||
stream = torch.cuda.Stream()
|
stream = torch.cuda.Stream()
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
with torch.cuda.graph(graph, stream=stream):
|
with torch.cuda.graph(graph, stream=stream):
|
||||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids)
|
cutlass_output = run_8_bit(mt, topk_weights, topk_ids,
|
||||||
|
per_act_token)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
graph.replay()
|
graph.replay()
|
||||||
@@ -359,6 +362,7 @@ def test_cutlass_moe_8_bit_EP(
|
|||||||
cutlass_output = run_8_bit(mt,
|
cutlass_output = run_8_bit(mt,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
|
per_act_token,
|
||||||
num_local_experts=e // ep_size)
|
num_local_experts=e // ep_size)
|
||||||
|
|
||||||
torch.testing.assert_close(triton_output,
|
torch.testing.assert_close(triton_output,
|
||||||
|
|||||||
@@ -17,12 +17,11 @@ from vllm.config import VllmConfig, set_current_vllm_config
|
|||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||||
FusedMoEModularKernel)
|
FusedMoEModularKernel)
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
|
||||||
per_token_group_quant_fp8)
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import has_deep_ep, has_deep_gemm
|
from vllm.utils import has_deep_ep, has_deep_gemm
|
||||||
|
|
||||||
from .utils import ProcessGroupInfo, parallel_launch
|
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||||
|
from .utils import make_test_weights
|
||||||
|
|
||||||
if has_deep_ep():
|
if has_deep_ep():
|
||||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||||
@@ -30,10 +29,9 @@ if has_deep_ep():
|
|||||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||||
DeepEPLLPrepareAndFinalize)
|
DeepEPLLPrepareAndFinalize)
|
||||||
|
|
||||||
from .utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||||
|
|
||||||
if has_deep_gemm():
|
if has_deep_gemm():
|
||||||
import deep_gemm
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||||
BatchedDeepGemmExperts)
|
BatchedDeepGemmExperts)
|
||||||
@@ -60,25 +58,6 @@ def next_power_of_2(x):
|
|||||||
return 2**math.ceil(math.log2(x))
|
return 2**math.ceil(math.log2(x))
|
||||||
|
|
||||||
|
|
||||||
def per_block_cast_to_fp8(
|
|
||||||
x: torch.Tensor,
|
|
||||||
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
assert x.dim() == 2
|
|
||||||
m, n = x.shape
|
|
||||||
x_padded = torch.zeros(
|
|
||||||
(deep_gemm.ceil_div(m, 128) * 128,
|
|
||||||
deep_gemm.ceil_div(n, block_size_n) * block_size_n),
|
|
||||||
dtype=x.dtype,
|
|
||||||
device=x.device)
|
|
||||||
x_padded[:m, :n] = x
|
|
||||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
|
|
||||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
|
||||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
|
||||||
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
|
|
||||||
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
|
||||||
return x_scaled_sub, scales
|
|
||||||
|
|
||||||
|
|
||||||
def make_block_quant_fp8_weights(
|
def make_block_quant_fp8_weights(
|
||||||
e: int,
|
e: int,
|
||||||
n: int,
|
n: int,
|
||||||
@@ -86,43 +65,11 @@ def make_block_quant_fp8_weights(
|
|||||||
block_size: list[int],
|
block_size: list[int],
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Return weights w1, w2, w1q, w2q, w1_scale, w2_scale
|
Return weights w1q, w2q, w1_scale, w2_scale
|
||||||
"""
|
"""
|
||||||
dtype = torch.bfloat16
|
w1, w1q, w1_scale, w2, w2q, w2_scale = make_test_weights(
|
||||||
|
e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_size)
|
||||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
return w1q, w2q, w1_scale, w2_scale
|
||||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
|
||||||
|
|
||||||
w1_bf16 = torch.randn((e, 2 * n, k), dtype=dtype) / 10
|
|
||||||
w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
|
|
||||||
|
|
||||||
w2_bf16 = torch.randn((e, k, n), dtype=dtype) / 10
|
|
||||||
w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
|
|
||||||
|
|
||||||
block_n, block_k = block_size[0], block_size[1]
|
|
||||||
n_tiles_w1 = ((2 * n) + block_n - 1) // block_n
|
|
||||||
k_tiles_w1 = (k + block_k - 1) // block_k
|
|
||||||
n_tiles_w2 = (k + block_n - 1) // block_n
|
|
||||||
k_tiles_w2 = (n + block_k - 1) // block_k
|
|
||||||
|
|
||||||
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
|
|
||||||
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1),
|
|
||||||
device="cuda",
|
|
||||||
dtype=torch.float32)
|
|
||||||
w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2),
|
|
||||||
device="cuda",
|
|
||||||
dtype=torch.float32)
|
|
||||||
|
|
||||||
assert w1_s.shape == (e, (2 * n + 127) // 128, (k + 127) // 128)
|
|
||||||
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
|
|
||||||
|
|
||||||
for i in range(e):
|
|
||||||
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
|
|
||||||
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
|
|
||||||
|
|
||||||
return w1, w2, w1_s, w2_s
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@@ -132,6 +79,7 @@ class TestConfig:
|
|||||||
k: int
|
k: int
|
||||||
n: int
|
n: int
|
||||||
num_experts: int
|
num_experts: int
|
||||||
|
per_act_token_quant: bool
|
||||||
block_size: list[int]
|
block_size: list[int]
|
||||||
# configs for testing low-latency kernels
|
# configs for testing low-latency kernels
|
||||||
low_latency: bool
|
low_latency: bool
|
||||||
@@ -150,8 +98,7 @@ class TestTensors:
|
|||||||
def make(config: TestConfig, rank) -> "TestTensors":
|
def make(config: TestConfig, rank) -> "TestTensors":
|
||||||
|
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
topk, m, k, block_size = (config.topk, config.m, config.k,
|
topk, m, k = (config.topk, config.m, config.k)
|
||||||
config.block_size)
|
|
||||||
|
|
||||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||||
@@ -159,9 +106,7 @@ class TestTensors:
|
|||||||
rank_tokens = torch.randn(
|
rank_tokens = torch.randn(
|
||||||
(m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
|
(m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
|
||||||
rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
|
rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
|
||||||
|
rank_token_scales = None
|
||||||
block_k = block_size[1]
|
|
||||||
_, rank_token_scales = per_token_group_quant_fp8(rank_tokens, block_k)
|
|
||||||
|
|
||||||
topk_ids = torch.randint(
|
topk_ids = torch.randint(
|
||||||
low=0,
|
low=0,
|
||||||
@@ -201,10 +146,12 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
|||||||
q_dtype=q_dtype,
|
q_dtype=q_dtype,
|
||||||
block_shape=test_config.block_size)
|
block_shape=test_config.block_size)
|
||||||
|
|
||||||
fused_experts = BatchedDeepGemmExperts(max_num_tokens=max_tokens_per_rank,
|
fused_experts = BatchedDeepGemmExperts(
|
||||||
world_size=pgi.world_size,
|
max_num_tokens=max_tokens_per_rank,
|
||||||
dp_size=dp_size,
|
world_size=pgi.world_size,
|
||||||
block_shape=test_config.block_size)
|
dp_size=dp_size,
|
||||||
|
block_shape=test_config.block_size,
|
||||||
|
per_act_token_quant=test_config.per_act_token_quant)
|
||||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||||
fused_experts=fused_experts)
|
fused_experts=fused_experts)
|
||||||
return mk
|
return mk
|
||||||
@@ -426,6 +373,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
|
|||||||
"""
|
"""
|
||||||
Tests for High-Throughput DeepEP + DeepGemm integration.
|
Tests for High-Throughput DeepEP + DeepGemm integration.
|
||||||
"""
|
"""
|
||||||
|
import deep_gemm
|
||||||
|
|
||||||
m, n, k = mnk
|
m, n, k = mnk
|
||||||
current_platform.seed_everything(7)
|
current_platform.seed_everything(7)
|
||||||
@@ -442,6 +390,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
|
|||||||
k=k,
|
k=k,
|
||||||
n=n,
|
n=n,
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
|
per_act_token_quant=False,
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
low_latency=False,
|
low_latency=False,
|
||||||
use_fp8_dispatch=None)
|
use_fp8_dispatch=None)
|
||||||
@@ -474,10 +423,14 @@ USE_FP8_DISPATCH = [False]
|
|||||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||||
@requires_deep_ep
|
@requires_deep_ep
|
||||||
@requires_deep_gemm
|
@requires_deep_gemm
|
||||||
def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int,
|
def test_ll_deepep_deepgemm_moe(
|
||||||
int], num_experts: int, topk: int,
|
mnk: tuple[int, int, int],
|
||||||
use_fp8_dispatch: bool, block_size: list[int],
|
num_experts: int,
|
||||||
world_dp_size: tuple[int, int]):
|
topk: int,
|
||||||
|
use_fp8_dispatch: bool,
|
||||||
|
block_size: list[int],
|
||||||
|
world_dp_size: tuple[int, int],
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Tests for Low-Latency DeepEP + DeepGemm integration.
|
Tests for Low-Latency DeepEP + DeepGemm integration.
|
||||||
"""
|
"""
|
||||||
@@ -495,6 +448,7 @@ def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int,
|
|||||||
k=k,
|
k=k,
|
||||||
n=n,
|
n=n,
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
|
per_act_token_quant=False,
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
low_latency=True,
|
low_latency=True,
|
||||||
use_fp8_dispatch=use_fp8_dispatch,
|
use_fp8_dispatch=use_fp8_dispatch,
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import has_deep_ep
|
from vllm.utils import has_deep_ep
|
||||||
|
|
||||||
from .utils import ProcessGroupInfo, parallel_launch
|
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||||
|
|
||||||
if has_deep_ep():
|
if has_deep_ep():
|
||||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||||
@@ -31,7 +31,7 @@ if has_deep_ep():
|
|||||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||||
DeepEPLLPrepareAndFinalize)
|
DeepEPLLPrepareAndFinalize)
|
||||||
|
|
||||||
from .utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||||
|
|
||||||
requires_deep_ep = pytest.mark.skipif(
|
requires_deep_ep = pytest.mark.skipif(
|
||||||
not has_deep_ep(),
|
not has_deep_ep(),
|
||||||
@@ -102,10 +102,6 @@ class TestTensors:
|
|||||||
rank_tokens = torch.randn(
|
rank_tokens = torch.randn(
|
||||||
(config.m, config.k), device="cuda", dtype=token_dtype) / 10
|
(config.m, config.k), device="cuda", dtype=token_dtype) / 10
|
||||||
rank_token_scales = None
|
rank_token_scales = None
|
||||||
if config.dtype == torch.float8_e4m3fn:
|
|
||||||
# low_latency_mode kernels dont support per-token quant.
|
|
||||||
_, rank_token_scales = ops.scaled_fp8_quant(
|
|
||||||
rank_tokens, use_per_token_if_dynamic=not low_latency_mode)
|
|
||||||
|
|
||||||
topk = torch.randint(low=0,
|
topk = torch.randint(low=0,
|
||||||
high=config.num_experts,
|
high=config.num_experts,
|
||||||
@@ -121,11 +117,18 @@ class TestTensors:
|
|||||||
config=config)
|
config=config)
|
||||||
|
|
||||||
|
|
||||||
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
def make_modular_kernel(
|
||||||
low_latency_mode: bool, hidden_size: int, dp_size: int,
|
pg: ProcessGroup,
|
||||||
num_experts: int, num_local_experts: int,
|
pgi: ProcessGroupInfo,
|
||||||
q_dtype: Optional[torch.dtype],
|
low_latency_mode: bool,
|
||||||
use_fp8_dispatch: bool) -> FusedMoEModularKernel:
|
hidden_size: int,
|
||||||
|
dp_size: int,
|
||||||
|
num_experts: int,
|
||||||
|
num_local_experts: int,
|
||||||
|
q_dtype: Optional[torch.dtype],
|
||||||
|
use_fp8_dispatch: bool,
|
||||||
|
per_act_token_quant: bool,
|
||||||
|
) -> FusedMoEModularKernel:
|
||||||
|
|
||||||
is_quantized = q_dtype is not None
|
is_quantized = q_dtype is not None
|
||||||
|
|
||||||
@@ -152,6 +155,7 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
|||||||
deepep_ll_args = ll_args)
|
deepep_ll_args = ll_args)
|
||||||
|
|
||||||
if low_latency_mode:
|
if low_latency_mode:
|
||||||
|
assert not per_act_token_quant, "not supported in ll mode"
|
||||||
fused_experts = BatchedTritonExperts(
|
fused_experts = BatchedTritonExperts(
|
||||||
max_num_tokens=MAX_TOKENS_PER_RANK,
|
max_num_tokens=MAX_TOKENS_PER_RANK,
|
||||||
world_size=pgi.world_size,
|
world_size=pgi.world_size,
|
||||||
@@ -159,25 +163,37 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
|||||||
use_fp8_w8a8=is_quantized,
|
use_fp8_w8a8=is_quantized,
|
||||||
use_int8_w8a8=False,
|
use_int8_w8a8=False,
|
||||||
use_int8_w8a16=False,
|
use_int8_w8a16=False,
|
||||||
use_int4_w4a16=False)
|
use_int4_w4a16=False,
|
||||||
|
per_act_token_quant=False,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
fused_experts = TritonExperts(use_fp8_w8a8=is_quantized,
|
fused_experts = TritonExperts(
|
||||||
use_int8_w8a8=False,
|
use_fp8_w8a8=is_quantized,
|
||||||
use_int8_w8a16=False,
|
use_int8_w8a8=False,
|
||||||
use_int4_w4a16=False,
|
use_int8_w8a16=False,
|
||||||
per_channel_quant=False)
|
use_int4_w4a16=False,
|
||||||
|
per_act_token_quant=per_act_token_quant,
|
||||||
|
)
|
||||||
|
|
||||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||||
fused_experts=fused_experts)
|
fused_experts=fused_experts)
|
||||||
return mk
|
return mk
|
||||||
|
|
||||||
|
|
||||||
def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
def deep_ep_moe_impl(
|
||||||
low_latency_mode: bool, dp_size: int,
|
pg: ProcessGroup,
|
||||||
test_tensors: TestTensors, w1: torch.Tensor,
|
pgi: ProcessGroupInfo,
|
||||||
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
|
low_latency_mode: bool,
|
||||||
w2_scale: Optional[torch.Tensor], num_experts: int,
|
dp_size: int,
|
||||||
use_fp8_dispatch: bool) -> torch.Tensor:
|
test_tensors: TestTensors,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
w1_scale: Optional[torch.Tensor],
|
||||||
|
w2_scale: Optional[torch.Tensor],
|
||||||
|
num_experts: int,
|
||||||
|
use_fp8_dispatch: bool,
|
||||||
|
per_act_token_quant: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
num_local_experts = w1.size(0)
|
num_local_experts = w1.size(0)
|
||||||
|
|
||||||
@@ -199,11 +215,9 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
|||||||
q_dtype = torch.float8_e4m3fn
|
q_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
# Make modular kernel
|
# Make modular kernel
|
||||||
mk: FusedMoEModularKernel = make_modular_kernel(pg, pgi, low_latency_mode,
|
mk: FusedMoEModularKernel = make_modular_kernel(
|
||||||
hidden_size, dp_size,
|
pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts,
|
||||||
num_experts,
|
num_local_experts, q_dtype, use_fp8_dispatch, per_act_token_quant)
|
||||||
num_local_experts, q_dtype,
|
|
||||||
use_fp8_dispatch)
|
|
||||||
|
|
||||||
out_hidden_states = torch.empty_like(test_tensors.rank_tokens)
|
out_hidden_states = torch.empty_like(test_tensors.rank_tokens)
|
||||||
total_num_tokens = test_tensors.rank_tokens.size(0)
|
total_num_tokens = test_tensors.rank_tokens.size(0)
|
||||||
@@ -257,9 +271,15 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
|||||||
return out_hidden_states
|
return out_hidden_states
|
||||||
|
|
||||||
|
|
||||||
def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor,
|
def torch_moe_impl(
|
||||||
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
|
test_tensors: TestTensors,
|
||||||
w2_scale: Optional[torch.Tensor], using_fp8_dispatch: bool):
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
w1_scale: Optional[torch.Tensor],
|
||||||
|
w2_scale: Optional[torch.Tensor],
|
||||||
|
using_fp8_dispatch: bool,
|
||||||
|
per_act_token_quant: bool,
|
||||||
|
):
|
||||||
|
|
||||||
a, topk_ids, topk_weights = (test_tensors.rank_tokens, test_tensors.topk,
|
a, topk_ids, topk_weights = (test_tensors.rank_tokens, test_tensors.topk,
|
||||||
test_tensors.topk_weights)
|
test_tensors.topk_weights)
|
||||||
@@ -267,6 +287,7 @@ def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor,
|
|||||||
# The DeepEP implementation is requested to dispatch using FP8.
|
# The DeepEP implementation is requested to dispatch using FP8.
|
||||||
# For numerical stability for testing, emulate the fp8 dispatch by
|
# For numerical stability for testing, emulate the fp8 dispatch by
|
||||||
# blockwise quant and de-quant.
|
# blockwise quant and de-quant.
|
||||||
|
assert not per_act_token_quant
|
||||||
a = test_tensors.rank_tokens
|
a = test_tensors.rank_tokens
|
||||||
aq, aq_scale = per_token_group_quant_fp8(a, 128)
|
aq, aq_scale = per_token_group_quant_fp8(a, 128)
|
||||||
a = (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)).view(
|
a = (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)).view(
|
||||||
@@ -310,6 +331,7 @@ def _deep_ep_moe(
|
|||||||
w1_scale: Optional[torch.Tensor],
|
w1_scale: Optional[torch.Tensor],
|
||||||
w2_scale: Optional[torch.Tensor],
|
w2_scale: Optional[torch.Tensor],
|
||||||
use_fp8_dispatch: bool,
|
use_fp8_dispatch: bool,
|
||||||
|
per_act_token_quant: bool,
|
||||||
):
|
):
|
||||||
|
|
||||||
if not low_latency_mode:
|
if not low_latency_mode:
|
||||||
@@ -331,7 +353,8 @@ def _deep_ep_moe(
|
|||||||
with set_current_vllm_config(VllmConfig()):
|
with set_current_vllm_config(VllmConfig()):
|
||||||
# Reference
|
# Reference
|
||||||
torch_combined = torch_moe_impl(test_tensors, w1, w2, w1_scale,
|
torch_combined = torch_moe_impl(test_tensors, w1, w2, w1_scale,
|
||||||
w2_scale, use_fp8_dispatch)
|
w2_scale, use_fp8_dispatch,
|
||||||
|
per_act_token_quant)
|
||||||
|
|
||||||
# Splice experts for this rank.
|
# Splice experts for this rank.
|
||||||
num_local_experts = config.num_experts // pgi.world_size
|
num_local_experts = config.num_experts // pgi.world_size
|
||||||
@@ -356,6 +379,7 @@ def _deep_ep_moe(
|
|||||||
w2_scale_ep,
|
w2_scale_ep,
|
||||||
config.num_experts,
|
config.num_experts,
|
||||||
use_fp8_dispatch,
|
use_fp8_dispatch,
|
||||||
|
per_act_token_quant,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
@@ -384,10 +408,16 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
|
|||||||
@pytest.mark.parametrize("num_experts", [32])
|
@pytest.mark.parametrize("num_experts", [32])
|
||||||
@pytest.mark.parametrize("topk", [6])
|
@pytest.mark.parametrize("topk", [6])
|
||||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||||
|
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
||||||
@requires_deep_ep
|
@requires_deep_ep
|
||||||
def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
|
def test_deep_ep_moe(
|
||||||
num_experts: int, topk: int, world_dp_size: tuple[int,
|
dtype: torch.dtype,
|
||||||
int]):
|
mnk: tuple[int, int, int],
|
||||||
|
num_experts: int,
|
||||||
|
topk: int,
|
||||||
|
world_dp_size: tuple[int, int],
|
||||||
|
per_act_token_quant: bool,
|
||||||
|
):
|
||||||
low_latency_mode = False
|
low_latency_mode = False
|
||||||
use_fp8_dispatch = False
|
use_fp8_dispatch = False
|
||||||
m, n, k = mnk
|
m, n, k = mnk
|
||||||
@@ -404,7 +434,8 @@ def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
|
|||||||
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
|
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
|
||||||
|
|
||||||
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
|
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
|
||||||
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch)
|
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch,
|
||||||
|
per_act_token_quant)
|
||||||
|
|
||||||
|
|
||||||
MNKs = [
|
MNKs = [
|
||||||
@@ -454,4 +485,5 @@ def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
|
|||||||
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
|
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
|
||||||
|
|
||||||
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
|
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
|
||||||
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch)
|
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch,
|
||||||
|
False)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
|||||||
import vllm.model_executor.layers.fused_moe # noqa
|
import vllm.model_executor.layers.fused_moe # noqa
|
||||||
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
|
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.distributed.parallel_state import init_distributed_environment
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
@@ -142,6 +143,10 @@ def test_fused_moe(
|
|||||||
# Setup test data
|
# Setup test data
|
||||||
#
|
#
|
||||||
|
|
||||||
|
#
|
||||||
|
# Setup test data
|
||||||
|
#
|
||||||
|
|
||||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||||
@@ -169,7 +174,7 @@ def test_fused_moe(
|
|||||||
use_int8_w8a8=False,
|
use_int8_w8a8=False,
|
||||||
use_int8_w8a16=False,
|
use_int8_w8a16=False,
|
||||||
use_int4_w4a16=False,
|
use_int4_w4a16=False,
|
||||||
per_channel_quant=False,
|
per_act_token_quant=False,
|
||||||
block_shape=None)
|
block_shape=None)
|
||||||
|
|
||||||
def m_fused_moe(
|
def m_fused_moe(
|
||||||
@@ -365,6 +370,13 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
|
|||||||
if dtype == torch.float32:
|
if dtype == torch.float32:
|
||||||
pytest.skip("AITER ROCm test skip for float32")
|
pytest.skip("AITER ROCm test skip for float32")
|
||||||
|
|
||||||
|
monkeypatch.setenv('RANK', "0")
|
||||||
|
monkeypatch.setenv('LOCAL_RANK', "0")
|
||||||
|
monkeypatch.setenv('WORLD_SIZE', "1")
|
||||||
|
monkeypatch.setenv('MASTER_ADDR', 'localhost')
|
||||||
|
monkeypatch.setenv('MASTER_PORT', '12345')
|
||||||
|
init_distributed_environment()
|
||||||
|
|
||||||
# Instantiate our and huggingface's MoE blocks
|
# Instantiate our and huggingface's MoE blocks
|
||||||
vllm_config.compilation_config.static_forward_context = dict()
|
vllm_config.compilation_config.static_forward_context = dict()
|
||||||
with (set_current_vllm_config(vllm_config),
|
with (set_current_vllm_config(vllm_config),
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
if not current_platform.has_device_capability(100):
|
if not current_platform.has_device_capability(100):
|
||||||
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
|
pytest.skip("Nvfp4 Requires compute capability of 10 or above.",
|
||||||
allow_module_level=True)
|
allow_module_level=True)
|
||||||
|
|
||||||
MNK_FACTORS = [
|
MNK_FACTORS = [
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
|||||||
FusedMoEModularKernel)
|
FusedMoEModularKernel)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .utils import ProcessGroupInfo, parallel_launch
|
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from pplx_kernels import AllToAll
|
from pplx_kernels import AllToAll
|
||||||
@@ -93,7 +93,7 @@ def pplx_cutlass_moe(
|
|||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
experts_per_token=topk,
|
experts_per_token=topk,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=pgi.world_size,
|
world_size=world_size,
|
||||||
dp_size=dp_size,
|
dp_size=dp_size,
|
||||||
hidden_dim=hidden_dim,
|
hidden_dim=hidden_dim,
|
||||||
hidden_dim_bytes=hidden_dim, # because a.dtype.itemsize == 1
|
hidden_dim_bytes=hidden_dim, # because a.dtype.itemsize == 1
|
||||||
@@ -118,8 +118,6 @@ def pplx_cutlass_moe(
|
|||||||
pgi.world_size,
|
pgi.world_size,
|
||||||
rank,
|
rank,
|
||||||
dp_size,
|
dp_size,
|
||||||
quant_dtype=torch.float8_e4m3fn,
|
|
||||||
per_act_token=per_act_token,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size,
|
experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size,
|
||||||
|
|||||||
@@ -18,18 +18,20 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
has_pplx = False
|
has_pplx = False
|
||||||
|
|
||||||
|
from tests.kernels.moe.utils import make_test_weights, naive_batched_moe
|
||||||
from tests.kernels.utils import torch_experts
|
from tests.kernels.utils import torch_experts
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
from vllm.model_executor.layers.fused_moe import override_config
|
from vllm.model_executor.layers.fused_moe import fused_topk, override_config
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||||
BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts)
|
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk,
|
from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config
|
||||||
get_default_config)
|
|
||||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||||
FusedMoEModularKernel)
|
FusedMoEModularKernel)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import round_up
|
||||||
|
|
||||||
from .utils import ProcessGroupInfo, parallel_launch
|
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||||
|
|
||||||
requires_pplx = pytest.mark.skipif(
|
requires_pplx = pytest.mark.skipif(
|
||||||
not has_pplx,
|
not has_pplx,
|
||||||
@@ -144,25 +146,6 @@ def torch_batched_moe(
|
|||||||
return torch_finalize(out, topk_weight, topk_ids)
|
return torch_finalize(out, topk_weight, topk_ids)
|
||||||
|
|
||||||
|
|
||||||
def batched_moe(
|
|
||||||
a: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
topk_weight: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
num_experts = w1.shape[0]
|
|
||||||
|
|
||||||
fused_experts = FusedMoEModularKernel(
|
|
||||||
BatchedPrepareAndFinalize(max_num_tokens=a.shape[0],
|
|
||||||
world_size=1,
|
|
||||||
dp_size=1,
|
|
||||||
rank=0),
|
|
||||||
BatchedExperts(max_num_tokens=a.shape[0], dp_size=1, world_size=1))
|
|
||||||
|
|
||||||
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
||||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||||
@pytest.mark.parametrize("k", [128, 512, 1024])
|
@pytest.mark.parametrize("k", [128, 512, 1024])
|
||||||
@@ -188,7 +171,7 @@ def test_fused_moe_batched_experts(
|
|||||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||||
baseline_output = torch_experts(a, w1, w2, topk_weight, topk_ids)
|
baseline_output = torch_experts(a, w1, w2, topk_weight, topk_ids)
|
||||||
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
|
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
|
||||||
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids)
|
batched_output = naive_batched_moe(a, w1, w2, topk_weight, topk_ids)
|
||||||
|
|
||||||
torch.testing.assert_close(baseline_output,
|
torch.testing.assert_close(baseline_output,
|
||||||
torch_output,
|
torch_output,
|
||||||
@@ -226,7 +209,6 @@ def pplx_prepare_finalize(
|
|||||||
|
|
||||||
topk = topk_ids.shape[1]
|
topk = topk_ids.shape[1]
|
||||||
num_tokens, hidden_dim = a.shape
|
num_tokens, hidden_dim = a.shape
|
||||||
block_size = 128
|
|
||||||
device = pgi.device
|
device = pgi.device
|
||||||
rank = pgi.rank
|
rank = pgi.rank
|
||||||
world_size = pgi.world_size
|
world_size = pgi.world_size
|
||||||
@@ -241,9 +223,7 @@ def pplx_prepare_finalize(
|
|||||||
dp_size=dp_size,
|
dp_size=dp_size,
|
||||||
hidden_dim=hidden_dim,
|
hidden_dim=hidden_dim,
|
||||||
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
|
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
|
||||||
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
|
hidden_dim_scale_bytes=0,
|
||||||
((hidden_dim + block_size - 1) // block_size *
|
|
||||||
torch.float32.itemsize)),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if group_name is None:
|
if group_name is None:
|
||||||
@@ -260,7 +240,6 @@ def pplx_prepare_finalize(
|
|||||||
world_size,
|
world_size,
|
||||||
rank,
|
rank,
|
||||||
dp_size,
|
dp_size,
|
||||||
a.dtype,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
|
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
|
||||||
@@ -276,6 +255,7 @@ def pplx_prepare_finalize(
|
|||||||
num_experts,
|
num_experts,
|
||||||
None,
|
None,
|
||||||
False,
|
False,
|
||||||
|
FusedMoEQuantConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
b_a = b_a * 1.5
|
b_a = b_a * 1.5
|
||||||
@@ -350,6 +330,7 @@ def _pplx_prepare_finalize(
|
|||||||
# TODO (bnell): this test point does not work for odd M due to how the test is
|
# TODO (bnell): this test point does not work for odd M due to how the test is
|
||||||
# written, not due to limitations of the pplx kernels. The pplx_moe
|
# written, not due to limitations of the pplx kernels. The pplx_moe
|
||||||
# test below is able to deal with odd M.
|
# test below is able to deal with odd M.
|
||||||
|
# TODO (bnell) add fp8 tests
|
||||||
@pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS)
|
@pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS)
|
||||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||||
@pytest.mark.parametrize("topk", TOP_KS)
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
@@ -386,18 +367,31 @@ def pplx_moe(
|
|||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_weight: torch.Tensor,
|
topk_weight: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
qtype: Optional[torch.dtype] = None,
|
||||||
|
per_act_token_quant=False,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
use_compile: bool = False,
|
use_compile: bool = False,
|
||||||
use_cudagraphs: bool = True,
|
use_cudagraphs: bool = True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||||
PplxPrepareAndFinalize)
|
PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes)
|
||||||
|
|
||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", rank)
|
||||||
hidden_dim = a.shape[1]
|
hidden_dim = a.shape[1]
|
||||||
num_experts = w1.shape[0]
|
num_experts = w1.shape[0]
|
||||||
block_size = 128
|
|
||||||
topk = topk_ids.shape[1]
|
topk = topk_ids.shape[1]
|
||||||
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
|
max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), 64)
|
||||||
|
|
||||||
|
hidden_dim_bytes, scale_bytes = pplx_hidden_dim_scale_bytes(
|
||||||
|
max_num_tokens,
|
||||||
|
hidden_dim,
|
||||||
|
a.dtype,
|
||||||
|
qtype,
|
||||||
|
per_act_token_quant=per_act_token_quant,
|
||||||
|
block_shape=block_shape,
|
||||||
|
)
|
||||||
|
|
||||||
args = dict(
|
args = dict(
|
||||||
max_num_tokens=max_num_tokens,
|
max_num_tokens=max_num_tokens,
|
||||||
@@ -407,10 +401,8 @@ def pplx_moe(
|
|||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
dp_size=dp_size,
|
dp_size=dp_size,
|
||||||
hidden_dim=hidden_dim,
|
hidden_dim=hidden_dim,
|
||||||
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
|
hidden_dim_bytes=hidden_dim_bytes,
|
||||||
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
|
hidden_dim_scale_bytes=scale_bytes,
|
||||||
((hidden_dim + block_size - 1) // block_size *
|
|
||||||
torch.float32.itemsize)),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if group_name is None:
|
if group_name is None:
|
||||||
@@ -429,9 +421,11 @@ def pplx_moe(
|
|||||||
dp_size,
|
dp_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
experts = BatchedTritonExperts(max_num_tokens=a.shape[0],
|
experts = BatchedTritonExperts(max_num_tokens=max_num_tokens,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
dp_size=dp_size)
|
dp_size=dp_size,
|
||||||
|
use_fp8_w8a8=qtype == torch.float8_e4m3fn,
|
||||||
|
block_shape=block_shape)
|
||||||
|
|
||||||
fused_experts = FusedMoEModularKernel(
|
fused_experts = FusedMoEModularKernel(
|
||||||
prepare_finalize,
|
prepare_finalize,
|
||||||
@@ -447,6 +441,13 @@ def pplx_moe(
|
|||||||
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
|
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
|
||||||
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
|
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
|
||||||
|
|
||||||
|
if w1_scale is not None:
|
||||||
|
w1_scale_chunk = chunk_by_rank(w1_scale, rank, world_size).to(device)
|
||||||
|
w2_scale_chunk = chunk_by_rank(w2_scale, rank, world_size).to(device)
|
||||||
|
else:
|
||||||
|
w1_scale_chunk = None
|
||||||
|
w2_scale_chunk = None
|
||||||
|
|
||||||
# Note: for now use_compile will error out if the problem size is
|
# Note: for now use_compile will error out if the problem size is
|
||||||
# large enough to trigger chunking. I'm leaving the flag and
|
# large enough to trigger chunking. I'm leaving the flag and
|
||||||
# setup code in case we are able to revisit this later.
|
# setup code in case we are able to revisit this later.
|
||||||
@@ -465,6 +466,8 @@ def pplx_moe(
|
|||||||
w2_chunk,
|
w2_chunk,
|
||||||
chunk_topk_weight,
|
chunk_topk_weight,
|
||||||
chunk_topk_ids,
|
chunk_topk_ids,
|
||||||
|
w1_scale=w1_scale_chunk,
|
||||||
|
w2_scale=w2_scale_chunk,
|
||||||
global_num_experts=num_experts)
|
global_num_experts=num_experts)
|
||||||
|
|
||||||
if use_cudagraphs:
|
if use_cudagraphs:
|
||||||
@@ -477,6 +480,8 @@ def pplx_moe(
|
|||||||
w2_chunk,
|
w2_chunk,
|
||||||
chunk_topk_weight,
|
chunk_topk_weight,
|
||||||
chunk_topk_ids,
|
chunk_topk_ids,
|
||||||
|
w1_scale=w1_scale_chunk,
|
||||||
|
w2_scale=w2_scale_chunk,
|
||||||
global_num_experts=num_experts)
|
global_num_experts=num_experts)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@@ -505,9 +510,9 @@ def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
|
|||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
experts = BatchedExperts(max_num_tokens=a.shape[0],
|
experts = NaiveBatchedExperts(max_num_tokens=a.shape[0],
|
||||||
world_size=1,
|
world_size=1,
|
||||||
dp_size=1)
|
dp_size=1)
|
||||||
|
|
||||||
fused_experts = FusedMoEModularKernel(
|
fused_experts = FusedMoEModularKernel(
|
||||||
prepare_finalize,
|
prepare_finalize,
|
||||||
@@ -539,7 +544,12 @@ def _pplx_moe(
|
|||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
score: torch.Tensor,
|
score: torch.Tensor,
|
||||||
topk: int,
|
topk: int,
|
||||||
use_internode: bool,
|
w1_s: Optional[torch.Tensor] = None,
|
||||||
|
w2_s: Optional[torch.Tensor] = None,
|
||||||
|
qtype: Optional[torch.dtype] = None,
|
||||||
|
per_act_token_quant: bool = False,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
use_internode: bool = False,
|
||||||
):
|
):
|
||||||
if use_internode:
|
if use_internode:
|
||||||
uid = nvshmem_get_unique_id(
|
uid = nvshmem_get_unique_id(
|
||||||
@@ -557,11 +567,28 @@ def _pplx_moe(
|
|||||||
|
|
||||||
moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)
|
moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)
|
||||||
|
|
||||||
|
device = torch.device("cuda", pgi.rank)
|
||||||
|
a = a.to(device)
|
||||||
|
w1 = w1.to(device)
|
||||||
|
w2 = w2.to(device)
|
||||||
|
w1_s = w1_s.to(device) if w1_s is not None else None
|
||||||
|
w2_s = w2_s.to(device) if w2_s is not None else None
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config), override_config(moe_config):
|
with set_current_vllm_config(vllm_config), override_config(moe_config):
|
||||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||||
torch_output = torch_experts(a, w1, w2, topk_weight, topk_ids)
|
torch_output = torch_experts(a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weight,
|
||||||
|
topk_ids,
|
||||||
|
w1_scale=w1_s,
|
||||||
|
w2_scale=w2_s,
|
||||||
|
quant_dtype=qtype,
|
||||||
|
per_act_token_quant=per_act_token_quant,
|
||||||
|
block_shape=block_shape)
|
||||||
pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size,
|
pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size,
|
||||||
a, w1, w2, topk_weight, topk_ids)
|
a, w1, w2, topk_weight, topk_ids, w1_s, w2_s,
|
||||||
|
qtype, per_act_token_quant, block_shape)
|
||||||
# TODO (bnell): fix + re-enable
|
# TODO (bnell): fix + re-enable
|
||||||
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
|
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
|
||||||
# topk_ids)
|
# topk_ids)
|
||||||
@@ -581,6 +608,8 @@ def _pplx_moe(
|
|||||||
@pytest.mark.parametrize("topk", TOP_KS)
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||||
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
|
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
|
||||||
|
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
||||||
|
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
|
||||||
@pytest.mark.parametrize("use_internode", [False])
|
@pytest.mark.parametrize("use_internode", [False])
|
||||||
@requires_pplx
|
@requires_pplx
|
||||||
def test_pplx_moe(
|
def test_pplx_moe(
|
||||||
@@ -589,15 +618,33 @@ def test_pplx_moe(
|
|||||||
topk: int,
|
topk: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
world_dp_size: tuple[int, int],
|
world_dp_size: tuple[int, int],
|
||||||
|
per_act_token_quant: bool,
|
||||||
|
block_shape: Optional[list[int]],
|
||||||
use_internode: bool,
|
use_internode: bool,
|
||||||
):
|
):
|
||||||
current_platform.seed_everything(7)
|
current_platform.seed_everything(7)
|
||||||
m, n, k = mnk
|
m, n, k = mnk
|
||||||
world_size, dp_size = world_dp_size
|
world_size, dp_size = world_dp_size
|
||||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
|
||||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
if dtype == torch.float8_e4m3fn:
|
||||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
use_fp8_w8a8 = True
|
||||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
quant_dtype = dtype
|
||||||
|
else:
|
||||||
|
use_fp8_w8a8 = False
|
||||||
|
quant_dtype = None
|
||||||
|
|
||||||
|
if not use_fp8_w8a8 and per_act_token_quant and block_shape is not None:
|
||||||
|
pytest.skip("Skip quantization test for non-quantized type")
|
||||||
|
|
||||||
|
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||||
|
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
_, w1, w1_s, _, w2, w2_s = make_test_weights(e,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
quant_dtype=quant_dtype,
|
||||||
|
block_shape=block_shape)
|
||||||
|
|
||||||
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk,
|
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk,
|
||||||
|
w1_s, w2_s, quant_dtype, per_act_token_quant, block_shape,
|
||||||
use_internode)
|
use_internode)
|
||||||
|
|||||||
@@ -1,194 +1,249 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
"""
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
DeepEP test utilities
|
from typing import Optional
|
||||||
"""
|
|
||||||
import dataclasses
|
|
||||||
import importlib
|
|
||||||
import os
|
|
||||||
import traceback
|
|
||||||
from typing import Callable, Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed import ProcessGroup
|
|
||||||
from torch.multiprocessing import (
|
|
||||||
spawn) # pyright: ignore[reportPrivateImportUsage]
|
|
||||||
from typing_extensions import Concatenate, ParamSpec
|
|
||||||
|
|
||||||
from vllm.utils import get_open_port
|
import vllm._custom_ops as ops
|
||||||
|
from tests.kernels.quant_utils import (per_block_cast_to_fp8,
|
||||||
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
|
per_block_cast_to_int8)
|
||||||
if has_deep_ep:
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||||
DeepEPHTPrepareAndFinalize)
|
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
|
||||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||||
DeepEPLLPrepareAndFinalize)
|
FusedMoEModularKernel)
|
||||||
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
## Parallel Processes Utils
|
moe_kernel_quantize_input)
|
||||||
|
from vllm.utils import round_up
|
||||||
P = ParamSpec("P")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
def triton_moe(
|
||||||
class ProcessGroupInfo:
|
a: torch.Tensor,
|
||||||
world_size: int
|
w1: torch.Tensor,
|
||||||
world_local_size: int
|
w2: torch.Tensor,
|
||||||
rank: int
|
topk_weight: torch.Tensor,
|
||||||
node_rank: int
|
topk_ids: torch.Tensor,
|
||||||
local_rank: int
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
device: torch.device
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
quant_dtype: Optional[torch.dtype] = None,
|
||||||
|
per_act_token_quant=False,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return fused_experts(a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weight,
|
||||||
|
topk_ids,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
per_channel_quant=per_act_token_quant,
|
||||||
|
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
||||||
|
block_shape=block_shape)
|
||||||
|
|
||||||
|
|
||||||
def _worker_parallel_launch(
|
def batched_moe(
|
||||||
local_rank: int,
|
a: torch.Tensor,
|
||||||
world_size: int,
|
w1: torch.Tensor,
|
||||||
world_local_size: int,
|
w2: torch.Tensor,
|
||||||
node_rank: int,
|
topk_weight: torch.Tensor,
|
||||||
init_method: str,
|
topk_ids: torch.Tensor,
|
||||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
*args: P.args,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
**kwargs: P.kwargs,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
) -> None:
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
rank = node_rank * world_local_size + local_rank
|
quant_dtype: Optional[torch.dtype] = None,
|
||||||
torch.cuda.set_device(local_rank)
|
per_act_token_quant: bool = False,
|
||||||
device = torch.device("cuda", local_rank)
|
block_shape: Optional[list[int]] = None,
|
||||||
torch.distributed.init_process_group(
|
) -> torch.Tensor:
|
||||||
backend="cpu:gloo,cuda:nccl",
|
max_num_tokens = round_up(a.shape[0], 64)
|
||||||
init_method=init_method,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
device_id=device,
|
|
||||||
)
|
|
||||||
barrier = torch.tensor([rank], device=device)
|
|
||||||
torch.distributed.all_reduce(barrier)
|
|
||||||
|
|
||||||
try:
|
fused_experts = FusedMoEModularKernel(
|
||||||
worker(
|
BatchedPrepareAndFinalize(max_num_tokens,
|
||||||
ProcessGroupInfo(
|
world_size=1,
|
||||||
world_size=world_size,
|
dp_size=1,
|
||||||
world_local_size=world_local_size,
|
rank=0),
|
||||||
rank=rank,
|
BatchedTritonExperts(
|
||||||
node_rank=node_rank,
|
max_num_tokens=max_num_tokens,
|
||||||
local_rank=local_rank,
|
world_size=1,
|
||||||
device=device,
|
dp_size=1,
|
||||||
),
|
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
||||||
*args,
|
per_act_token_quant=per_act_token_quant,
|
||||||
**kwargs,
|
block_shape=block_shape,
|
||||||
)
|
),
|
||||||
except Exception as ex:
|
|
||||||
print(ex)
|
|
||||||
traceback.print_exc()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
torch.distributed.destroy_process_group()
|
|
||||||
|
|
||||||
|
|
||||||
def parallel_launch(
|
|
||||||
world_size: int,
|
|
||||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
|
||||||
*args: P.args,
|
|
||||||
**kwargs: P.kwargs,
|
|
||||||
) -> None:
|
|
||||||
assert not kwargs
|
|
||||||
spawn(
|
|
||||||
_worker_parallel_launch,
|
|
||||||
args=(
|
|
||||||
world_size,
|
|
||||||
world_size,
|
|
||||||
0,
|
|
||||||
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
|
|
||||||
worker,
|
|
||||||
) + args,
|
|
||||||
nprocs=world_size,
|
|
||||||
join=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return fused_experts(a,
|
||||||
## DeepEP specific utils
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weight,
|
||||||
|
topk_ids,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
def naive_batched_moe(
|
||||||
class DeepEPHTArgs:
|
a: torch.Tensor,
|
||||||
num_local_experts: int
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weight: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
quant_dtype: Optional[torch.dtype] = None,
|
||||||
|
per_act_token_quant: bool = False,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
max_num_tokens = round_up(a.shape[0], 64)
|
||||||
|
|
||||||
|
fused_experts = FusedMoEModularKernel(
|
||||||
@dataclasses.dataclass
|
BatchedPrepareAndFinalize(max_num_tokens,
|
||||||
class DeepEPLLArgs:
|
world_size=1,
|
||||||
max_tokens_per_rank: int
|
dp_size=1,
|
||||||
hidden_size: int
|
rank=0),
|
||||||
num_experts: int
|
NaiveBatchedExperts(
|
||||||
use_fp8_dispatch: bool
|
max_num_tokens=max_num_tokens,
|
||||||
|
dp_size=1,
|
||||||
|
world_size=1,
|
||||||
def make_deepep_ht_a2a(pg: ProcessGroup,
|
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
||||||
pgi: ProcessGroupInfo,
|
per_act_token_quant=per_act_token_quant,
|
||||||
dp_size: int,
|
block_shape=block_shape,
|
||||||
ht_args: DeepEPHTArgs,
|
),
|
||||||
q_dtype: Optional[torch.dtype] = None,
|
|
||||||
block_shape: Optional[list[int]] = None):
|
|
||||||
|
|
||||||
import deep_ep
|
|
||||||
|
|
||||||
# high throughput a2a
|
|
||||||
num_nvl_bytes = 1024 * 1024 * 1024 # 1GB
|
|
||||||
num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1
|
|
||||||
buffer = deep_ep.Buffer(group=pg,
|
|
||||||
num_nvl_bytes=num_nvl_bytes,
|
|
||||||
num_rdma_bytes=num_rdma_bytes,
|
|
||||||
low_latency_mode=low_latency_mode,
|
|
||||||
num_qps_per_rank=num_qps_per_rank)
|
|
||||||
return DeepEPHTPrepareAndFinalize(buffer=buffer,
|
|
||||||
world_size=pgi.world_size,
|
|
||||||
rank=pgi.rank,
|
|
||||||
dp_size=dp_size,
|
|
||||||
rank_expert_offset=pgi.rank *
|
|
||||||
ht_args.num_local_experts,
|
|
||||||
quant_dtype=q_dtype,
|
|
||||||
block_shape=block_shape)
|
|
||||||
|
|
||||||
|
|
||||||
def make_deepep_ll_a2a(pg: ProcessGroup,
|
|
||||||
pgi: ProcessGroupInfo,
|
|
||||||
dp_size: int,
|
|
||||||
deepep_ll_args: DeepEPLLArgs,
|
|
||||||
q_dtype: Optional[torch.dtype] = None,
|
|
||||||
block_shape: Optional[list[int]] = None):
|
|
||||||
|
|
||||||
import deep_ep
|
|
||||||
|
|
||||||
# low-latency a2a
|
|
||||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
|
||||||
deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size,
|
|
||||||
pgi.world_size, deepep_ll_args.num_experts)
|
|
||||||
|
|
||||||
buffer = deep_ep.Buffer(group=pg,
|
|
||||||
num_rdma_bytes=num_rdma_bytes,
|
|
||||||
low_latency_mode=True,
|
|
||||||
num_qps_per_rank=deepep_ll_args.num_experts //
|
|
||||||
pgi.world_size)
|
|
||||||
|
|
||||||
return DeepEPLLPrepareAndFinalize(
|
|
||||||
buffer=buffer,
|
|
||||||
world_size=pgi.world_size,
|
|
||||||
dp_size=dp_size,
|
|
||||||
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
|
|
||||||
quant_dtype=q_dtype,
|
|
||||||
block_shape=block_shape,
|
|
||||||
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return fused_experts(a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weight,
|
||||||
|
topk_ids,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale)
|
||||||
|
|
||||||
def make_deepep_a2a(pg: ProcessGroup,
|
|
||||||
pgi: ProcessGroupInfo,
|
|
||||||
dp_size: int,
|
|
||||||
deepep_ht_args: Optional[DeepEPHTArgs],
|
|
||||||
deepep_ll_args: Optional[DeepEPLLArgs],
|
|
||||||
q_dtype: Optional[torch.dtype] = None,
|
|
||||||
block_shape: Optional[list[int]] = None):
|
|
||||||
if deepep_ht_args is not None:
|
|
||||||
assert deepep_ll_args is None
|
|
||||||
return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype,
|
|
||||||
block_shape)
|
|
||||||
|
|
||||||
assert deepep_ll_args is not None
|
def chunk_scales(scales: Optional[torch.Tensor], start: int,
|
||||||
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype,
|
end: int) -> Optional[torch.Tensor]:
|
||||||
block_shape)
|
if scales is not None:
|
||||||
|
if scales.numel() == 1:
|
||||||
|
return scales
|
||||||
|
else:
|
||||||
|
return scales[start:end]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def make_quantized_test_activations(
|
||||||
|
E: int,
|
||||||
|
m: int,
|
||||||
|
k: int,
|
||||||
|
in_dtype: torch.dtype,
|
||||||
|
quant_dtype: Optional[torch.dtype] = None,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
per_act_token_quant: bool = False,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10
|
||||||
|
a_q = a
|
||||||
|
a_scale = None
|
||||||
|
|
||||||
|
if quant_dtype is not None:
|
||||||
|
assert (quant_dtype == torch.float8_e4m3fn
|
||||||
|
or quant_dtype == torch.int8), "only fp8/int8 supported"
|
||||||
|
a_q = torch.zeros_like(a, dtype=quant_dtype)
|
||||||
|
a_scale_l = [None] * E
|
||||||
|
for e in range(E):
|
||||||
|
a_q[e], a_scale_l[e] = moe_kernel_quantize_input(
|
||||||
|
a[e], None, quant_dtype, per_act_token_quant, block_shape)
|
||||||
|
a_scale = torch.stack(a_scale_l)
|
||||||
|
|
||||||
|
if not per_act_token_quant and block_shape is None:
|
||||||
|
a_scale = a_scale.view(E, 1, 1)
|
||||||
|
|
||||||
|
return a, a_q, a_scale
|
||||||
|
|
||||||
|
|
||||||
|
def moe_quantize_weights(
|
||||||
|
w: torch.Tensor,
|
||||||
|
w_s: Optional[torch.Tensor],
|
||||||
|
quant_dtype: Optional[torch.dtype],
|
||||||
|
per_token_quant: bool,
|
||||||
|
block_shape: Optional[list[int]],
|
||||||
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
assert (quant_dtype == torch.float8_e4m3fn
|
||||||
|
or quant_dtype == torch.int8), "only fp8/int8 supported"
|
||||||
|
|
||||||
|
if block_shape is not None:
|
||||||
|
assert not per_token_quant
|
||||||
|
if quant_dtype == torch.int8:
|
||||||
|
w, w_s = per_block_cast_to_int8(w, block_shape)
|
||||||
|
else:
|
||||||
|
w, w_s = per_block_cast_to_fp8(w, block_shape)
|
||||||
|
else:
|
||||||
|
if quant_dtype == torch.int8:
|
||||||
|
w, w_s = ops.scaled_int8_quant(
|
||||||
|
w, w_s, use_per_token_if_dynamic=per_token_quant)
|
||||||
|
else:
|
||||||
|
w, w_s = ops.scaled_fp8_quant(
|
||||||
|
w, w_s, use_per_token_if_dynamic=per_token_quant)
|
||||||
|
|
||||||
|
return w, w_s
|
||||||
|
|
||||||
|
|
||||||
|
def make_test_weight(
|
||||||
|
e: int,
|
||||||
|
rows: int,
|
||||||
|
cols: int,
|
||||||
|
in_dtype: torch.dtype = torch.bfloat16,
|
||||||
|
quant_dtype: Optional[torch.dtype] = None,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
per_act_token_quant: bool = False,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
|
||||||
|
|
||||||
|
if quant_dtype is not None:
|
||||||
|
w_l = [None] * e
|
||||||
|
w_s_l = [None] * e
|
||||||
|
for idx in range(e):
|
||||||
|
w_l[idx], w_s_l[idx] = moe_quantize_weights(
|
||||||
|
w_16[idx], None, quant_dtype, per_act_token_quant, block_shape)
|
||||||
|
|
||||||
|
w = torch.stack(w_l)
|
||||||
|
w_s = torch.stack(w_s_l)
|
||||||
|
if w_s.ndim == 2:
|
||||||
|
assert w_s.shape[-1] == 1
|
||||||
|
w_s = w_s.view(-1, 1, 1)
|
||||||
|
|
||||||
|
if block_shape is not None:
|
||||||
|
block_n, block_k = block_shape
|
||||||
|
n_tiles = (rows + block_n - 1) // block_n
|
||||||
|
k_tiles = (cols + block_k - 1) // block_k
|
||||||
|
assert w_s.shape == (e, n_tiles, k_tiles)
|
||||||
|
else:
|
||||||
|
w = w_16
|
||||||
|
w_s = None
|
||||||
|
|
||||||
|
return w_16, w, w_s
|
||||||
|
|
||||||
|
|
||||||
|
def make_test_weights(
|
||||||
|
e: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
in_dtype: torch.dtype = torch.bfloat16,
|
||||||
|
quant_dtype: Optional[torch.dtype] = None,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
per_act_token_quant: bool = False,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor,
|
||||||
|
torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
return (
|
||||||
|
*make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
|
||||||
|
per_act_token_quant),
|
||||||
|
*make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
|
||||||
|
per_act_token_quant),
|
||||||
|
)
|
||||||
|
|||||||
@@ -5,7 +5,10 @@ from typing import Optional, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
group_broadcast)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import round_up
|
||||||
|
|
||||||
# Using the default value (240.0) from pytorch will cause accuracy
|
# Using the default value (240.0) from pytorch will cause accuracy
|
||||||
# issue on dynamic quantization models. Here use 224.0 for rocm.
|
# issue on dynamic quantization models. Here use 224.0 for rocm.
|
||||||
@@ -94,9 +97,15 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
|
|||||||
return ref_out, ref_scale.view((1, ))
|
return ref_out, ref_scale.view((1, ))
|
||||||
|
|
||||||
|
|
||||||
def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
|
def native_w8a8_block_matmul(
|
||||||
As: torch.Tensor, Bs: torch.Tensor, block_size,
|
A: torch.Tensor,
|
||||||
output_dtype):
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
block_size: list[int],
|
||||||
|
output_dtype: torch.dtype,
|
||||||
|
compute_type: torch.dtype = torch.float32,
|
||||||
|
) -> torch.Tensor:
|
||||||
"""This function performs matrix multiplication with block-wise
|
"""This function performs matrix multiplication with block-wise
|
||||||
quantization using native torch.
|
quantization using native torch.
|
||||||
It is agnostic to the input data type and can be used for both int8 and
|
It is agnostic to the input data type and can be used for both int8 and
|
||||||
@@ -106,8 +115,8 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
|
|||||||
`Bs` (float32).
|
`Bs` (float32).
|
||||||
The output is returned in the specified `output_dtype`.
|
The output is returned in the specified `output_dtype`.
|
||||||
"""
|
"""
|
||||||
A = A.to(torch.float32)
|
A = A.to(compute_type)
|
||||||
B = B.to(torch.float32)
|
B = B.to(compute_type)
|
||||||
assert A.shape[-1] == B.shape[-1]
|
assert A.shape[-1] == B.shape[-1]
|
||||||
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||||
assert len(block_size) == 2
|
assert len(block_size) == 2
|
||||||
@@ -122,11 +131,11 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
|
|||||||
As = As.reshape(M, As.shape[-1])
|
As = As.reshape(M, As.shape[-1])
|
||||||
n_tiles = (N + block_n - 1) // block_n
|
n_tiles = (N + block_n - 1) // block_n
|
||||||
k_tiles = (K + block_k - 1) // block_k
|
k_tiles = (K + block_k - 1) // block_k
|
||||||
assert n_tiles == Bs.shape[0]
|
assert n_tiles == Bs.shape[0], f"{n_tiles} == {Bs.shape[0]}"
|
||||||
assert k_tiles == Bs.shape[1]
|
assert k_tiles == Bs.shape[1], f"{k_tiles} == {Bs.shape[1]}"
|
||||||
|
|
||||||
C_shape = (M, N)
|
C_shape = (M, N)
|
||||||
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
|
C = torch.zeros(C_shape, dtype=compute_type, device=A.device)
|
||||||
|
|
||||||
A_tiles = [
|
A_tiles = [
|
||||||
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
|
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
|
||||||
@@ -152,3 +161,152 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
|
|||||||
|
|
||||||
C = C.reshape(origin_C_shape).to(output_dtype)
|
C = C.reshape(origin_C_shape).to(output_dtype)
|
||||||
return C
|
return C
|
||||||
|
|
||||||
|
|
||||||
|
def native_per_token_group_quant_fp8(x,
|
||||||
|
group_size,
|
||||||
|
eps=1e-10,
|
||||||
|
dtype=torch.float8_e4m3fn):
|
||||||
|
"""Function to perform per-token-group quantization on an input tensor
|
||||||
|
`x` using native torch."""
|
||||||
|
assert x.shape[-1] % group_size == 0, ("the last dimension of `x` must "
|
||||||
|
"be divisible by `group_size`")
|
||||||
|
assert x.is_contiguous(), "`x` is not contiguous"
|
||||||
|
|
||||||
|
finfo = torch.finfo(dtype)
|
||||||
|
fp8_min = finfo.min
|
||||||
|
fp8_max = finfo.max
|
||||||
|
|
||||||
|
x_ = x.reshape(x.numel() // group_size, group_size)
|
||||||
|
amax = x_.abs().max(dim=-1,
|
||||||
|
keepdim=True)[0].clamp(min=eps).to(torch.float32)
|
||||||
|
x_s = amax / fp8_max
|
||||||
|
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
|
||||||
|
x_q = x_q.reshape(x.shape)
|
||||||
|
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
|
||||||
|
|
||||||
|
return x_q, x_s
|
||||||
|
|
||||||
|
|
||||||
|
def native_per_token_group_quant_int8(x,
|
||||||
|
group_size,
|
||||||
|
eps=1e-10,
|
||||||
|
dtype=torch.int8):
|
||||||
|
"""Function to perform per-token-group quantization on an input tensor
|
||||||
|
`x` using native torch.
|
||||||
|
|
||||||
|
It converts the tensor values into int8 values and returns the
|
||||||
|
quantized tensor along with the scaling factor used for quantization.
|
||||||
|
"""
|
||||||
|
assert (x.shape[-1] % group_size == 0
|
||||||
|
), "the last dimension of `x` must be divisible by `group_size`"
|
||||||
|
assert x.is_contiguous(), "`x` is not contiguous"
|
||||||
|
|
||||||
|
iinfo = torch.iinfo(dtype)
|
||||||
|
int8_min = iinfo.min
|
||||||
|
int8_max = iinfo.max
|
||||||
|
|
||||||
|
x_ = x.reshape(x.numel() // group_size, group_size)
|
||||||
|
# Use float32 for scale calculation for stability
|
||||||
|
amax = x_.abs().max(dim=-1,
|
||||||
|
keepdim=True)[0].clamp(min=eps).to(torch.float32)
|
||||||
|
x_s = amax / int8_max
|
||||||
|
x_q = (x_.to(torch.float32) / x_s).round().clamp(
|
||||||
|
min=int8_min, max=int8_max).to(dtype) # Round before clamping
|
||||||
|
x_q = x_q.reshape(x.shape)
|
||||||
|
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
|
||||||
|
|
||||||
|
return x_q, x_s
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_BLOCK_SHAPE = [128, 128]
|
||||||
|
|
||||||
|
|
||||||
|
def per_block_cast_to_fp8(
|
||||||
|
x: torch.Tensor,
|
||||||
|
block_shape: list[int] = DEFAULT_BLOCK_SHAPE,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
block_m, block_n = block_shape
|
||||||
|
assert x.dim() == 2
|
||||||
|
m, n = x.shape
|
||||||
|
x_padded = torch.zeros((round_up(m, block_m), round_up(n, block_n)),
|
||||||
|
dtype=x.dtype,
|
||||||
|
device=x.device)
|
||||||
|
x_padded[:m, :n] = x
|
||||||
|
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
|
||||||
|
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||||
|
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||||
|
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
|
||||||
|
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
||||||
|
return x_scaled_sub, scales
|
||||||
|
|
||||||
|
|
||||||
|
def per_block_cast_to_int8(
|
||||||
|
x: torch.Tensor,
|
||||||
|
block_shape: list[int] = DEFAULT_BLOCK_SHAPE,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
block_m, block_n = block_shape
|
||||||
|
assert x.dim() == 2
|
||||||
|
m, n = x.shape
|
||||||
|
x_padded = torch.zeros((round_up(m, block_m), round_up(n, block_n)),
|
||||||
|
dtype=x.dtype,
|
||||||
|
device=x.device)
|
||||||
|
x_padded[:m, :n] = x
|
||||||
|
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
|
||||||
|
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||||
|
x_scaled = (x_view * (256.0 / x_amax)).to(torch.int8)
|
||||||
|
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
|
||||||
|
scales = (x_amax / 256.0).view(x_view.size(0), x_view.size(2))
|
||||||
|
return x_scaled_sub, scales
|
||||||
|
|
||||||
|
|
||||||
|
def dequant(
|
||||||
|
t: torch.Tensor,
|
||||||
|
scale: Optional[torch.Tensor],
|
||||||
|
block_shape: Optional[list[int]],
|
||||||
|
per_act_token_quant: bool,
|
||||||
|
out_dtype: Optional[torch.dtype] = torch.float32,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if scale is not None:
|
||||||
|
f32 = torch.float32
|
||||||
|
if per_act_token_quant or block_shape is None:
|
||||||
|
return (t.to(f32) * scale).to(out_dtype)
|
||||||
|
else:
|
||||||
|
return (t.to(f32) * group_broadcast(scale, t.shape)).to(out_dtype)
|
||||||
|
else:
|
||||||
|
return t.to(out_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def native_batched_masked_quant_matmul(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
C: torch.Tensor,
|
||||||
|
num_expert_tokens: torch.Tensor,
|
||||||
|
A_scale: Optional[torch.Tensor] = None,
|
||||||
|
B_scale: Optional[torch.Tensor] = None,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
per_act_token_quant: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
num_expert_tokens_cpu = num_expert_tokens.clone()
|
||||||
|
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
|
||||||
|
num_experts = num_expert_tokens.size(0)
|
||||||
|
|
||||||
|
for e in range(num_experts):
|
||||||
|
num_tokens = num_expert_tokens_cpu[e]
|
||||||
|
if A.dtype.itemsize == 1 and block_shape is not None:
|
||||||
|
assert A_scale is not None and B_scale is not None
|
||||||
|
tmp = native_w8a8_block_matmul(A[e], B[e], A_scale[e], B_scale[e],
|
||||||
|
block_shape, C.dtype)
|
||||||
|
C[e, :num_tokens, :] = tmp[:num_tokens, :]
|
||||||
|
elif A.dtype.itemsize == 1 and block_shape is None:
|
||||||
|
assert A_scale is not None and B_scale is not None
|
||||||
|
A_dq = dequant(A[e], A_scale[e], block_shape, per_act_token_quant)
|
||||||
|
B_dq = dequant(B[e], B_scale[e], block_shape, per_act_token_quant)
|
||||||
|
C[e, :num_tokens, :] = (
|
||||||
|
A_dq[:num_tokens] @ B_dq.transpose(0, 1)).to(C.dtype)
|
||||||
|
else:
|
||||||
|
assert A_scale is None
|
||||||
|
assert B_scale is None
|
||||||
|
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
|
||||||
|
|
||||||
|
return C
|
||||||
|
|||||||
@@ -7,16 +7,10 @@ import itertools
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.kernels.quant_utils import native_w8a8_block_matmul
|
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
native_w8a8_block_matmul,
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
per_block_cast_to_fp8)
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
|
||||||
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
|
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
|
||||||
fused_topk, modular_triton_fused_moe)
|
|
||||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
|
||||||
moe_align_block_size)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@@ -46,78 +40,10 @@ N = [128, 512, 7168, 7748, 13824]
|
|||||||
K = [256, 3884, 4096, 13824, 16384]
|
K = [256, 3884, 4096, 13824, 16384]
|
||||||
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
|
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
|
||||||
# and its hidden size is 7168.
|
# and its hidden size is 7168.
|
||||||
M_moe = [1, 2, 7, 83, 128, 2048, 1024 * 128]
|
|
||||||
M_moe_dg = [128, 192, 1335, 2048]
|
|
||||||
N_moe = [128, 256, 1024, 4608] # [13824]
|
|
||||||
K_moe = [256, 512, 7168] # [13824]
|
|
||||||
BLOCK_SIZE = [[128, 128]]
|
BLOCK_SIZE = [[128, 128]]
|
||||||
E = [2, 8, 16, 24] # [128, 256]
|
|
||||||
TOP_KS = [1, 2, 6]
|
|
||||||
OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16]
|
OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16]
|
||||||
SEEDS = [0]
|
SEEDS = [0]
|
||||||
|
|
||||||
|
|
||||||
def native_per_token_group_quant_fp8(x,
|
|
||||||
group_size,
|
|
||||||
eps=1e-10,
|
|
||||||
dtype=torch.float8_e4m3fn):
|
|
||||||
"""Function to perform per-token-group quantization on an input tensor
|
|
||||||
`x` using native torch."""
|
|
||||||
assert x.shape[-1] % group_size == 0, ("the last dimension of `x` cannot "
|
|
||||||
"be divisible by `group_size`")
|
|
||||||
assert x.is_contiguous(), "`x` is not contiguous"
|
|
||||||
|
|
||||||
finfo = torch.finfo(dtype)
|
|
||||||
fp8_min = finfo.min
|
|
||||||
fp8_max = finfo.max
|
|
||||||
|
|
||||||
x_ = x.reshape(x.numel() // group_size, group_size)
|
|
||||||
amax = x_.abs().max(dim=-1,
|
|
||||||
keepdim=True)[0].clamp(min=eps).to(torch.float32)
|
|
||||||
x_s = amax / fp8_max
|
|
||||||
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
|
|
||||||
x_q = x_q.reshape(x.shape)
|
|
||||||
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
|
|
||||||
|
|
||||||
return x_q, x_s
|
|
||||||
|
|
||||||
|
|
||||||
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
|
|
||||||
"""Fused moe with block-wise quantization using native torch."""
|
|
||||||
B, D = a.shape
|
|
||||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
|
||||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
|
||||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
|
||||||
topk_weight, topk_ids = torch.topk(score, topk)
|
|
||||||
topk_weight = topk_weight.view(-1)
|
|
||||||
topk_ids = topk_ids.view(-1)
|
|
||||||
|
|
||||||
_, block_k = block_shape[0], block_shape[1]
|
|
||||||
a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
|
|
||||||
a_q = a_q.to(torch.float32)
|
|
||||||
for i in range(w1.shape[0]):
|
|
||||||
mask = topk_ids == i
|
|
||||||
if mask.sum():
|
|
||||||
inter_out = native_w8a8_block_matmul(a_q[mask],
|
|
||||||
w1[i],
|
|
||||||
a_s[mask],
|
|
||||||
w1_s[i],
|
|
||||||
block_shape,
|
|
||||||
output_dtype=a.dtype)
|
|
||||||
act_out = SiluAndMul().forward_native(inter_out)
|
|
||||||
act_out_q, act_out_s = native_per_token_group_quant_fp8(
|
|
||||||
act_out, block_k)
|
|
||||||
act_out = act_out.to(torch.float32)
|
|
||||||
out[mask] = native_w8a8_block_matmul(act_out_q,
|
|
||||||
w2[i],
|
|
||||||
act_out_s,
|
|
||||||
w2_s[i],
|
|
||||||
block_shape,
|
|
||||||
output_dtype=a.dtype)
|
|
||||||
return (out.view(B, -1, w2.shape[1]) *
|
|
||||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
# Skip all tests if CUDA is not available
|
# Skip all tests if CUDA is not available
|
||||||
pytest.importorskip("torch.cuda")
|
pytest.importorskip("torch.cuda")
|
||||||
|
|
||||||
@@ -177,111 +103,6 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
|
|||||||
assert rel_diff < 0.001
|
assert rel_diff < 0.001
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"M,N,K,E,topk,block_size,dtype,seed",
|
|
||||||
itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES,
|
|
||||||
SEEDS))
|
|
||||||
@torch.inference_mode()
|
|
||||||
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
|
||||||
if topk > E:
|
|
||||||
pytest.skip(f"Skipping test; topk={topk} > E={E}")
|
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
factor_for_scale = 1e-2
|
|
||||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
|
||||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
|
||||||
|
|
||||||
a = torch.randn((M, K), dtype=dtype) / 10
|
|
||||||
|
|
||||||
w1_bf16 = (torch.rand(
|
|
||||||
(E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max
|
|
||||||
w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
|
||||||
del w1_bf16
|
|
||||||
|
|
||||||
w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max
|
|
||||||
w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
|
||||||
del w2_bf16
|
|
||||||
|
|
||||||
block_n, block_k = block_size[0], block_size[1]
|
|
||||||
n_tiles_w1 = (2 * N + block_n - 1) // block_n
|
|
||||||
n_tiles_w2 = (K + block_n - 1) // block_n
|
|
||||||
k_tiles_w1 = (K + block_k - 1) // block_k
|
|
||||||
k_tiles_w2 = (N + block_k - 1) // block_k
|
|
||||||
|
|
||||||
w1_s = torch.rand(
|
|
||||||
(E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale
|
|
||||||
w2_s = torch.rand(
|
|
||||||
(E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale
|
|
||||||
|
|
||||||
score = torch.randn((M, E), dtype=dtype)
|
|
||||||
|
|
||||||
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
|
|
||||||
use_int8_w8a8=False,
|
|
||||||
use_int8_w8a16=False,
|
|
||||||
use_int4_w4a16=False,
|
|
||||||
per_channel_quant=False,
|
|
||||||
block_shape=block_size)
|
|
||||||
|
|
||||||
# Set the context to avoid lots of warning spam.
|
|
||||||
with set_current_vllm_config(vllm_config):
|
|
||||||
out = fused_moe(
|
|
||||||
a,
|
|
||||||
w1,
|
|
||||||
w2,
|
|
||||||
score,
|
|
||||||
topk,
|
|
||||||
renormalize=False,
|
|
||||||
use_fp8_w8a8=True,
|
|
||||||
w1_scale=w1_s,
|
|
||||||
w2_scale=w2_s,
|
|
||||||
block_shape=block_size,
|
|
||||||
)
|
|
||||||
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk,
|
|
||||||
block_size)
|
|
||||||
|
|
||||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
|
||||||
m_out = m_fused_moe(a,
|
|
||||||
w1,
|
|
||||||
w2,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
global_num_experts=E,
|
|
||||||
w1_scale=w1_s,
|
|
||||||
w2_scale=w2_s)
|
|
||||||
|
|
||||||
#print(f"{out.sum()=}")
|
|
||||||
#print(f"{ref_out.sum()=}")
|
|
||||||
|
|
||||||
rel_diff = (torch.mean(
|
|
||||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
|
||||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
|
||||||
assert rel_diff < 0.03
|
|
||||||
|
|
||||||
rel_diff = (torch.mean(
|
|
||||||
torch.abs(m_out.to(torch.float32) - ref_out.to(torch.float32))) /
|
|
||||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
|
||||||
assert rel_diff < 0.03
|
|
||||||
|
|
||||||
|
|
||||||
def per_block_cast_to_fp8(
|
|
||||||
x: torch.Tensor,
|
|
||||||
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
assert x.dim() == 2
|
|
||||||
m, n = x.shape
|
|
||||||
x_padded = torch.zeros(
|
|
||||||
(deep_gemm.ceil_div(m, 128) * 128,
|
|
||||||
deep_gemm.ceil_div(n, block_size_n) * block_size_n),
|
|
||||||
dtype=x.dtype,
|
|
||||||
device=x.device)
|
|
||||||
x_padded[:m, :n] = x
|
|
||||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
|
|
||||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
|
||||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
|
||||||
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
|
|
||||||
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
|
||||||
return x_scaled_sub, scales
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"M,N,K,block_size,out_dtype,seed",
|
"M,N,K,block_size,out_dtype,seed",
|
||||||
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
|
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
|
||||||
@@ -324,187 +145,3 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
|||||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||||
assert rel_diff < 0.001
|
assert rel_diff < 0.001
|
||||||
|
|
||||||
|
|
||||||
def fp8_perm(m, idx):
|
|
||||||
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
|
|
||||||
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
|
|
||||||
else:
|
|
||||||
return m[idx, ...]
|
|
||||||
|
|
||||||
|
|
||||||
def _moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
|
|
||||||
M, K = a.shape
|
|
||||||
|
|
||||||
sorted_token_ids, m_indices, num_pad = moe_align_block_size(
|
|
||||||
topk_ids, block_m, num_groups, None, pad_sorted_ids=True)
|
|
||||||
|
|
||||||
num_tokens = topk * M
|
|
||||||
|
|
||||||
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
|
|
||||||
m_indices = torch.repeat_interleave(m_indices, block_m, dim=0)
|
|
||||||
inv_perm = torch.argsort(sorted_token_ids)[:M * topk]
|
|
||||||
|
|
||||||
a = fp8_perm(a, sorted_token_ids // topk)
|
|
||||||
if a_s is not None:
|
|
||||||
a_s = a_s[sorted_token_ids // topk]
|
|
||||||
|
|
||||||
return a, a_s, m_indices, inv_perm
|
|
||||||
|
|
||||||
|
|
||||||
def _moe_unpermute(out, inv_perm, topk, K, topk_weight):
|
|
||||||
M = topk_weight.shape[0]
|
|
||||||
out = out[inv_perm, ...]
|
|
||||||
tmp_out = out.view(-1, topk, K)
|
|
||||||
return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
|
|
||||||
block_shape):
|
|
||||||
"""Fused moe with block-wise quantization using DeepGemm grouped gemm."""
|
|
||||||
num_groups = w1.shape[0]
|
|
||||||
M, K = a.shape
|
|
||||||
N = w2.shape[-1]
|
|
||||||
|
|
||||||
topk_weight, topk_ids, token_expert_indices = fused_topk(
|
|
||||||
a, score.float(), topk, False)
|
|
||||||
|
|
||||||
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
|
|
||||||
|
|
||||||
_, block_k = block_shape[0], block_shape[1]
|
|
||||||
|
|
||||||
a_q, a_s = per_token_group_quant_fp8(a, block_m)
|
|
||||||
|
|
||||||
a_q, a_s, m_indices, inv_perm = _moe_permute(a_q, a_s, topk_ids,
|
|
||||||
num_groups, topk, block_m)
|
|
||||||
|
|
||||||
inter_out = torch.zeros((a_q.shape[0], N * 2),
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
device=a.device)
|
|
||||||
|
|
||||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s),
|
|
||||||
inter_out, m_indices)
|
|
||||||
|
|
||||||
act_out = SiluAndMul().forward_native(inter_out)
|
|
||||||
act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k)
|
|
||||||
|
|
||||||
out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device)
|
|
||||||
|
|
||||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
|
||||||
(act_out_q, act_out_s), (w2, w2_s), out, m_indices)
|
|
||||||
|
|
||||||
final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight)
|
|
||||||
|
|
||||||
return final_out
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"M,N,K,E,topk,seed",
|
|
||||||
itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, SEEDS))
|
|
||||||
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
|
||||||
@torch.inference_mode()
|
|
||||||
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
|
||||||
monkeypatch):
|
|
||||||
if topk > E:
|
|
||||||
pytest.skip(f"Skipping test: topk={topk} > E={E}")
|
|
||||||
|
|
||||||
if not _valid_deep_gemm_shape(M, N, K):
|
|
||||||
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
|
|
||||||
|
|
||||||
chunk_size = 1024
|
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
|
|
||||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
|
|
||||||
|
|
||||||
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
|
|
||||||
block_size = [block_m, block_m]
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
|
|
||||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
|
||||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
|
||||||
|
|
||||||
a = torch.randn((M, K), dtype=dtype) / 10
|
|
||||||
|
|
||||||
w1_bf16 = ((torch.rand((E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 *
|
|
||||||
fp8_max).clamp(min=fp8_min, max=fp8_max)
|
|
||||||
|
|
||||||
w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 *
|
|
||||||
fp8_max).clamp(min=fp8_min, max=fp8_max)
|
|
||||||
|
|
||||||
score = torch.randn((M, E), dtype=dtype)
|
|
||||||
|
|
||||||
block_n, block_k = block_size[0], block_size[1]
|
|
||||||
n_tiles_w1 = ((2 * N) + block_n - 1) // block_n
|
|
||||||
k_tiles_w1 = (K + block_k - 1) // block_k
|
|
||||||
n_tiles_w2 = (K + block_n - 1) // block_n
|
|
||||||
k_tiles_w2 = (N + block_k - 1) // block_k
|
|
||||||
|
|
||||||
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
|
|
||||||
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
|
|
||||||
w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
|
|
||||||
|
|
||||||
w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous()
|
|
||||||
w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous()
|
|
||||||
|
|
||||||
assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128)
|
|
||||||
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
|
|
||||||
|
|
||||||
for i in range(E):
|
|
||||||
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
|
|
||||||
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
|
|
||||||
|
|
||||||
# Note: for now use_compile will error out if the problem size is
|
|
||||||
# large enough to trigger chunking. I'm leaving the flag and
|
|
||||||
# setup code in case we are able to revisit this later.
|
|
||||||
use_compile = False
|
|
||||||
|
|
||||||
use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024
|
|
||||||
and current_platform.is_cuda_alike())
|
|
||||||
|
|
||||||
# Set the context to avoid lots of warning spam.
|
|
||||||
with set_current_vllm_config(vllm_config):
|
|
||||||
if M >= 128:
|
|
||||||
ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s,
|
|
||||||
score, topk, block_size)
|
|
||||||
else:
|
|
||||||
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score,
|
|
||||||
topk, block_size)
|
|
||||||
|
|
||||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
|
||||||
a, score.float(), topk, False)
|
|
||||||
|
|
||||||
if use_compile:
|
|
||||||
deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8,
|
|
||||||
backend="inductor",
|
|
||||||
fullgraph=True)
|
|
||||||
torch._dynamo.mark_dynamic(a, 0)
|
|
||||||
torch._dynamo.mark_dynamic(topk_weights, 0)
|
|
||||||
torch._dynamo.mark_dynamic(topk_ids, 0)
|
|
||||||
else:
|
|
||||||
deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8
|
|
||||||
|
|
||||||
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
|
|
||||||
topk_ids)
|
|
||||||
|
|
||||||
if use_cudagraph:
|
|
||||||
out.fill_(0)
|
|
||||||
stream = torch.cuda.Stream()
|
|
||||||
graph = torch.cuda.CUDAGraph()
|
|
||||||
with torch.cuda.graph(graph, stream=stream):
|
|
||||||
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
|
|
||||||
topk_ids)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
graph.replay()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
#print(f"{out.sum()=}")
|
|
||||||
#print(f"{ref_out.sum()=}")
|
|
||||||
|
|
||||||
rel_diff = (torch.mean(
|
|
||||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
|
||||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
|
||||||
|
|
||||||
assert rel_diff < 0.03
|
|
||||||
|
|||||||
@@ -8,9 +8,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.kernels.quant_utils import native_w8a8_block_matmul
|
from tests.kernels.quant_utils import native_w8a8_block_matmul
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
|
||||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||||
w8a8_block_int8_matmul)
|
w8a8_block_int8_matmul)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@@ -23,82 +21,10 @@ vllm_config = VllmConfig()
|
|||||||
vllm_config.scheduler_config.max_num_seqs = 128
|
vllm_config.scheduler_config.max_num_seqs = 128
|
||||||
vllm_config.scheduler_config.max_model_len = 8192
|
vllm_config.scheduler_config.max_model_len = 8192
|
||||||
|
|
||||||
|
|
||||||
# For test
|
|
||||||
def native_per_token_group_quant_int8(x,
|
|
||||||
group_size,
|
|
||||||
eps=1e-10,
|
|
||||||
dtype=torch.int8):
|
|
||||||
"""Function to perform per-token-group quantization on an input tensor
|
|
||||||
`x` using native torch.
|
|
||||||
|
|
||||||
It converts the tensor values into int8 values and returns the
|
|
||||||
quantized tensor along with the scaling factor used for quantization.
|
|
||||||
"""
|
|
||||||
assert (x.shape[-1] % group_size == 0
|
|
||||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
|
||||||
assert x.is_contiguous(), "`x` is not contiguous"
|
|
||||||
|
|
||||||
iinfo = torch.iinfo(dtype)
|
|
||||||
int8_min = iinfo.min
|
|
||||||
int8_max = iinfo.max
|
|
||||||
|
|
||||||
x_ = x.reshape(x.numel() // group_size, group_size)
|
|
||||||
# Use float32 for scale calculation for stability
|
|
||||||
amax = x_.abs().max(dim=-1,
|
|
||||||
keepdim=True)[0].clamp(min=eps).to(torch.float32)
|
|
||||||
x_s = amax / int8_max
|
|
||||||
x_q = (x_.to(torch.float32) / x_s).round().clamp(
|
|
||||||
min=int8_min, max=int8_max).to(dtype) # Round before clamping
|
|
||||||
x_q = x_q.reshape(x.shape)
|
|
||||||
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
|
|
||||||
|
|
||||||
return x_q, x_s
|
|
||||||
|
|
||||||
|
|
||||||
# For test
|
|
||||||
def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
|
|
||||||
"""This function performs fused moe with block-wise quantization using
|
|
||||||
native torch."""
|
|
||||||
B, D = a.shape
|
|
||||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
|
||||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
|
||||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
|
||||||
topk_weight, topk_ids = torch.topk(score, topk)
|
|
||||||
topk_weight = topk_weight.view(-1)
|
|
||||||
topk_ids = topk_ids.view(-1)
|
|
||||||
|
|
||||||
_, block_k = block_shape[0], block_shape[1]
|
|
||||||
a_q, a_s = native_per_token_group_quant_int8(a, block_k)
|
|
||||||
for i in range(w1.shape[0]):
|
|
||||||
mask = topk_ids == i
|
|
||||||
if mask.sum():
|
|
||||||
inter_out = native_w8a8_block_matmul(a_q[mask],
|
|
||||||
w1[i],
|
|
||||||
a_s[mask],
|
|
||||||
w1_s[i],
|
|
||||||
block_shape,
|
|
||||||
output_dtype=a.dtype)
|
|
||||||
act_out = SiluAndMul().forward_native(inter_out)
|
|
||||||
act_out_q, act_out_s = native_per_token_group_quant_int8(
|
|
||||||
act_out, block_k)
|
|
||||||
act_out = act_out.to(torch.float32)
|
|
||||||
out[mask] = native_w8a8_block_matmul(act_out_q,
|
|
||||||
w2[i],
|
|
||||||
act_out_s,
|
|
||||||
w2_s[i],
|
|
||||||
block_shape,
|
|
||||||
output_dtype=a.dtype)
|
|
||||||
return (out.view(B, -1, w2.shape[1]) *
|
|
||||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
DTYPES = [torch.half, torch.bfloat16]
|
DTYPES = [torch.half, torch.bfloat16]
|
||||||
M = [1, 33, 64, 222]
|
M = [1, 33, 64, 222]
|
||||||
N = [128, 1024]
|
N = [128, 1024]
|
||||||
K = [256, 4096]
|
K = [256, 4096]
|
||||||
E = [8, 24]
|
|
||||||
TOP_KS = [2, 6]
|
|
||||||
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
|
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
|
||||||
BLOCK_SIZE = [[128, 128]]
|
BLOCK_SIZE = [[128, 128]]
|
||||||
SEEDS = [0]
|
SEEDS = [0]
|
||||||
@@ -140,63 +66,3 @@ def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed):
|
|||||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||||
assert rel_diff < 0.001
|
assert rel_diff < 0.001
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"M, N, K, E, topk, block_size, dtype, seed",
|
|
||||||
itertools.product(M, N, K, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS))
|
|
||||||
@torch.inference_mode()
|
|
||||||
def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
|
||||||
"""Tests the fused_moe kernel with W8A8 INT8 block quantization against a
|
|
||||||
native torch reference."""
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
# Use a smaller factor for scale initialization to prevent large
|
|
||||||
# values/overflow especially when output dtype might be float16
|
|
||||||
factor_for_scale = 1e-2
|
|
||||||
int8_info = torch.iinfo(torch.int8)
|
|
||||||
int8_max, int8_min = int8_info.max, int8_info.min
|
|
||||||
|
|
||||||
a = torch.randn((M, K), dtype=dtype) / 10
|
|
||||||
|
|
||||||
w1_fp32 = (torch.rand(
|
|
||||||
(E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 * int8_max
|
|
||||||
w1 = w1_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
|
|
||||||
|
|
||||||
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 * int8_max
|
|
||||||
w2 = w2_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
|
|
||||||
|
|
||||||
block_n, block_k = block_size[0], block_size[1]
|
|
||||||
n_tiles_w1 = (2 * N + block_n - 1) // block_n
|
|
||||||
n_tiles_w2 = (K + block_n - 1) // block_n
|
|
||||||
k_tiles_w1 = (K + block_k - 1) // block_k
|
|
||||||
k_tiles_w2 = (N + block_k - 1) // block_k
|
|
||||||
|
|
||||||
w1_s = (torch.rand(
|
|
||||||
(E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale)
|
|
||||||
w2_s = (torch.rand(
|
|
||||||
(E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale)
|
|
||||||
|
|
||||||
score = torch.randn((M, E), dtype=dtype)
|
|
||||||
|
|
||||||
# Set the context to avoid lots of warning spam.
|
|
||||||
with set_current_vllm_config(vllm_config):
|
|
||||||
out = fused_moe(
|
|
||||||
a,
|
|
||||||
w1,
|
|
||||||
w2,
|
|
||||||
score,
|
|
||||||
topk,
|
|
||||||
renormalize=False,
|
|
||||||
use_int8_w8a8=True,
|
|
||||||
w1_scale=w1_s,
|
|
||||||
w2_scale=w2_s,
|
|
||||||
block_shape=block_size,
|
|
||||||
)
|
|
||||||
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk,
|
|
||||||
block_size)
|
|
||||||
|
|
||||||
# Check results
|
|
||||||
rel_diff = (torch.mean(
|
|
||||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
|
||||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
|
||||||
assert rel_diff < 0.06
|
|
||||||
|
|||||||
@@ -13,8 +13,11 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
from torch._prims_common import TensorLikeType
|
from torch._prims_common import TensorLikeType
|
||||||
|
|
||||||
|
from tests.kernels.quant_utils import native_w8a8_block_matmul
|
||||||
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
|
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
|
moe_kernel_quantize_input)
|
||||||
from vllm.platforms.interface import _Backend
|
from vllm.platforms.interface import _Backend
|
||||||
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
|
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
|
||||||
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)
|
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)
|
||||||
@@ -1054,32 +1057,77 @@ def compute_max_diff(output, output_ref):
|
|||||||
torch.abs(output_ref))
|
torch.abs(output_ref))
|
||||||
|
|
||||||
|
|
||||||
def torch_experts(a: torch.Tensor,
|
def torch_experts(
|
||||||
w1: torch.Tensor,
|
a: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
topk_weight: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_weight: torch.Tensor,
|
||||||
global_num_experts: int = -1,
|
topk_ids: torch.Tensor,
|
||||||
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
quant_dtype: Optional[torch.dtype] = None,
|
||||||
|
per_act_token_quant=False,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
assert (global_num_experts == -1
|
assert (global_num_experts == -1
|
||||||
or (global_num_experts == w1.shape[0] and expert_map is None)
|
or (global_num_experts == w1.shape[0] and expert_map is None)
|
||||||
or (expert_map is not None
|
or (expert_map is not None
|
||||||
and global_num_experts == expert_map.shape[0]))
|
and global_num_experts == expert_map.shape[0]))
|
||||||
|
|
||||||
|
M, K = a.shape
|
||||||
topk = topk_ids.shape[1]
|
topk = topk_ids.shape[1]
|
||||||
B, D = a.shape
|
|
||||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
|
||||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
|
||||||
topk_weight = topk_weight.view(-1)
|
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||||
|
|
||||||
|
a, a_scale = moe_kernel_quantize_input(a, None, quant_dtype,
|
||||||
|
per_act_token_quant, block_shape)
|
||||||
|
|
||||||
|
num_experts = w1.shape[0]
|
||||||
|
|
||||||
topk_ids = topk_ids.view(-1)
|
topk_ids = topk_ids.view(-1)
|
||||||
if expert_map is not None:
|
if expert_map is not None:
|
||||||
topk_ids = expert_map[topk_ids]
|
topk_ids = expert_map[topk_ids]
|
||||||
for i in range(w1.shape[0]):
|
|
||||||
|
for i in range(num_experts):
|
||||||
mask = topk_ids == i
|
mask = topk_ids == i
|
||||||
if mask.sum():
|
if mask.sum():
|
||||||
out[mask] = SiluAndMul()(
|
if quant_dtype is None:
|
||||||
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
tmp1 = a[mask] @ w1[i].transpose(0, 1)
|
||||||
return (out.view(B, -1, w2.shape[1]) *
|
tmp2 = SiluAndMul()(tmp1)
|
||||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
out[mask] = tmp2 @ w2[i].transpose(0, 1)
|
||||||
|
elif block_shape is not None:
|
||||||
|
assert (a_scale is not None and w1_scale is not None
|
||||||
|
and w2_scale is not None)
|
||||||
|
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
|
||||||
|
w1_scale[i], block_shape,
|
||||||
|
out.dtype)
|
||||||
|
tmp2 = SiluAndMul()(tmp1)
|
||||||
|
tmp2, b_scale = moe_kernel_quantize_input(
|
||||||
|
tmp2, None, quant_dtype, per_act_token_quant, block_shape)
|
||||||
|
|
||||||
|
out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale,
|
||||||
|
w2_scale[i], block_shape,
|
||||||
|
out.dtype)
|
||||||
|
else:
|
||||||
|
assert (a_scale is not None and w1_scale is not None
|
||||||
|
and w2_scale is not None)
|
||||||
|
f32 = torch.float32
|
||||||
|
scales = a_scale if a_scale.numel() == 1 else a_scale[mask]
|
||||||
|
tmp1 = a[mask].to(f32) * scales
|
||||||
|
w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1)
|
||||||
|
tmp1 = tmp1 @ w1_dq
|
||||||
|
tmp2 = SiluAndMul()(tmp1)
|
||||||
|
w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1)
|
||||||
|
out[mask] = (tmp2 @ w2_dq).to(out.dtype)
|
||||||
|
|
||||||
|
return (out.view(M, -1, w2.shape[1]) *
|
||||||
|
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
|
||||||
|
|
||||||
|
|
||||||
def torch_moe(a: torch.Tensor,
|
def torch_moe(a: torch.Tensor,
|
||||||
|
|||||||
@@ -1274,7 +1274,7 @@ def scaled_fp8_quant(
|
|||||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||||
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
|
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
|
||||||
else:
|
else:
|
||||||
assert scale.numel() == 1
|
assert scale.numel() == 1, f"{scale.shape}"
|
||||||
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
|
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
|
||||||
|
|
||||||
return output, scale
|
return output, scale
|
||||||
|
|||||||
@@ -4,8 +4,12 @@
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||||
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||||
|
FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute,
|
||||||
|
FusedMoEPrepareAndFinalize)
|
||||||
from vllm.triton_utils import HAS_TRITON
|
from vllm.triton_utils import HAS_TRITON
|
||||||
|
|
||||||
_config: Optional[dict[str, Any]] = None
|
_config: Optional[dict[str, Any]] = None
|
||||||
@@ -26,8 +30,12 @@ def get_config() -> Optional[dict[str, Any]]:
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"FusedMoE",
|
"FusedMoE",
|
||||||
|
"FusedMoEConfig",
|
||||||
"FusedMoEMethodBase",
|
"FusedMoEMethodBase",
|
||||||
"FusedMoeWeightScaleSupported",
|
"FusedMoeWeightScaleSupported",
|
||||||
|
"FusedMoEPermuteExpertsUnpermute",
|
||||||
|
"FusedMoEActivationFormat",
|
||||||
|
"FusedMoEPrepareAndFinalize",
|
||||||
"override_config",
|
"override_config",
|
||||||
"get_config",
|
"get_config",
|
||||||
]
|
]
|
||||||
@@ -36,11 +44,21 @@ if HAS_TRITON:
|
|||||||
# import to register the custom ops
|
# import to register the custom ops
|
||||||
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
|
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
|
||||||
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
|
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
|
||||||
|
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||||
|
BatchedDeepGemmExperts)
|
||||||
|
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||||
|
BatchedTritonOrDeepGemmExperts)
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||||
cutlass_moe_fp4, cutlass_moe_fp8)
|
CutlassExpertsFp8, cutlass_moe_fp4, cutlass_moe_fp8)
|
||||||
|
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||||
|
DeepGemmExperts)
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||||
|
BatchedTritonExperts)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
TritonExperts, fused_experts, fused_moe, fused_topk,
|
TritonExperts, fused_experts, fused_moe, fused_topk,
|
||||||
get_config_file_name, grouped_topk)
|
get_config_file_name, grouped_topk)
|
||||||
|
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||||
|
TritonOrDeepGemmExperts)
|
||||||
|
|
||||||
__all__ += [
|
__all__ += [
|
||||||
"fused_moe",
|
"fused_moe",
|
||||||
@@ -50,5 +68,11 @@ if HAS_TRITON:
|
|||||||
"grouped_topk",
|
"grouped_topk",
|
||||||
"cutlass_moe_fp8",
|
"cutlass_moe_fp8",
|
||||||
"cutlass_moe_fp4",
|
"cutlass_moe_fp4",
|
||||||
|
"CutlassExpertsFp8",
|
||||||
"TritonExperts",
|
"TritonExperts",
|
||||||
|
"BatchedTritonExperts",
|
||||||
|
"DeepGemmExperts",
|
||||||
|
"BatchedDeepGemmExperts",
|
||||||
|
"TritonOrDeepGemmExperts",
|
||||||
|
"BatchedTritonOrDeepGemmExperts",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import torch
|
|||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
@@ -179,28 +180,44 @@ def silu_mul_fp8_quant_deep_gemm(
|
|||||||
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
|
|
||||||
# The Deep Gemm kernels only support block size of 128
|
# The Deep Gemm kernels only support block size of 128
|
||||||
DEEPGEMM_BLOCK_SHAPE = 128
|
DEEPGEMM_BLOCK_SHAPE: list[int] = [128, 128]
|
||||||
|
|
||||||
def __init__(self, max_num_tokens: int, world_size: int, dp_size: int,
|
def __init__(self,
|
||||||
block_shape: list[int]):
|
max_num_tokens: int,
|
||||||
|
world_size: int,
|
||||||
|
dp_size: int,
|
||||||
|
block_shape: list[int],
|
||||||
|
per_act_token_quant=False):
|
||||||
"""
|
"""
|
||||||
max_num_tokens: Maximum number of tokens from a DP Rank
|
max_num_tokens: Maximum number of tokens from a DP Rank
|
||||||
world_size: Number of EP ranks
|
world_size: Number of EP ranks
|
||||||
dp_size: Number of data-parallel ranks
|
dp_size: Number of data-parallel ranks
|
||||||
block_shape: Block quantization block shape
|
block_shape: Block quantization block shape
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__(
|
||||||
|
FusedMoEQuantConfig(
|
||||||
|
quant_dtype=torch.float8_e4m3fn,
|
||||||
|
per_act_token_quant=per_act_token_quant,
|
||||||
|
block_shape=block_shape,
|
||||||
|
))
|
||||||
|
assert self.block_shape == self.DEEPGEMM_BLOCK_SHAPE
|
||||||
self.max_num_tokens = max_num_tokens
|
self.max_num_tokens = max_num_tokens
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.dp_size = dp_size
|
self.dp_size = dp_size
|
||||||
self.block_shape = block_shape
|
|
||||||
|
|
||||||
assert (len(self.block_shape) == 2 and all(
|
@property
|
||||||
[v == self.DEEPGEMM_BLOCK_SHAPE for v in self.block_shape]))
|
def activation_formats(
|
||||||
|
self
|
||||||
|
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||||
|
return (mk.FusedMoEActivationFormat.BatchedExperts,
|
||||||
|
mk.FusedMoEActivationFormat.BatchedExperts)
|
||||||
|
|
||||||
def supports_chunking(self) -> bool:
|
def supports_chunking(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def supports_expert_map(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
def workspace_shapes(
|
def workspace_shapes(
|
||||||
self,
|
self,
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
@@ -248,6 +265,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
):
|
):
|
||||||
import deep_gemm as dg
|
import deep_gemm as dg
|
||||||
assert hidden_states.ndim == 3
|
assert hidden_states.ndim == 3
|
||||||
|
assert self.block_shape is not None
|
||||||
|
|
||||||
a1q = hidden_states
|
a1q = hidden_states
|
||||||
_, N, K = w1.size()
|
_, N, K = w1.size()
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import torch
|
|||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||||
BatchedDeepGemmExperts)
|
BatchedDeepGemmExperts)
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||||
BatchedTritonExperts)
|
BatchedTritonExperts)
|
||||||
|
|
||||||
@@ -20,43 +21,45 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
use_int4_w4a16: bool = False,
|
use_int4_w4a16: bool = False,
|
||||||
per_channel_quant: bool = False,
|
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
|
per_act_token_quant: bool = False,
|
||||||
allow_deep_gemm: bool = False):
|
allow_deep_gemm: bool = False):
|
||||||
super().__init__()
|
|
||||||
assert not use_int8_w8a8, "NYI"
|
assert not use_int8_w8a8, "NYI"
|
||||||
assert not use_int8_w8a16, "NYI"
|
assert not use_int8_w8a16, "NYI"
|
||||||
assert not use_int4_w4a16, "NYI"
|
assert not use_int4_w4a16, "NYI"
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
FusedMoEQuantConfig.make(
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a8=use_int8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
use_int4_w4a16=use_int4_w4a16,
|
||||||
|
block_shape=block_shape,
|
||||||
|
per_act_token_quant=per_act_token_quant,
|
||||||
|
))
|
||||||
self.max_num_tokens = max_num_tokens
|
self.max_num_tokens = max_num_tokens
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.dp_size = dp_size
|
self.dp_size = dp_size
|
||||||
self.use_fp8_w8a8 = use_fp8_w8a8
|
|
||||||
self.use_int8_w8a8 = use_int8_w8a8
|
|
||||||
self.use_int8_w8a16 = use_int8_w8a16
|
|
||||||
self.use_int4_w4a16 = use_int4_w4a16
|
|
||||||
self.per_channel_quant = per_channel_quant
|
|
||||||
self.block_shape = block_shape
|
|
||||||
self.allow_deep_gemm = allow_deep_gemm
|
self.allow_deep_gemm = allow_deep_gemm
|
||||||
|
|
||||||
# BatchedTritonKernel doesn't support block quantization
|
# BatchedTritonKernel doesn't support block quantization
|
||||||
# at the moment.
|
# at the moment.
|
||||||
self.batched_triton_experts = BatchedTritonExperts(
|
self.batched_triton_experts = BatchedTritonExperts(
|
||||||
max_num_tokens=self.max_num_tokens,
|
max_num_tokens=self.max_num_tokens,
|
||||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
|
||||||
use_int8_w8a8=self.use_int8_w8a8,
|
|
||||||
use_int8_w8a16=self.use_int8_w8a16,
|
|
||||||
use_int4_w4a16=self.use_int4_w4a16,
|
|
||||||
per_channel_quant=self.per_channel_quant,
|
|
||||||
block_shape=self.block_shape,
|
|
||||||
world_size=self.world_size,
|
world_size=self.world_size,
|
||||||
dp_size=self.dp_size) if self.block_shape is None else None
|
dp_size=self.dp_size,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a8=use_int8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
use_int4_w4a16=use_int4_w4a16,
|
||||||
|
per_act_token_quant=self.per_act_token_quant,
|
||||||
|
block_shape=self.block_shape,
|
||||||
|
) if self.block_shape is None else None
|
||||||
|
|
||||||
|
is_fp8_128_block_quantized = (
|
||||||
|
use_fp8_w8a8 and self.block_shape
|
||||||
|
== BatchedDeepGemmExperts.DEEPGEMM_BLOCK_SHAPE)
|
||||||
|
|
||||||
is_fp8_128_block_quantized = (self.use_fp8_w8a8
|
|
||||||
and self.block_shape is not None
|
|
||||||
and len(self.block_shape) == 2 and all(
|
|
||||||
[b == 128
|
|
||||||
for b in self.block_shape]))
|
|
||||||
self.batched_deep_gemm_experts = BatchedDeepGemmExperts(
|
self.batched_deep_gemm_experts = BatchedDeepGemmExperts(
|
||||||
max_num_tokens=self.max_num_tokens,
|
max_num_tokens=self.max_num_tokens,
|
||||||
world_size=self.world_size,
|
world_size=self.world_size,
|
||||||
@@ -67,12 +70,31 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
assert (self.batched_deep_gemm_experts is not None
|
assert (self.batched_deep_gemm_experts is not None
|
||||||
or self.batched_triton_experts is not None)
|
or self.batched_triton_experts is not None)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activation_formats(
|
||||||
|
self
|
||||||
|
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||||
|
if self.batched_triton_experts is not None:
|
||||||
|
assert (self.batched_deep_gemm_experts is None
|
||||||
|
or self.batched_deep_gemm_experts.activation_formats
|
||||||
|
== self.batched_triton_experts.activation_formats)
|
||||||
|
return self.batched_triton_experts.activation_formats
|
||||||
|
else:
|
||||||
|
assert self.batched_deep_gemm_experts is not None
|
||||||
|
return self.batched_deep_gemm_experts.activation_formats
|
||||||
|
|
||||||
def supports_chunking(self) -> bool:
|
def supports_chunking(self) -> bool:
|
||||||
bdge = self.batched_deep_gemm_experts
|
bdge = self.batched_deep_gemm_experts
|
||||||
bte = self.batched_triton_experts
|
bte = self.batched_triton_experts
|
||||||
return ((bdge is None or bdge.supports_chunking())
|
return ((bdge is None or bdge.supports_chunking())
|
||||||
and (bte is None or bte.supports_chunking()))
|
and (bte is None or bte.supports_chunking()))
|
||||||
|
|
||||||
|
def supports_expert_map(self) -> bool:
|
||||||
|
bdge = self.batched_deep_gemm_experts
|
||||||
|
bte = self.batched_triton_experts
|
||||||
|
return ((bdge is None or bdge.supports_expert_map())
|
||||||
|
and (bte is None or bte.supports_expert_map()))
|
||||||
|
|
||||||
def workspace_shapes(
|
def workspace_shapes(
|
||||||
self,
|
self,
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
@@ -87,7 +109,8 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
# Note: the deep gemm workspaces are strictly larger than the triton
|
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||||
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||||
# even if we fall back to triton later, e.g. if expert maps are set.
|
# even if we fall back to triton later, e.g. if expert maps are set.
|
||||||
if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None:
|
if self.allow_deep_gemm:
|
||||||
|
assert self.batched_deep_gemm_experts is not None
|
||||||
return self.batched_deep_gemm_experts.workspace_shapes(
|
return self.batched_deep_gemm_experts.workspace_shapes(
|
||||||
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
|
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
|
||||||
else:
|
else:
|
||||||
|
|||||||
410
vllm/model_executor/layers/fused_moe/config.py
Normal file
410
vllm/model_executor/layers/fused_moe/config.py
Normal file
@@ -0,0 +1,410 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from compressed_tensors.quantization import (QuantizationArgs,
|
||||||
|
QuantizationStrategy,
|
||||||
|
QuantizationType)
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.config import ParallelConfig
|
||||||
|
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_quant_config_quantization_args(
|
||||||
|
quant_config: Optional[QuantizationConfig],
|
||||||
|
prop_name: str,
|
||||||
|
) -> Optional[QuantizationArgs]:
|
||||||
|
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
|
||||||
|
and "Linear" in quant_config.target_scheme_map and
|
||||||
|
"input_activations" in quant_config.target_scheme_map["Linear"]):
|
||||||
|
return quant_config.target_scheme_map["Linear"].get(prop_name)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_quant_config_input_quant(
|
||||||
|
quant_config: Optional[QuantizationConfig]
|
||||||
|
) -> Optional[QuantizationArgs]:
|
||||||
|
return _get_quant_config_quantization_args(quant_config,
|
||||||
|
"input_activations")
|
||||||
|
|
||||||
|
|
||||||
|
def get_quant_config_weight_quant(
|
||||||
|
quant_config: Optional[QuantizationConfig]
|
||||||
|
) -> Optional[QuantizationArgs]:
|
||||||
|
return _get_quant_config_quantization_args(quant_config, "weights")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO (bnell): use scalar_type instead of bools?
|
||||||
|
def get_config_quant_dtype(
|
||||||
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a8: bool,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
|
use_int4_w4a16: bool,
|
||||||
|
) -> Optional[torch.dtype]:
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
return torch.float8_e4m3fn
|
||||||
|
elif use_int8_w8a8:
|
||||||
|
return torch.int8
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FusedMoEQuantConfig:
|
||||||
|
# The post quantization activation type.
|
||||||
|
quant_dtype: Optional[torch.dtype] = None
|
||||||
|
per_act_token_quant: bool = False
|
||||||
|
per_out_ch_quant: bool = False
|
||||||
|
block_shape: Optional[list[int]] = None
|
||||||
|
|
||||||
|
# TODO: add col major flag?
|
||||||
|
# add detailed quant info for input, intermediates, weights, etc?
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make(
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
use_int8_w8a8: bool = False,
|
||||||
|
use_int8_w8a16: bool = False,
|
||||||
|
use_int4_w4a16: bool = False,
|
||||||
|
per_act_token_quant: bool = False,
|
||||||
|
per_out_ch_quant: bool = False,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
) -> "FusedMoEQuantConfig":
|
||||||
|
assert sum([
|
||||||
|
int(flag) for flag in [
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
|
use_int4_w4a16,
|
||||||
|
]
|
||||||
|
]) <= 1, "Quantization flags are mutually exclusive."
|
||||||
|
|
||||||
|
quant_dtype = get_config_quant_dtype(
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a8=use_int8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
use_int4_w4a16=use_int4_w4a16,
|
||||||
|
)
|
||||||
|
return FusedMoEQuantConfig(
|
||||||
|
quant_dtype,
|
||||||
|
per_act_token_quant,
|
||||||
|
per_out_ch_quant,
|
||||||
|
block_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FusedMoEParallelConfig:
|
||||||
|
tp_size: int
|
||||||
|
dp_size: int
|
||||||
|
ep_size: int
|
||||||
|
tp_rank: int
|
||||||
|
dp_rank: int
|
||||||
|
ep_rank: int
|
||||||
|
world_size: int
|
||||||
|
|
||||||
|
use_ep: bool # whether to use EP or not
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_all2all_kernels(self):
|
||||||
|
return self.dp_size > 1 and self.use_ep
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_pplx_kernels(self):
|
||||||
|
return (self.use_all2all_kernels
|
||||||
|
and envs.VLLM_ALL2ALL_BACKEND == "pplx")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_deepep_ht_kernels(self):
|
||||||
|
return (self.use_all2all_kernels
|
||||||
|
and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_deepep_ll_kernels(self):
|
||||||
|
return (self.use_all2all_kernels
|
||||||
|
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make(tp_size_: int, dp_size_: int, world_size_: int,
|
||||||
|
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
|
||||||
|
"""
|
||||||
|
Determine MoE parallel configuration. Based on the input tp_size_,
|
||||||
|
dp_size_, ep_size_ and vllm's parallel config, determine what
|
||||||
|
level's of parallelism to use in the fused moe layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tp_size_ (int): tp_size passed into the FusedMoE constructor.
|
||||||
|
dp_size_ (int): dp_size passed into the FusedMoE constructor.
|
||||||
|
ep_size_ (int): ep_size passed into the FusedMoE constructor.
|
||||||
|
world_size_ (int): the world size of the current All2All manager.
|
||||||
|
vllm_parallel_config (ParallelConfig): vllm's parallel config
|
||||||
|
object.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1,
|
||||||
|
we simply return the sizes unaltered and the ranks set to 0.
|
||||||
|
|
||||||
|
Expert Parallelism is considered only when either dp_size_ or tp_size_
|
||||||
|
is non trivial.
|
||||||
|
|
||||||
|
When TP = 2, DP = 1 and EP = False, the configuration on different
|
||||||
|
devices,
|
||||||
|
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
|
||||||
|
legend : {size, rank}
|
||||||
|
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
|
||||||
|
- Comment : Tensors are sharded across 2 devices.
|
||||||
|
|
||||||
|
When TP = 1, DP = 2 and EP = False, the configuration on different
|
||||||
|
devices,
|
||||||
|
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
|
||||||
|
- device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0}
|
||||||
|
- Comment: There are 2 engine instances and the tensors are sharded
|
||||||
|
across 2 decvices.
|
||||||
|
|
||||||
|
When TP = 2, DP = 2 and EP = False, the configuration on different
|
||||||
|
devices,
|
||||||
|
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
|
||||||
|
- device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0}
|
||||||
|
- device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0}
|
||||||
|
- device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0}
|
||||||
|
- Comment: There are 2 engine instances and the tensors are sharded
|
||||||
|
across 4 devices.
|
||||||
|
|
||||||
|
When, TP = 2, DP = 1 and EP = True, the configuration on different
|
||||||
|
devices,
|
||||||
|
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
|
||||||
|
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
|
||||||
|
- Comment: The experts are split between the 2 devices.
|
||||||
|
|
||||||
|
When, TP = 1, DP = 2 and EP = True, the configuration on different
|
||||||
|
devices,
|
||||||
|
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
|
||||||
|
- device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1}
|
||||||
|
- Comment: There are 2 engine instances and the experts are split
|
||||||
|
between the 2 devices.
|
||||||
|
|
||||||
|
When TP = 2, DP = 2 and EP = True, the configuration on different
|
||||||
|
devices,
|
||||||
|
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
|
||||||
|
- device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1}
|
||||||
|
- device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2}
|
||||||
|
- device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3}
|
||||||
|
- Comment: There are 2 engine instances and the experts are split
|
||||||
|
between the 4 devices.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def flatten_tp_across_dp(dp_rank: int):
|
||||||
|
tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank()
|
||||||
|
# There are actually dp_size_ * tp_size_ devices. Update tp_size
|
||||||
|
# and tp_rank so we shard across all devices.
|
||||||
|
tp_size = dp_size_ * tp_size_
|
||||||
|
tp_rank = dp_rank * tp_size_ + tp_rank
|
||||||
|
return tp_size, tp_rank
|
||||||
|
|
||||||
|
use_ep = (dp_size_ * tp_size_ > 1
|
||||||
|
and vllm_parallel_config.enable_expert_parallel)
|
||||||
|
|
||||||
|
dp_size = dp_size_
|
||||||
|
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
|
||||||
|
tp_size, tp_rank = flatten_tp_across_dp(dp_rank)
|
||||||
|
|
||||||
|
if not use_ep:
|
||||||
|
return FusedMoEParallelConfig(tp_size=tp_size,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
dp_size=dp_size,
|
||||||
|
dp_rank=dp_rank,
|
||||||
|
ep_size=1,
|
||||||
|
ep_rank=0,
|
||||||
|
world_size=world_size_,
|
||||||
|
use_ep=False)
|
||||||
|
# DP + EP / TP + EP / DP + TP + EP
|
||||||
|
assert use_ep
|
||||||
|
# In EP, each device owns a set of experts fully. There is no tensor
|
||||||
|
# parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that.
|
||||||
|
ep_size = tp_size
|
||||||
|
ep_rank = tp_rank
|
||||||
|
return FusedMoEParallelConfig(tp_size=1,
|
||||||
|
tp_rank=0,
|
||||||
|
dp_size=dp_size,
|
||||||
|
dp_rank=dp_rank,
|
||||||
|
ep_size=ep_size,
|
||||||
|
ep_rank=ep_rank,
|
||||||
|
world_size=world_size_,
|
||||||
|
use_ep=True)
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from pplx-kernels tests/all_to_all_utils.py
|
||||||
|
@dataclass
|
||||||
|
class FusedMoEConfig:
|
||||||
|
num_experts: int
|
||||||
|
experts_per_token: int
|
||||||
|
hidden_dim: int
|
||||||
|
|
||||||
|
num_local_experts: int
|
||||||
|
moe_parallel_config: FusedMoEParallelConfig
|
||||||
|
|
||||||
|
# The activation type.
|
||||||
|
in_dtype: torch.dtype
|
||||||
|
|
||||||
|
quant_config: Optional[FusedMoEQuantConfig] = None
|
||||||
|
|
||||||
|
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.dp_size > 1:
|
||||||
|
logger.debug("Using FusedMoEConfig::max_num_tokens=%d",
|
||||||
|
self.max_num_tokens)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def quant_dtype(self) -> Optional[torch.dtype]:
|
||||||
|
if self.quant_config is not None:
|
||||||
|
return self.quant_config.quant_dtype
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def block_shape(self) -> Optional[list[int]]:
|
||||||
|
if self.quant_config is not None:
|
||||||
|
return self.quant_config.block_shape
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def per_act_token_quant(self) -> bool:
|
||||||
|
if self.quant_config is not None:
|
||||||
|
return self.quant_config.per_act_token_quant
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def per_out_ch_quant(self) -> bool:
|
||||||
|
if self.quant_config is not None:
|
||||||
|
return self.quant_config.per_out_ch_quant
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tp_size(self):
|
||||||
|
return self.moe_parallel_config.tp_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dp_size(self):
|
||||||
|
return self.moe_parallel_config.dp_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ep_size(self):
|
||||||
|
return self.moe_parallel_config.ep_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def world_size(self):
|
||||||
|
return self.moe_parallel_config.world_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tp_rank(self):
|
||||||
|
return self.moe_parallel_config.tp_rank
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dp_rank(self):
|
||||||
|
return self.moe_parallel_config.dp_rank
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ep_rank(self):
|
||||||
|
return self.moe_parallel_config.ep_rank
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_ep(self):
|
||||||
|
return self.moe_parallel_config.use_ep
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_pplx_kernels(self):
|
||||||
|
return self.moe_parallel_config.use_pplx_kernels
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_deepep_ht_kernels(self):
|
||||||
|
return self.moe_parallel_config.use_deepep_ht_kernels
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_deepep_ll_kernels(self):
|
||||||
|
return self.moe_parallel_config.use_deepep_ll_kernels
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make(
|
||||||
|
num_experts: int,
|
||||||
|
experts_per_token: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
num_local_experts: int,
|
||||||
|
moe_parallel_config: FusedMoEParallelConfig,
|
||||||
|
in_dtype: torch.dtype,
|
||||||
|
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||||
|
quant_config: Optional[Union[FusedMoEQuantConfig,
|
||||||
|
QuantizationConfig]] = None
|
||||||
|
) -> "FusedMoEConfig":
|
||||||
|
|
||||||
|
_quant_config: Optional[FusedMoEQuantConfig] = None
|
||||||
|
|
||||||
|
if quant_config is not None and isinstance(quant_config,
|
||||||
|
QuantizationConfig):
|
||||||
|
if hasattr(quant_config, 'weight_block_size'):
|
||||||
|
block_shape = quant_config.weight_block_size
|
||||||
|
else:
|
||||||
|
block_shape = None
|
||||||
|
per_act_token_quant = False
|
||||||
|
per_out_ch_quant = False
|
||||||
|
quant_dtype: Optional[torch.dtype] = None
|
||||||
|
|
||||||
|
input_quant = get_quant_config_input_quant(quant_config)
|
||||||
|
weight_quant = get_quant_config_weight_quant(quant_config)
|
||||||
|
|
||||||
|
if input_quant is not None:
|
||||||
|
per_act_token_quant = (input_quant.strategy
|
||||||
|
== QuantizationStrategy.TOKEN
|
||||||
|
if input_quant is not None else False)
|
||||||
|
|
||||||
|
if input_quant.num_bits == 8:
|
||||||
|
if input_quant.type == QuantizationType.FLOAT:
|
||||||
|
quant_dtype = torch.float8_e4m3fn
|
||||||
|
elif input_quant.type == QuantizationType.INT:
|
||||||
|
quant_dtype = torch.int8
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||||
|
if quant_dtype is None and isinstance(quant_config, Fp8Config):
|
||||||
|
quant_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
if weight_quant is not None:
|
||||||
|
per_out_ch_quant = (
|
||||||
|
weight_quant.strategy == QuantizationStrategy.CHANNEL)
|
||||||
|
|
||||||
|
if quant_dtype is not None:
|
||||||
|
_quant_config = FusedMoEQuantConfig(
|
||||||
|
quant_dtype=quant_dtype,
|
||||||
|
per_act_token_quant=per_act_token_quant,
|
||||||
|
per_out_ch_quant=per_out_ch_quant,
|
||||||
|
block_shape=block_shape,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_quant_config = FusedMoEQuantConfig()
|
||||||
|
logger.warning_once("MoE DP setup unable to determine "
|
||||||
|
"quantization scheme or unsupported "
|
||||||
|
"quantization type. This model will "
|
||||||
|
"not run with DP enabled.")
|
||||||
|
else:
|
||||||
|
_quant_config = quant_config
|
||||||
|
|
||||||
|
return FusedMoEConfig(
|
||||||
|
num_experts=num_experts,
|
||||||
|
experts_per_token=experts_per_token,
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
num_local_experts=num_local_experts,
|
||||||
|
moe_parallel_config=moe_parallel_config,
|
||||||
|
in_dtype=in_dtype,
|
||||||
|
quant_config=_quant_config,
|
||||||
|
max_num_tokens=max_num_tokens,
|
||||||
|
)
|
||||||
@@ -7,6 +7,7 @@ import torch
|
|||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||||
MoEPrepareAndFinalizeNoEP)
|
MoEPrepareAndFinalizeNoEP)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache
|
from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache
|
||||||
@@ -202,26 +203,47 @@ def run_cutlass_moe_fp8(
|
|||||||
|
|
||||||
|
|
||||||
# TODO (bnell): split class batched vs. non-batched?
|
# TODO (bnell): split class batched vs. non-batched?
|
||||||
|
# maybe remove need for passing aq to workspace_shapes
|
||||||
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_experts_per_worker: int,
|
max_experts_per_worker: int,
|
||||||
out_dtype: torch.dtype,
|
out_dtype: Optional[torch.dtype],
|
||||||
per_act_token: bool,
|
per_act_token_quant: bool,
|
||||||
per_out_ch: bool,
|
per_out_ch_quant: bool,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
use_batched_format: bool = False,
|
use_batched_format: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__(
|
||||||
|
FusedMoEQuantConfig(
|
||||||
|
quant_dtype=torch.float8_e4m3fn,
|
||||||
|
per_act_token_quant=per_act_token_quant,
|
||||||
|
per_out_ch_quant=per_out_ch_quant,
|
||||||
|
block_shape=block_shape,
|
||||||
|
))
|
||||||
|
assert max_experts_per_worker > 0
|
||||||
self.max_experts_per_worker = max_experts_per_worker
|
self.max_experts_per_worker = max_experts_per_worker
|
||||||
self.out_dtype = out_dtype
|
self.out_dtype = out_dtype
|
||||||
self.per_act_token = per_act_token
|
|
||||||
self.per_out_ch = per_out_ch
|
|
||||||
self.use_batched_format = use_batched_format
|
self.use_batched_format = use_batched_format
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activation_formats(
|
||||||
|
self
|
||||||
|
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||||
|
if self.use_batched_format:
|
||||||
|
return (mk.FusedMoEActivationFormat.BatchedExperts,
|
||||||
|
mk.FusedMoEActivationFormat.BatchedExperts)
|
||||||
|
else:
|
||||||
|
return (mk.FusedMoEActivationFormat.Standard,
|
||||||
|
mk.FusedMoEActivationFormat.Standard)
|
||||||
|
|
||||||
def supports_chunking(self) -> bool:
|
def supports_chunking(self) -> bool:
|
||||||
return not self.use_batched_format
|
return not self.use_batched_format
|
||||||
|
|
||||||
|
def supports_expert_map(self) -> bool:
|
||||||
|
return not self.use_batched_format
|
||||||
|
|
||||||
def workspace_shapes(
|
def workspace_shapes(
|
||||||
self,
|
self,
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
@@ -245,7 +267,8 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
workspace1 = (M * topk, max(2 * N, K))
|
workspace1 = (M * topk, max(2 * N, K))
|
||||||
workspace2 = (M * topk, N)
|
workspace2 = (M * topk, N)
|
||||||
output = (M * topk, K)
|
output = (M * topk, K)
|
||||||
return (workspace1, workspace2, output, self.out_dtype)
|
return (workspace1, workspace2, output,
|
||||||
|
self.out_dtype if self.out_dtype is not None else a.dtype)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@@ -270,13 +293,14 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
||||||
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
||||||
activation_callable = lambda i, o: self.activation(activation, i, o)
|
activation_callable = lambda i, o: self.activation(activation, i, o)
|
||||||
run_cutlass_moe_fp8(output, hidden_states, w1, w2, topk_ids,
|
in_dtype = hidden_states.dtype
|
||||||
activation_callable, global_num_experts,
|
run_cutlass_moe_fp8(
|
||||||
expert_map, w1_scale, w2_scale, a1q_scale,
|
output, hidden_states, w1, w2, topk_ids, activation_callable,
|
||||||
a2_scale, workspace13, workspace2,
|
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
|
||||||
expert_num_tokens, self.out_dtype,
|
a2_scale, workspace13, workspace2, expert_num_tokens,
|
||||||
self.per_act_token, self.per_out_ch,
|
self.out_dtype if self.out_dtype is not None else in_dtype,
|
||||||
self.use_batched_format)
|
self.per_act_token_quant, self.per_out_ch_quant,
|
||||||
|
self.use_batched_format)
|
||||||
|
|
||||||
|
|
||||||
def cutlass_moe_fp8(
|
def cutlass_moe_fp8(
|
||||||
@@ -287,6 +311,7 @@ def cutlass_moe_fp8(
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
w1_scale: torch.Tensor,
|
w1_scale: torch.Tensor,
|
||||||
w2_scale: torch.Tensor,
|
w2_scale: torch.Tensor,
|
||||||
|
per_act_token: bool,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
@@ -330,22 +355,18 @@ def cutlass_moe_fp8(
|
|||||||
Returns:
|
Returns:
|
||||||
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
|
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
|
||||||
"""
|
"""
|
||||||
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
|
||||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
|
||||||
per_out_ch = w1_scale.numel() != w1_q.size(0)
|
per_out_ch = w1_scale.numel() != w1_q.size(0)
|
||||||
|
|
||||||
out_dtype = a.dtype
|
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(
|
||||||
|
0)
|
||||||
|
|
||||||
fn = mk.FusedMoEModularKernel(
|
fn = mk.FusedMoEModularKernel(
|
||||||
MoEPrepareAndFinalizeNoEP(
|
MoEPrepareAndFinalizeNoEP(),
|
||||||
quant_dtype=torch.float8_e4m3fn,
|
|
||||||
per_channel_quant=per_act_token,
|
|
||||||
),
|
|
||||||
CutlassExpertsFp8(
|
CutlassExpertsFp8(
|
||||||
max_experts_per_worker=global_num_experts,
|
max_experts_per_worker=num_experts,
|
||||||
out_dtype=out_dtype,
|
out_dtype=a.dtype,
|
||||||
per_act_token=per_act_token,
|
per_act_token_quant=per_act_token,
|
||||||
per_out_ch=per_out_ch,
|
per_out_ch_quant=per_out_ch,
|
||||||
use_batched_format=False,
|
use_batched_format=False,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -358,7 +379,7 @@ def cutlass_moe_fp8(
|
|||||||
topk_ids,
|
topk_ids,
|
||||||
False,
|
False,
|
||||||
activation,
|
activation,
|
||||||
global_num_experts if global_num_experts != -1 else w1_q.size(0),
|
num_experts,
|
||||||
expert_map,
|
expert_map,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
w2_scale,
|
w2_scale,
|
||||||
|
|||||||
@@ -7,13 +7,13 @@ import torch
|
|||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||||
_moe_permute)
|
_moe_permute)
|
||||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||||
MoEPrepareAndFinalizeNoEP)
|
MoEPrepareAndFinalizeNoEP)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
_resize_cache, per_token_group_quant_fp8)
|
||||||
per_token_group_quant_fp8)
|
|
||||||
from vllm.utils import has_deep_gemm, round_up
|
from vllm.utils import has_deep_gemm, round_up
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@@ -65,16 +65,31 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
|
|||||||
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__(
|
||||||
self.block_shape = deep_gemm_block_shape()
|
FusedMoEQuantConfig(
|
||||||
|
quant_dtype=torch.float8_e4m3fn,
|
||||||
|
per_act_token_quant=False,
|
||||||
|
block_shape=deep_gemm_block_shape(),
|
||||||
|
))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activation_formats(
|
||||||
|
self
|
||||||
|
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||||
|
return (mk.FusedMoEActivationFormat.Standard,
|
||||||
|
mk.FusedMoEActivationFormat.Standard)
|
||||||
|
|
||||||
def supports_chunking(self) -> bool:
|
def supports_chunking(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def supports_expert_map(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
def workspace_shapes(
|
def workspace_shapes(
|
||||||
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
|
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
|
||||||
topk: int, global_num_experts: int, local_num_experts: int
|
topk: int, global_num_experts: int, local_num_experts: int
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||||
|
assert self.block_shape is not None
|
||||||
# We use global_num_experts due to how moe_align_block_size handles
|
# We use global_num_experts due to how moe_align_block_size handles
|
||||||
# expert_maps.
|
# expert_maps.
|
||||||
num_experts = global_num_experts
|
num_experts = global_num_experts
|
||||||
@@ -107,6 +122,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
expert_num_tokens: Optional[torch.Tensor],
|
expert_num_tokens: Optional[torch.Tensor],
|
||||||
):
|
):
|
||||||
import deep_gemm as dg
|
import deep_gemm as dg
|
||||||
|
assert self.block_shape is not None
|
||||||
|
|
||||||
a1q = hidden_states
|
a1q = hidden_states
|
||||||
_, N, K = w1.size()
|
_, N, K = w1.size()
|
||||||
@@ -213,8 +229,7 @@ def deep_gemm_moe_fp8(
|
|||||||
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
|
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
|
||||||
"""
|
"""
|
||||||
fn = mk.FusedMoEModularKernel(
|
fn = mk.FusedMoEModularKernel(
|
||||||
MoEPrepareAndFinalizeNoEP(quant_dtype=torch.float8_e4m3fn,
|
MoEPrepareAndFinalizeNoEP(),
|
||||||
block_shape=deep_gemm_block_shape()),
|
|
||||||
DeepGemmExperts(),
|
DeepGemmExperts(),
|
||||||
)
|
)
|
||||||
return fn(
|
return fn(
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import torch
|
|||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.fused_moe.utils import (
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
moe_kernel_quantize_input)
|
moe_kernel_quantize_input)
|
||||||
|
|
||||||
@@ -15,22 +16,14 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
Prepare/Finalize using DeepEP High-Throughput kernels.
|
Prepare/Finalize using DeepEP High-Throughput kernels.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, buffer: deep_ep.Buffer, world_size: int, rank: int,
|
||||||
buffer: deep_ep.Buffer,
|
dp_size: int, rank_expert_offset: int):
|
||||||
world_size: int,
|
|
||||||
rank: int,
|
|
||||||
dp_size: int,
|
|
||||||
rank_expert_offset: int,
|
|
||||||
quant_dtype: Optional[torch.dtype] = None,
|
|
||||||
block_shape: Optional[list[int]] = None):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.buffer = buffer
|
self.buffer = buffer
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.dp_size = dp_size
|
self.dp_size = dp_size
|
||||||
self.rank_expert_offset = rank_expert_offset
|
self.rank_expert_offset = rank_expert_offset
|
||||||
self.quant_dtype = quant_dtype
|
|
||||||
self.block_shape = block_shape
|
|
||||||
# The dispatch function returns a handle that the combine function
|
# The dispatch function returns a handle that the combine function
|
||||||
# requires. We store the handle here so it is available to the
|
# requires. We store the handle here so it is available to the
|
||||||
# combine function.
|
# combine function.
|
||||||
@@ -39,6 +32,10 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
# From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164
|
# From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164
|
||||||
self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160]
|
self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||||
|
return mk.FusedMoEActivationFormat.Standard
|
||||||
|
|
||||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -55,13 +52,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
return None
|
return None
|
||||||
return deep_ep.Buffer.get_combine_config(self.dp_size)
|
return deep_ep.Buffer.get_combine_config(self.dp_size)
|
||||||
|
|
||||||
def _do_quant(self, tokens: torch.Tensor,
|
|
||||||
token_scales: Optional[torch.Tensor], per_act_token: bool):
|
|
||||||
tokens, token_scales = moe_kernel_quantize_input(
|
|
||||||
tokens, token_scales, self.quant_dtype, per_act_token,
|
|
||||||
self.block_shape)
|
|
||||||
return tokens, token_scales
|
|
||||||
|
|
||||||
def _do_dispatch(self, tokens: torch.Tensor,
|
def _do_dispatch(self, tokens: torch.Tensor,
|
||||||
token_scales: Optional[torch.Tensor],
|
token_scales: Optional[torch.Tensor],
|
||||||
rank_topk_ids: torch.Tensor,
|
rank_topk_ids: torch.Tensor,
|
||||||
@@ -130,43 +120,51 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
a1: torch.Tensor,
|
a1: torch.Tensor,
|
||||||
a1_scale: Optional[torch.Tensor],
|
a1_scale: Optional[torch.Tensor],
|
||||||
a2_scale: Optional[torch.Tensor],
|
a2_scale: Optional[torch.Tensor],
|
||||||
rank_topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
rank_topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
|
quant_config: FusedMoEQuantConfig,
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
|
|
||||||
if apply_router_weight_on_input:
|
if apply_router_weight_on_input:
|
||||||
topk = rank_topk_ids.size(1)
|
topk = topk_ids.size(1)
|
||||||
# TODO: this only works for topK=1, will need to update for topK>1
|
# TODO: this only works for topK=1, will need to update for topK>1
|
||||||
assert topk == 1, (
|
assert topk == 1, (
|
||||||
"apply_router_weight_on_input is only implemented for topk=1")
|
"apply_router_weight_on_input is only implemented for topk=1")
|
||||||
a1 = a1 * rank_topk_weights.to(a1.dtype)
|
a1 = a1 * topk_weights.to(a1.dtype)
|
||||||
|
|
||||||
# Check if there is a block_shape / or if we can infer the quantization
|
# Check if there is a block_shape / or if we can infer the quantization
|
||||||
# schemes from the scales.
|
# schemes from the scales.
|
||||||
per_token_quant = None
|
per_token_quant = None
|
||||||
if all([x is None for x in [self.block_shape, a1_scale, a2_scale]
|
if all([
|
||||||
]) and self.quant_dtype is not None:
|
x is None
|
||||||
|
for x in [quant_config.block_shape, a1_scale, a2_scale]
|
||||||
|
]) and quant_config.quant_dtype is not None:
|
||||||
# Quantization required despite none of the inputs suggesting
|
# Quantization required despite none of the inputs suggesting
|
||||||
# quantization. Fallback to per_token_dynamic quant.
|
# quantization. Fallback to per_token_dynamic quant.
|
||||||
per_token_quant = True
|
per_token_quant = True
|
||||||
else:
|
else:
|
||||||
per_token_quant = ((self.block_shape is not None) or
|
per_token_quant = False
|
||||||
(a1_scale is not None and a1_scale.numel() != 1)
|
|
||||||
or (a2_scale is not None
|
|
||||||
and a2_scale.numel() != 1))
|
|
||||||
|
|
||||||
if per_token_quant:
|
if per_token_quant:
|
||||||
a1q, a1q_scale = self._do_quant(a1, a1_scale, per_act_token=True)
|
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||||
|
a1,
|
||||||
|
a1_scale,
|
||||||
|
quant_dtype=quant_config.quant_dtype,
|
||||||
|
per_act_token_quant=True,
|
||||||
|
block_shape=quant_config.block_shape,
|
||||||
|
)
|
||||||
|
if a1q_scale is not None and a1q_scale.numel() == 1:
|
||||||
|
a1q_scale = a1q_scale.view(1, 1)
|
||||||
(expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
|
(expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
|
||||||
expert_topk_weights) = self._do_dispatch(
|
expert_topk_weights) = self._do_dispatch(
|
||||||
tokens=a1q,
|
tokens=a1q,
|
||||||
token_scales=a1q_scale,
|
token_scales=a1q_scale,
|
||||||
rank_topk_ids=rank_topk_ids,
|
rank_topk_ids=topk_ids,
|
||||||
rank_topk_weights=rank_topk_weights,
|
rank_topk_weights=topk_weights,
|
||||||
num_experts=num_experts)
|
num_experts=num_experts)
|
||||||
else:
|
else:
|
||||||
# DeepEP kernels only support dispatching per-token-quant
|
# DeepEP kernels only support dispatching per-token-quant
|
||||||
@@ -175,15 +173,18 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
expert_topk_weights) = self._do_dispatch(
|
expert_topk_weights) = self._do_dispatch(
|
||||||
tokens=a1,
|
tokens=a1,
|
||||||
token_scales=None,
|
token_scales=None,
|
||||||
rank_topk_ids=rank_topk_ids,
|
rank_topk_ids=topk_ids,
|
||||||
rank_topk_weights=rank_topk_weights,
|
rank_topk_weights=topk_weights,
|
||||||
num_experts=num_experts)
|
num_experts=num_experts)
|
||||||
# quantize now
|
# quantize now
|
||||||
expert_x_scale = None
|
expert_x_scale = None
|
||||||
if expert_x.numel() != 0:
|
if expert_x.numel() != 0:
|
||||||
expert_x, expert_x_scale = self._do_quant(expert_x,
|
expert_x, expert_x_scale = moe_kernel_quantize_input(
|
||||||
a1_scale,
|
expert_x,
|
||||||
per_act_token=False)
|
a1_scale,
|
||||||
|
quant_dtype=quant_config.quant_dtype,
|
||||||
|
per_act_token_quant=False,
|
||||||
|
block_shape=quant_config.block_shape)
|
||||||
|
|
||||||
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
|
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
|
||||||
expert_topk_weights)
|
expert_topk_weights)
|
||||||
|
|||||||
@@ -5,11 +5,13 @@ import deep_ep
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.fused_moe.utils import (
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
moe_kernel_quantize_input)
|
maybe_fix_scales, moe_kernel_quantize_input)
|
||||||
|
|
||||||
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
|
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
|
||||||
DEEPEP_QUANT_BLOCK_SIZE = 128
|
DEEPEP_QUANT_BLOCK_SIZE = 128
|
||||||
|
DEEPEP_QUANT_BLOCK_SHAPE = [DEEPEP_QUANT_BLOCK_SIZE, DEEPEP_QUANT_BLOCK_SIZE]
|
||||||
|
|
||||||
|
|
||||||
def dequant_fp8(expert_x_fp8: torch.Tensor,
|
def dequant_fp8(expert_x_fp8: torch.Tensor,
|
||||||
@@ -35,30 +37,30 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
|
|
||||||
# DeepEP low-latency kernels are compiled only for certain
|
# DeepEP low-latency kernels are compiled only for certain
|
||||||
# specific hidden sizes.
|
# specific hidden sizes.
|
||||||
SUPPORTED_HIDDEN_SIZES = [2560, 4096, 5120, 7168]
|
SUPPORTED_HIDDEN_SIZES = [2048, 2560, 4096, 5120, 7168]
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
buffer: deep_ep.Buffer,
|
buffer: deep_ep.Buffer,
|
||||||
|
max_tokens_per_rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
dp_size: int,
|
dp_size: int,
|
||||||
max_tokens_per_rank: int,
|
|
||||||
quant_dtype: Optional[torch.dtype] = None,
|
|
||||||
block_shape: Optional[list[int]] = None,
|
|
||||||
use_fp8_dispatch: bool = False):
|
use_fp8_dispatch: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.buffer = buffer
|
self.buffer = buffer
|
||||||
|
self.max_tokens_per_rank = max_tokens_per_rank
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.dp_size = dp_size
|
self.dp_size = dp_size
|
||||||
self.quant_dtype = quant_dtype
|
|
||||||
self.block_shape = block_shape
|
|
||||||
self.max_tokens_per_rank = max_tokens_per_rank
|
|
||||||
self.use_fp8_dispatch = use_fp8_dispatch
|
self.use_fp8_dispatch = use_fp8_dispatch
|
||||||
# The dispatch function returns a handle that the combine function
|
# The dispatch function returns a handle that the combine function
|
||||||
# requires. We store the handle here so it is available to the
|
# requires. We store the handle here so it is available to the
|
||||||
# combine function.
|
# combine function.
|
||||||
self.handle = None
|
self.handle = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||||
|
return mk.FusedMoEActivationFormat.BatchedExperts
|
||||||
|
|
||||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||||
return self.max_tokens_per_rank
|
return self.max_tokens_per_rank
|
||||||
|
|
||||||
@@ -66,12 +68,17 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
return torch.int64
|
return torch.int64
|
||||||
|
|
||||||
def _do_quant(
|
def _do_quant(
|
||||||
self, x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
self,
|
||||||
a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
|
x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||||
a1_dtype: torch.dtype
|
a1_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
a1_dtype: torch.dtype,
|
||||||
|
quant_dtype: Optional[torch.dtype],
|
||||||
|
per_act_token_quant: bool,
|
||||||
|
block_shape: Optional[list[int]],
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
|
||||||
block_k = self.block_shape[1] if self.block_shape is not None else None
|
block_k = block_shape[1] if block_shape is not None else None
|
||||||
if self.use_fp8_dispatch:
|
if self.use_fp8_dispatch:
|
||||||
if block_k == DEEPEP_QUANT_BLOCK_SIZE:
|
if block_k == DEEPEP_QUANT_BLOCK_SIZE:
|
||||||
# DeepEP kernels did the quantization for us.
|
# DeepEP kernels did the quantization for us.
|
||||||
@@ -84,32 +91,20 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
|
|
||||||
assert isinstance(x, torch.Tensor)
|
assert isinstance(x, torch.Tensor)
|
||||||
|
|
||||||
# Check if there is a block_shape / or if we can infer the quantization
|
assert not per_act_token_quant
|
||||||
# schemes from the scales.
|
|
||||||
per_token_quant = None
|
|
||||||
if all([v is None for v in [self.block_shape, a1_scale, a2_scale]
|
|
||||||
]) and self.quant_dtype is not None:
|
|
||||||
# Quantization required despite none of the inputs suggesting
|
|
||||||
# quantization. Fallback to per_token_dynamic quant.
|
|
||||||
per_token_quant = True
|
|
||||||
else:
|
|
||||||
per_token_quant = ((self.block_shape is not None) or
|
|
||||||
(a1_scale is not None and a1_scale.numel() != 1)
|
|
||||||
or (a2_scale is not None
|
|
||||||
and a2_scale.numel() != 1))
|
|
||||||
|
|
||||||
num_experts, max_tokens, hidden_dim = x.size()
|
num_experts, max_tokens, hidden_dim = x.size()
|
||||||
|
|
||||||
# TODO (varun): Optimization - Use a batched version of quant
|
# TODO (varun): Optimization - Use a batched version of quant
|
||||||
x = x.view((-1, hidden_dim))
|
x = x.view((-1, hidden_dim))
|
||||||
x, x_scales = moe_kernel_quantize_input(x, a1_scale, self.quant_dtype,
|
x, x_scales = moe_kernel_quantize_input(x, a1_scale, quant_dtype,
|
||||||
per_token_quant,
|
per_act_token_quant,
|
||||||
self.block_shape)
|
block_shape)
|
||||||
x = x.view((num_experts, -1, hidden_dim))
|
x = x.view((num_experts, -1, hidden_dim))
|
||||||
|
|
||||||
if per_token_quant:
|
if quant_dtype is not None:
|
||||||
assert x_scales is not None
|
assert x_scales is not None
|
||||||
x_scales = x_scales.view(num_experts, max_tokens, -1)
|
x_scales = maybe_fix_scales(x_scales, num_experts)
|
||||||
|
|
||||||
return x, x_scales
|
return x, x_scales
|
||||||
|
|
||||||
@@ -118,11 +113,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
a1: torch.Tensor,
|
a1: torch.Tensor,
|
||||||
a1_scale: Optional[torch.Tensor],
|
a1_scale: Optional[torch.Tensor],
|
||||||
a2_scale: Optional[torch.Tensor],
|
a2_scale: Optional[torch.Tensor],
|
||||||
rank_topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
rank_topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
|
quant_config: FusedMoEQuantConfig,
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
|
|
||||||
@@ -142,24 +138,25 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
"low_latency kernels doesn't support dispatching per-token scales")
|
"low_latency kernels doesn't support dispatching per-token scales")
|
||||||
|
|
||||||
if apply_router_weight_on_input:
|
if apply_router_weight_on_input:
|
||||||
topk = rank_topk_ids.size(1)
|
topk = topk_ids.size(1)
|
||||||
# TODO: this only works for topK=1, will need to update for topK>1
|
# TODO: this only works for topK=1, will need to update for topK>1
|
||||||
assert topk == 1, (
|
assert topk == 1, (
|
||||||
"apply_router_weight_on_input is only implemented for topk=1")
|
"apply_router_weight_on_input is only implemented for topk=1")
|
||||||
a1 = a1 * rank_topk_weights.to(a1.dtype)
|
a1 = a1 * topk_weights.to(a1.dtype)
|
||||||
|
|
||||||
# Dispatch
|
# Dispatch
|
||||||
expert_x, expert_num_tokens, self.handle, event, hook = \
|
expert_x, expert_num_tokens, self.handle, event, hook = \
|
||||||
self.buffer.low_latency_dispatch(a1,
|
self.buffer.low_latency_dispatch(a1,
|
||||||
rank_topk_ids,
|
topk_ids,
|
||||||
self.max_tokens_per_rank,
|
self.max_tokens_per_rank,
|
||||||
num_experts,
|
num_experts,
|
||||||
use_fp8=self.use_fp8_dispatch,
|
use_fp8=self.use_fp8_dispatch,
|
||||||
async_finish=False,
|
async_finish=False,
|
||||||
return_recv_hook=False)
|
return_recv_hook=False)
|
||||||
|
|
||||||
expert_x, expert_x_scale = self._do_quant(expert_x, a1_scale, a2_scale,
|
expert_x, expert_x_scale = self._do_quant(
|
||||||
a1.dtype)
|
expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype,
|
||||||
|
quant_config.per_act_token_quant, quant_config.block_shape)
|
||||||
|
|
||||||
return (expert_x, expert_x_scale, expert_num_tokens, None, None)
|
return (expert_x, expert_x_scale, expert_num_tokens, None, None)
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
get_config_dtype_str, try_get_optimal_moe_config)
|
get_config_dtype_str, try_get_optimal_moe_config)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import (
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
@@ -317,8 +318,8 @@ def invoke_moe_batched_triton_kernel(
|
|||||||
expert_num_tokens: torch.Tensor, # [E]
|
expert_num_tokens: torch.Tensor, # [E]
|
||||||
compute_type: tl.dtype,
|
compute_type: tl.dtype,
|
||||||
# Quantization data
|
# Quantization data
|
||||||
A_scale: torch.Tensor,
|
A_scale: Optional[torch.Tensor],
|
||||||
B_scale: torch.Tensor,
|
B_scale: Optional[torch.Tensor],
|
||||||
B_zp: torch.Tensor,
|
B_zp: torch.Tensor,
|
||||||
# Quantization schemes
|
# Quantization schemes
|
||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool,
|
||||||
@@ -387,14 +388,23 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
that the PPLX dispatch/combine kernels use.
|
that the PPLX dispatch/combine kernels use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, max_num_tokens: int, world_size: int, dp_size: int,
|
def __init__(
|
||||||
rank: int):
|
self,
|
||||||
|
max_num_tokens: int,
|
||||||
|
world_size: int,
|
||||||
|
dp_size: int,
|
||||||
|
rank: int,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.dp_size = dp_size
|
self.dp_size = dp_size
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.max_num_tokens = max_num_tokens
|
self.max_num_tokens = max_num_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||||
|
return mk.FusedMoEActivationFormat.BatchedExperts
|
||||||
|
|
||||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||||
return self.max_num_tokens
|
return self.max_num_tokens
|
||||||
|
|
||||||
@@ -411,6 +421,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
num_experts: int,
|
num_experts: int,
|
||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
|
quant_config: FusedMoEQuantConfig,
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
assert a1.dim() == 2
|
assert a1.dim() == 2
|
||||||
@@ -435,22 +446,35 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
|
|
||||||
num_local_experts = num_experts // self.world_size
|
num_local_experts = num_experts // self.world_size
|
||||||
|
|
||||||
|
if quant_config.quant_dtype is None:
|
||||||
|
b_type = a1.dtype
|
||||||
|
else:
|
||||||
|
b_type = quant_config.quant_dtype
|
||||||
|
|
||||||
b_a1 = torch.zeros(
|
b_a1 = torch.zeros(
|
||||||
(num_local_experts, self.max_num_tokens, hidden_dim),
|
(num_local_experts, self.max_num_tokens, hidden_dim),
|
||||||
dtype=a1.dtype,
|
dtype=b_type,
|
||||||
device=a1.device)
|
device=a1.device)
|
||||||
|
|
||||||
|
b_a1_scale = None
|
||||||
|
|
||||||
|
assert quant_config.quant_dtype is None, "quantization NYI"
|
||||||
|
|
||||||
first_expert = num_local_experts * self.rank
|
first_expert = num_local_experts * self.rank
|
||||||
last_expert = first_expert + num_local_experts
|
last_expert = first_expert + num_local_experts
|
||||||
|
|
||||||
for expert_id in range(first_expert, last_expert):
|
for expert_id in range(first_expert, last_expert):
|
||||||
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
|
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
|
||||||
rows = torch.count_nonzero(topks.flatten())
|
rows = torch.count_nonzero(topks.flatten())
|
||||||
b_a1[expert_id -
|
if rows == 0:
|
||||||
first_expert, :rows, :] = a1[:topks.numel()][topks]
|
continue
|
||||||
tokens_per_expert[expert_id - first_expert] = rows
|
idx = expert_id - first_expert
|
||||||
|
b_a1[idx, :rows, :] = a1[:topks.numel()][topks]
|
||||||
|
tokens_per_expert[idx] = rows
|
||||||
|
|
||||||
return b_a1, a1_scale, tokens_per_expert, None, None
|
assert b_a1_scale is None or b_a1_scale.ndim == 3
|
||||||
|
|
||||||
|
return b_a1, b_a1_scale, tokens_per_expert, None, None
|
||||||
|
|
||||||
def finalize(
|
def finalize(
|
||||||
self,
|
self,
|
||||||
@@ -480,7 +504,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
output[topks] = output[topks] + rhs
|
output[topks] = output[topks] + rhs
|
||||||
|
|
||||||
|
|
||||||
class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
"""
|
"""
|
||||||
A reference MoE expert class that operates on expert batched format,
|
A reference MoE expert class that operates on expert batched format,
|
||||||
i.e. E x max_num_tokens x K. This is the format that the pplx
|
i.e. E x max_num_tokens x K. This is the format that the pplx
|
||||||
@@ -497,11 +521,17 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
use_int4_w4a16: bool = False,
|
use_int4_w4a16: bool = False,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
block_m: Optional[int] = None,
|
per_act_token_quant: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__(
|
||||||
assert block_shape is None
|
FusedMoEQuantConfig.make(
|
||||||
assert block_m is None
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a8=use_int8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
use_int4_w4a16=use_int4_w4a16,
|
||||||
|
per_act_token_quant=per_act_token_quant,
|
||||||
|
block_shape=block_shape,
|
||||||
|
))
|
||||||
assert not use_fp8_w8a8, "NYI"
|
assert not use_fp8_w8a8, "NYI"
|
||||||
assert not use_int8_w8a8, "NYI"
|
assert not use_int8_w8a8, "NYI"
|
||||||
assert not use_int8_w8a16, "NYI"
|
assert not use_int8_w8a16, "NYI"
|
||||||
@@ -510,9 +540,19 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.dp_size = dp_size
|
self.dp_size = dp_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activation_formats(
|
||||||
|
self
|
||||||
|
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||||
|
return (mk.FusedMoEActivationFormat.BatchedExperts,
|
||||||
|
mk.FusedMoEActivationFormat.BatchedExperts)
|
||||||
|
|
||||||
def supports_chunking(self) -> bool:
|
def supports_chunking(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def supports_expert_map(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
def workspace_shapes(
|
def workspace_shapes(
|
||||||
self,
|
self,
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
@@ -554,20 +594,12 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
assert hidden_states.dim() == 3
|
assert hidden_states.dim() == 3
|
||||||
assert expert_num_tokens is not None
|
assert expert_num_tokens is not None
|
||||||
|
|
||||||
max_num_tokens = self.max_num_tokens
|
|
||||||
num_dp = self.world_size // self.dp_size
|
|
||||||
num_local_experts = w1.size(0)
|
num_local_experts = w1.size(0)
|
||||||
assert num_local_experts == w1.size(0), (
|
assert num_local_experts == w1.size(0), (
|
||||||
f"{num_local_experts} == {w1.size(0)}")
|
f"{num_local_experts} == {w1.size(0)}")
|
||||||
|
|
||||||
N = w1.size(1) // 2
|
N = w1.size(1) // 2
|
||||||
|
|
||||||
# Not cudagraph friendly
|
|
||||||
assert (torch.compiler.is_compiling()
|
|
||||||
or torch.cuda.is_current_stream_capturing()
|
|
||||||
or torch.all(expert_num_tokens <= max_num_tokens * num_dp)), (
|
|
||||||
f"{expert_num_tokens} <= {max_num_tokens * num_dp}")
|
|
||||||
|
|
||||||
for expert in range(num_local_experts):
|
for expert in range(num_local_experts):
|
||||||
# Indexing expert_num_tokens doesn't work w/cudagraphs or inductor
|
# Indexing expert_num_tokens doesn't work w/cudagraphs or inductor
|
||||||
if (torch.compiler.is_compiling()
|
if (torch.compiler.is_compiling()
|
||||||
@@ -575,6 +607,10 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
num = hidden_states.shape[1]
|
num = hidden_states.shape[1]
|
||||||
else:
|
else:
|
||||||
num = int(expert_num_tokens[expert].item())
|
num = int(expert_num_tokens[expert].item())
|
||||||
|
|
||||||
|
if num == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
tmp = _resize_cache(workspace2, (num, N))
|
tmp = _resize_cache(workspace2, (num, N))
|
||||||
input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1)
|
input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1)
|
||||||
self.activation(activation, tmp, input)
|
self.activation(activation, tmp, input)
|
||||||
@@ -590,34 +626,53 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_num_tokens: Optional[int] = None,
|
max_num_tokens: int,
|
||||||
|
world_size: int,
|
||||||
|
dp_size: int,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
use_int4_w4a16: bool = False,
|
use_int4_w4a16: bool = False,
|
||||||
per_channel_quant: bool = False,
|
per_act_token_quant: bool = False,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
world_size: int = 1,
|
|
||||||
dp_size: int = 1,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__(
|
||||||
|
FusedMoEQuantConfig.make(
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a8=use_int8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
use_int4_w4a16=use_int4_w4a16,
|
||||||
|
per_act_token_quant=per_act_token_quant,
|
||||||
|
block_shape=block_shape,
|
||||||
|
))
|
||||||
|
assert not use_int8_w8a8, "NYI"
|
||||||
|
assert not use_int8_w8a16, "NYI"
|
||||||
|
assert not use_int4_w4a16, "NYI"
|
||||||
self.use_fp8_w8a8 = use_fp8_w8a8
|
self.use_fp8_w8a8 = use_fp8_w8a8
|
||||||
self.use_int8_w8a8 = use_int8_w8a8
|
self.use_int8_w8a8 = use_int8_w8a8
|
||||||
self.use_int4_w4a16 = use_int4_w4a16
|
self.use_int4_w4a16 = use_int4_w4a16
|
||||||
self.use_int8_w8a16 = use_int8_w8a16
|
self.use_int8_w8a16 = use_int8_w8a16
|
||||||
self.block_shape = block_shape
|
|
||||||
self.per_channel_quant = per_channel_quant
|
|
||||||
self.max_num_tokens = max_num_tokens
|
self.max_num_tokens = max_num_tokens
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.dp_size = dp_size
|
self.dp_size = dp_size
|
||||||
|
assert world_size > 0
|
||||||
|
assert dp_size > 0
|
||||||
|
assert dp_size <= world_size
|
||||||
|
assert max_num_tokens > 0
|
||||||
|
|
||||||
assert not use_int8_w8a8, "NYI"
|
@property
|
||||||
assert not use_int4_w4a16, "NYI"
|
def activation_formats(
|
||||||
assert self.block_shape is None, "NYI"
|
self
|
||||||
|
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||||
|
return (mk.FusedMoEActivationFormat.BatchedExperts,
|
||||||
|
mk.FusedMoEActivationFormat.BatchedExperts)
|
||||||
|
|
||||||
def supports_chunking(self) -> bool:
|
def supports_chunking(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def supports_expert_map(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
def workspace_shapes(
|
def workspace_shapes(
|
||||||
self,
|
self,
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
@@ -630,10 +685,9 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||||
assert a.dim() == 2
|
assert a.dim() == 2
|
||||||
num_dp = self.world_size // self.dp_size
|
num_dp = self.world_size
|
||||||
num_experts = local_num_experts
|
num_experts = local_num_experts
|
||||||
max_num_tokens = a.size(
|
max_num_tokens = self.max_num_tokens
|
||||||
0) if self.max_num_tokens is None else self.max_num_tokens
|
|
||||||
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
|
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
|
||||||
workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2))
|
workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2))
|
||||||
output = (num_experts, max_num_tokens * num_dp, K)
|
output = (num_experts, max_num_tokens * num_dp, K)
|
||||||
@@ -708,7 +762,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported compute_type: {hidden_states.dtype}")
|
f"Unsupported compute_type: {hidden_states.dtype}")
|
||||||
|
|
||||||
#print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}")
|
|
||||||
# We can reuse the memory between these because by the time we need
|
# We can reuse the memory between these because by the time we need
|
||||||
# cache3, we're done with cache1
|
# cache3, we're done with cache1
|
||||||
intermediate_cache1 = _resize_cache(workspace13,
|
intermediate_cache1 = _resize_cache(workspace13,
|
||||||
@@ -734,6 +787,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
config=config,
|
config=config,
|
||||||
block_shape=self.block_shape)
|
block_shape=self.block_shape)
|
||||||
|
|
||||||
|
intermediate_cache2.fill_(0)
|
||||||
|
|
||||||
# TODO: would be nice to use expert_num_tokens here to reduce
|
# TODO: would be nice to use expert_num_tokens here to reduce
|
||||||
# garbage compute
|
# garbage compute
|
||||||
self.activation(activation, intermediate_cache2.view(-1, N // 2),
|
self.activation(activation, intermediate_cache2.view(-1, N // 2),
|
||||||
@@ -745,8 +800,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||||
A=intermediate_cache2,
|
A=intermediate_cache2,
|
||||||
A_scale=a2_scale,
|
A_scale=a2_scale,
|
||||||
qtype=torch.float8_e4m3fn if self.use_fp8_w8a8 else None,
|
quant_dtype=self.quant_dtype,
|
||||||
per_channel_quant=self.per_channel_quant,
|
per_act_token_quant=self.per_act_token_quant,
|
||||||
block_shape=self.block_shape)
|
block_shape=self.block_shape)
|
||||||
|
|
||||||
qintermediate_cache2 = qintermediate_cache2.view(
|
qintermediate_cache2 = qintermediate_cache2.view(
|
||||||
|
|||||||
@@ -12,6 +12,10 @@ import vllm.envs as envs
|
|||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
# yapf: disable
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FusedMoEQuantConfig, get_config_quant_dtype)
|
||||||
|
# yapf: enable
|
||||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||||
_valid_deep_gemm, deep_gemm_moe_fp8)
|
_valid_deep_gemm, deep_gemm_moe_fp8)
|
||||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||||
@@ -980,20 +984,6 @@ def get_config_dtype_str(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
# TODO (bnell): use scalar_type instead of bools?
|
|
||||||
def get_config_qtype(
|
|
||||||
use_fp8_w8a8: bool,
|
|
||||||
use_int8_w8a8: bool,
|
|
||||||
use_int8_w8a16: bool,
|
|
||||||
use_int4_w4a16: bool,
|
|
||||||
) -> Optional[torch.dtype]:
|
|
||||||
if use_fp8_w8a8:
|
|
||||||
return torch.float8_e4m3fn
|
|
||||||
elif use_int8_w8a8:
|
|
||||||
return torch.int8
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def inplace_fused_experts(hidden_states: torch.Tensor,
|
def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
@@ -1262,10 +1252,10 @@ def fused_experts_impl(
|
|||||||
use_int4_w4a16=use_int4_w4a16,
|
use_int4_w4a16=use_int4_w4a16,
|
||||||
dtype=hidden_states.dtype)
|
dtype=hidden_states.dtype)
|
||||||
|
|
||||||
qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8,
|
qtype = get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8,
|
||||||
use_int8_w8a8=use_int8_w8a8,
|
use_int8_w8a8=use_int8_w8a8,
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
use_int4_w4a16=use_int4_w4a16)
|
use_int4_w4a16=use_int4_w4a16)
|
||||||
|
|
||||||
get_config_func = functools.partial(
|
get_config_func = functools.partial(
|
||||||
try_get_optimal_moe_config,
|
try_get_optimal_moe_config,
|
||||||
@@ -1332,8 +1322,8 @@ def fused_experts_impl(
|
|||||||
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
|
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
|
||||||
A=curr_hidden_states,
|
A=curr_hidden_states,
|
||||||
A_scale=a1_scale,
|
A_scale=a1_scale,
|
||||||
qtype=qtype,
|
quant_dtype=qtype,
|
||||||
per_channel_quant=per_channel_quant,
|
per_act_token_quant=per_channel_quant,
|
||||||
block_shape=block_shape)
|
block_shape=block_shape)
|
||||||
|
|
||||||
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
||||||
@@ -1373,8 +1363,8 @@ def fused_experts_impl(
|
|||||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||||
A=intermediate_cache2,
|
A=intermediate_cache2,
|
||||||
A_scale=a2_scale,
|
A_scale=a2_scale,
|
||||||
qtype=qtype,
|
quant_dtype=qtype,
|
||||||
per_channel_quant=per_channel_quant,
|
per_act_token_quant=per_channel_quant,
|
||||||
block_shape=block_shape)
|
block_shape=block_shape)
|
||||||
|
|
||||||
invoke_fused_moe_kernel(qintermediate_cache2,
|
invoke_fused_moe_kernel(qintermediate_cache2,
|
||||||
@@ -1521,30 +1511,41 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool,
|
use_int8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool = False,
|
||||||
use_int4_w4a16: bool,
|
use_int4_w4a16: bool = False,
|
||||||
per_channel_quant: bool,
|
per_act_token_quant: bool = False,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
block_m: Optional[int] = None,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__(
|
||||||
|
FusedMoEQuantConfig.make(
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a8=use_int8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
use_int4_w4a16=use_int4_w4a16,
|
||||||
|
per_act_token_quant=per_act_token_quant,
|
||||||
|
block_shape=block_shape,
|
||||||
|
))
|
||||||
|
|
||||||
self.use_fp8_w8a8 = use_fp8_w8a8
|
self.use_fp8_w8a8 = use_fp8_w8a8
|
||||||
self.use_int4_w4a16 = use_int4_w4a16
|
self.use_int4_w4a16 = use_int4_w4a16
|
||||||
self.use_int8_w8a8 = use_int8_w8a8
|
self.use_int8_w8a8 = use_int8_w8a8
|
||||||
self.use_int8_w8a16 = use_int8_w8a16
|
self.use_int8_w8a16 = use_int8_w8a16
|
||||||
self.block_shape = block_shape
|
|
||||||
self.block_m = block_m
|
@property
|
||||||
self.qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8,
|
def activation_formats(
|
||||||
use_int8_w8a8=use_int8_w8a8,
|
self
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||||
use_int4_w4a16=use_int4_w4a16)
|
return (mk.FusedMoEActivationFormat.Standard,
|
||||||
self.per_channel_quant = per_channel_quant
|
mk.FusedMoEActivationFormat.Standard)
|
||||||
|
|
||||||
def supports_chunking(self) -> bool:
|
def supports_chunking(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def supports_expert_map(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
def workspace_shapes(
|
def workspace_shapes(
|
||||||
self,
|
self,
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
@@ -1660,7 +1661,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
use_int8_w8a8=self.use_int8_w8a8,
|
use_int8_w8a8=self.use_int8_w8a8,
|
||||||
use_int8_w8a16=self.use_int8_w8a16,
|
use_int8_w8a16=self.use_int8_w8a16,
|
||||||
use_int4_w4a16=self.use_int4_w4a16,
|
use_int4_w4a16=self.use_int4_w4a16,
|
||||||
per_channel_quant=self.per_channel_quant,
|
per_channel_quant=self.per_act_token_quant,
|
||||||
block_shape=self.block_shape)
|
block_shape=self.block_shape)
|
||||||
|
|
||||||
self.activation(activation, intermediate_cache2,
|
self.activation(activation, intermediate_cache2,
|
||||||
@@ -1669,8 +1670,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
a2q_scale: Optional[torch.Tensor] = None
|
a2q_scale: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||||
intermediate_cache2, a2_scale, self.qtype, self.per_channel_quant,
|
intermediate_cache2, a2_scale, self.quant_dtype,
|
||||||
self.block_shape)
|
self.per_act_token_quant, self.block_shape)
|
||||||
|
|
||||||
invoke_fused_moe_kernel(qintermediate_cache2,
|
invoke_fused_moe_kernel(qintermediate_cache2,
|
||||||
w2,
|
w2,
|
||||||
@@ -1690,7 +1691,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
use_int8_w8a8=self.use_int8_w8a8,
|
use_int8_w8a8=self.use_int8_w8a8,
|
||||||
use_int8_w8a16=self.use_int8_w8a16,
|
use_int8_w8a16=self.use_int8_w8a16,
|
||||||
use_int4_w4a16=self.use_int4_w4a16,
|
use_int4_w4a16=self.use_int4_w4a16,
|
||||||
per_channel_quant=self.per_channel_quant,
|
per_channel_quant=self.per_act_token_quant,
|
||||||
block_shape=self.block_shape)
|
block_shape=self.block_shape)
|
||||||
|
|
||||||
|
|
||||||
@@ -1699,27 +1700,17 @@ def modular_triton_fused_moe(
|
|||||||
use_int8_w8a8: bool,
|
use_int8_w8a8: bool,
|
||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
use_int4_w4a16: bool,
|
use_int4_w4a16: bool,
|
||||||
per_channel_quant: bool,
|
per_act_token_quant: bool,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
) -> mk.FusedMoEModularKernel:
|
) -> mk.FusedMoEModularKernel:
|
||||||
qtype = get_config_qtype(
|
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
|
||||||
use_int8_w8a8=use_int8_w8a8,
|
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
|
||||||
use_int4_w4a16=use_int4_w4a16,
|
|
||||||
)
|
|
||||||
return mk.FusedMoEModularKernel(
|
return mk.FusedMoEModularKernel(
|
||||||
MoEPrepareAndFinalizeNoEP(
|
MoEPrepareAndFinalizeNoEP(),
|
||||||
quant_dtype=qtype,
|
|
||||||
per_channel_quant=per_channel_quant,
|
|
||||||
block_shape=block_shape,
|
|
||||||
),
|
|
||||||
TritonExperts(
|
TritonExperts(
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
use_int8_w8a8=use_int8_w8a8,
|
use_int8_w8a8=use_int8_w8a8,
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
use_int4_w4a16=use_int4_w4a16,
|
use_int4_w4a16=use_int4_w4a16,
|
||||||
per_channel_quant=per_channel_quant,
|
per_act_token_quant=per_act_token_quant,
|
||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,27 +3,30 @@
|
|||||||
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, Literal, Optional, Union, overload
|
from typing import Callable, Literal, Optional, overload
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from compressed_tensors.quantization import (QuantizationArgs,
|
|
||||||
QuantizationStrategy,
|
|
||||||
QuantizationType)
|
|
||||||
from torch.nn.parameter import UninitializedParameter
|
from torch.nn.parameter import UninitializedParameter
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import ParallelConfig, get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.distributed import (get_dp_group, get_ep_group,
|
from vllm.distributed import (get_dp_group, get_ep_group,
|
||||||
get_tensor_model_parallel_rank,
|
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
|
get_world_group,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.distributed.eplb.eplb_state import EplbState
|
from vllm.distributed.eplb.eplb_state import EplbState
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
# yapf: disable
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FusedMoEConfig, FusedMoEParallelConfig)
|
||||||
|
# yapf: enable
|
||||||
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||||
|
FusedMoEActivationFormat, FusedMoEModularKernel,
|
||||||
|
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
|
||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||||
is_rocm_aiter_moe_enabled)
|
is_rocm_aiter_moe_enabled)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
@@ -36,14 +39,12 @@ from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
|
|||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
from .fused_batched_moe import BatchedTritonExperts
|
from .fused_batched_moe import BatchedTritonExperts
|
||||||
from .fused_moe import TritonExperts, fused_experts
|
from .fused_moe import TritonExperts, fused_experts
|
||||||
from .modular_kernel import (FusedMoEModularKernel,
|
|
||||||
FusedMoEPermuteExpertsUnpermute,
|
|
||||||
FusedMoEPrepareAndFinalize)
|
|
||||||
if has_pplx():
|
if has_pplx():
|
||||||
from .pplx_prepare_finalize import PplxPrepareAndFinalize
|
from .pplx_prepare_finalize import (PplxPrepareAndFinalize,
|
||||||
|
pplx_hidden_dim_scale_bytes)
|
||||||
if has_deep_ep():
|
if has_deep_ep():
|
||||||
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
|
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
|
||||||
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE,
|
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE,
|
||||||
DeepEPLLPrepareAndFinalize)
|
DeepEPLLPrepareAndFinalize)
|
||||||
else:
|
else:
|
||||||
fused_experts = None # type: ignore
|
fused_experts = None # type: ignore
|
||||||
@@ -60,209 +61,10 @@ if current_platform.is_tpu():
|
|||||||
from .moe_pallas import fused_moe as fused_moe_pallas
|
from .moe_pallas import fused_moe as fused_moe_pallas
|
||||||
else:
|
else:
|
||||||
fused_moe_pallas = None # type: ignore
|
fused_moe_pallas = None # type: ignore
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FusedMoEParallelConfig:
|
|
||||||
tp_size: int
|
|
||||||
dp_size: int
|
|
||||||
ep_size: int
|
|
||||||
tp_rank: int
|
|
||||||
dp_rank: int
|
|
||||||
ep_rank: int
|
|
||||||
|
|
||||||
use_ep: bool # whether to use EP or not
|
|
||||||
|
|
||||||
@property
|
|
||||||
def use_all2all_kernels(self):
|
|
||||||
return self.dp_size > 1 and self.use_ep
|
|
||||||
|
|
||||||
@property
|
|
||||||
def use_pplx_kernels(self):
|
|
||||||
return (self.use_all2all_kernels
|
|
||||||
and envs.VLLM_ALL2ALL_BACKEND == "pplx")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def use_deepep_ht_kernels(self):
|
|
||||||
return (self.use_all2all_kernels
|
|
||||||
and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def use_deepep_ll_kernels(self):
|
|
||||||
return (self.use_all2all_kernels
|
|
||||||
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def make(tp_size_: int, dp_size_: int,
|
|
||||||
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
|
|
||||||
"""
|
|
||||||
Determine MoE parallel configuration. Based on the input tp_size_,
|
|
||||||
dp_size_, ep_size_ and vllm's parallel config, determine what
|
|
||||||
level's of parallelism to use in the fused moe layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tp_size_ (int): tp_size passed into the FusedMoE constructor.
|
|
||||||
dp_size_ (int): dp_size passed into the FusedMoE constructor.
|
|
||||||
ep_size_ (int): ep_size passed into the FusedMoE constructor.
|
|
||||||
vllm_parallel_config (ParallelConfig): vllm's parallel config
|
|
||||||
object.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1,
|
|
||||||
we simply return the sizes unaltered and the ranks set to 0.
|
|
||||||
|
|
||||||
Expert Parallelism is considered only when either dp_size_ or tp_size_
|
|
||||||
is non trivial.
|
|
||||||
|
|
||||||
When TP = 2, DP = 1 and EP = False, the configuration on different
|
|
||||||
devices,
|
|
||||||
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
|
|
||||||
legend : {size, rank}
|
|
||||||
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
|
|
||||||
- Comment : Tensors are sharded across 2 devices.
|
|
||||||
|
|
||||||
When TP = 1, DP = 2 and EP = False, the configuration on different
|
|
||||||
devices,
|
|
||||||
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
|
|
||||||
- device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0}
|
|
||||||
- Comment: There are 2 engine instances and the tensors are sharded
|
|
||||||
across 2 decvices.
|
|
||||||
|
|
||||||
When TP = 2, DP = 2 and EP = False, the configuration on different
|
|
||||||
devices,
|
|
||||||
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
|
|
||||||
- device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0}
|
|
||||||
- device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0}
|
|
||||||
- device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0}
|
|
||||||
- Comment: There are 2 engine instances and the tensors are sharded
|
|
||||||
across 4 devices.
|
|
||||||
|
|
||||||
When, TP = 2, DP = 1 and EP = True, the configuration on different
|
|
||||||
devices,
|
|
||||||
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
|
|
||||||
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
|
|
||||||
- Comment: The experts are split between the 2 devices.
|
|
||||||
|
|
||||||
When, TP = 1, DP = 2 and EP = True, the configuration on different
|
|
||||||
devices,
|
|
||||||
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
|
|
||||||
- device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1}
|
|
||||||
- Comment: There are 2 engine instances and the experts are split
|
|
||||||
between the 2 devices.
|
|
||||||
|
|
||||||
When TP = 2, DP = 2 and EP = True, the configuration on different
|
|
||||||
devices,
|
|
||||||
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
|
|
||||||
- device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1}
|
|
||||||
- device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2}
|
|
||||||
- device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3}
|
|
||||||
- Comment: There are 2 engine instances and the experts are split
|
|
||||||
between the 4 devices.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def flatten_tp_across_dp(dp_rank: int):
|
|
||||||
tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank()
|
|
||||||
# There are actually dp_size_ * tp_size_ devices. Update tp_size
|
|
||||||
# and tp_rank so we shard across all devices.
|
|
||||||
tp_size = dp_size_ * tp_size_
|
|
||||||
tp_rank = dp_rank * tp_size_ + tp_rank
|
|
||||||
return tp_size, tp_rank
|
|
||||||
|
|
||||||
use_ep = (dp_size_ * tp_size_ > 1
|
|
||||||
and vllm_parallel_config.enable_expert_parallel)
|
|
||||||
|
|
||||||
dp_size = dp_size_
|
|
||||||
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
|
|
||||||
tp_size, tp_rank = flatten_tp_across_dp(dp_rank)
|
|
||||||
|
|
||||||
if not use_ep:
|
|
||||||
return FusedMoEParallelConfig(tp_size=tp_size,
|
|
||||||
tp_rank=tp_rank,
|
|
||||||
dp_size=dp_size,
|
|
||||||
dp_rank=dp_rank,
|
|
||||||
ep_size=1,
|
|
||||||
ep_rank=0,
|
|
||||||
use_ep=False)
|
|
||||||
# DP + EP / TP + EP / DP + TP + EP
|
|
||||||
assert use_ep
|
|
||||||
# In EP, each device owns a set of experts fully. There is no tensor
|
|
||||||
# parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that.
|
|
||||||
ep_size = tp_size
|
|
||||||
ep_rank = tp_rank
|
|
||||||
return FusedMoEParallelConfig(tp_size=1,
|
|
||||||
tp_rank=0,
|
|
||||||
dp_size=dp_size,
|
|
||||||
dp_rank=dp_rank,
|
|
||||||
ep_size=ep_size,
|
|
||||||
ep_rank=ep_rank,
|
|
||||||
use_ep=True)
|
|
||||||
|
|
||||||
|
|
||||||
# Adapted from pplx-kernels tests/all_to_all_utils.py
|
|
||||||
@dataclass
|
|
||||||
class MoEConfig:
|
|
||||||
num_experts: int
|
|
||||||
experts_per_token: int
|
|
||||||
hidden_dim: int
|
|
||||||
|
|
||||||
num_local_experts: int
|
|
||||||
moe_parallel_config: FusedMoEParallelConfig
|
|
||||||
|
|
||||||
in_dtype: torch.dtype # The activation type.
|
|
||||||
quant_dtype: torch.dtype = None
|
|
||||||
|
|
||||||
# TODO: add more quantization params, blocked, per-token, etc.
|
|
||||||
block_size: int = 128
|
|
||||||
|
|
||||||
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.dp_size > 1:
|
|
||||||
logger.debug("Using MOEConfig::max_num_tokens=%d",
|
|
||||||
self.max_num_tokens)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def tp_size(self):
|
|
||||||
return self.moe_parallel_config.tp_size
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dp_size(self):
|
|
||||||
return self.moe_parallel_config.dp_size
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ep_size(self):
|
|
||||||
return self.moe_parallel_config.ep_size
|
|
||||||
|
|
||||||
@property
|
|
||||||
def tp_rank(self):
|
|
||||||
return self.moe_parallel_config.tp_rank
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dp_rank(self):
|
|
||||||
return self.moe_parallel_config.dp_rank
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ep_rank(self):
|
|
||||||
return self.moe_parallel_config.ep_rank
|
|
||||||
|
|
||||||
@property
|
|
||||||
def use_ep(self):
|
|
||||||
return self.moe_parallel_config.use_ep
|
|
||||||
|
|
||||||
@property
|
|
||||||
def use_pplx_kernels(self):
|
|
||||||
return self.moe_parallel_config.use_pplx_kernels
|
|
||||||
|
|
||||||
@property
|
|
||||||
def use_deepep_ht_kernels(self):
|
|
||||||
return self.moe_parallel_config.use_deepep_ht_kernels
|
|
||||||
|
|
||||||
@property
|
|
||||||
def use_deepep_ll_kernels(self):
|
|
||||||
return self.moe_parallel_config.use_deepep_ll_kernels
|
|
||||||
|
|
||||||
|
|
||||||
class FusedMoeWeightScaleSupported(Enum):
|
class FusedMoeWeightScaleSupported(Enum):
|
||||||
TENSOR = "tensor"
|
TENSOR = "tensor"
|
||||||
CHANNEL = "channel"
|
CHANNEL = "channel"
|
||||||
@@ -270,21 +72,9 @@ class FusedMoeWeightScaleSupported(Enum):
|
|||||||
BLOCK = "block"
|
BLOCK = "block"
|
||||||
|
|
||||||
|
|
||||||
def get_quant_config_input_activations(
|
|
||||||
quant_config: Optional[QuantizationConfig]
|
|
||||||
) -> Optional[QuantizationArgs]:
|
|
||||||
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
|
|
||||||
and "Linear" in quant_config.target_scheme_map and
|
|
||||||
"input_activations" in quant_config.target_scheme_map["Linear"]):
|
|
||||||
return quant_config.target_scheme_map["Linear"].get(
|
|
||||||
"input_activations")
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class FusedMoEMethodBase(QuantizeMethodBase):
|
class FusedMoEMethodBase(QuantizeMethodBase):
|
||||||
|
|
||||||
moe: MoEConfig
|
moe: FusedMoEConfig
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
@@ -292,23 +82,25 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def init_prepare_finalize(self, moe: MoEConfig,
|
def init_prepare_finalize(self, moe: FusedMoEConfig,
|
||||||
quant_config: Optional[QuantizationConfig]):
|
quant_config: Optional[QuantizationConfig]):
|
||||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||||
assert all2all_manager is not None
|
assert all2all_manager is not None
|
||||||
|
|
||||||
self.moe = moe
|
self.moe = moe
|
||||||
quant_dtype = None
|
|
||||||
act_quant_block_size = None
|
|
||||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
|
||||||
if isinstance(quant_config, Fp8Config):
|
|
||||||
act_quant_block_size = quant_config.weight_block_size
|
|
||||||
quant_dtype = torch.float8_e4m3fn
|
|
||||||
|
|
||||||
prepare_finalize: Optional[Union[PplxPrepareAndFinalize,
|
prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None
|
||||||
DeepEPHTPrepareAndFinalize,
|
|
||||||
DeepEPLLPrepareAndFinalize]] = None
|
|
||||||
if moe.use_pplx_kernels:
|
if moe.use_pplx_kernels:
|
||||||
|
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
|
||||||
|
moe.max_num_tokens,
|
||||||
|
moe.hidden_dim,
|
||||||
|
moe.in_dtype,
|
||||||
|
moe.quant_dtype,
|
||||||
|
per_act_token_quant=moe.per_act_token_quant,
|
||||||
|
block_shape=moe.block_shape,
|
||||||
|
)
|
||||||
|
|
||||||
all_to_all_args = dict(
|
all_to_all_args = dict(
|
||||||
max_num_tokens=moe.max_num_tokens,
|
max_num_tokens=moe.max_num_tokens,
|
||||||
num_experts=moe.num_experts,
|
num_experts=moe.num_experts,
|
||||||
@@ -318,14 +110,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
# dp_size actually means tp_size, bug in pplx kernels
|
# dp_size actually means tp_size, bug in pplx kernels
|
||||||
dp_size=all2all_manager.tp_group.world_size,
|
dp_size=all2all_manager.tp_group.world_size,
|
||||||
hidden_dim=moe.hidden_dim,
|
hidden_dim=moe.hidden_dim,
|
||||||
hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize,
|
hidden_dim_bytes=hidden_dim_bytes,
|
||||||
# For blocked per token: set to
|
hidden_dim_scale_bytes=hidden_scale_bytes,
|
||||||
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
|
||||||
# For per-token: set to sizeof(float32)
|
|
||||||
hidden_dim_scale_bytes=(
|
|
||||||
0 if moe.quant_dtype.itemsize != 1 else
|
|
||||||
((moe.hidden_dim + moe.block_size - 1) // moe.block_size *
|
|
||||||
torch.float32.itemsize)),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Intranode pplx a2a takes a group name while internode does not.
|
# Intranode pplx a2a takes a group name while internode does not.
|
||||||
@@ -335,9 +121,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
|
|
||||||
handle = all2all_manager.get_handle(all_to_all_args)
|
handle = all2all_manager.get_handle(all_to_all_args)
|
||||||
|
|
||||||
input_activations = get_quant_config_input_activations(
|
|
||||||
quant_config)
|
|
||||||
|
|
||||||
prepare_finalize = PplxPrepareAndFinalize(
|
prepare_finalize = PplxPrepareAndFinalize(
|
||||||
handle,
|
handle,
|
||||||
max_num_tokens=moe.max_num_tokens,
|
max_num_tokens=moe.max_num_tokens,
|
||||||
@@ -345,10 +128,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
rank=all2all_manager.rank,
|
rank=all2all_manager.rank,
|
||||||
# dp_size actually means tp_size, bug in pplx kernels
|
# dp_size actually means tp_size, bug in pplx kernels
|
||||||
dp_size=all2all_manager.tp_group.world_size,
|
dp_size=all2all_manager.tp_group.world_size,
|
||||||
quant_dtype=moe.quant_dtype,
|
|
||||||
per_act_token=(input_activations.strategy
|
|
||||||
== QuantizationStrategy.TOKEN
|
|
||||||
if input_activations is not None else False),
|
|
||||||
)
|
)
|
||||||
elif moe.use_deepep_ht_kernels:
|
elif moe.use_deepep_ht_kernels:
|
||||||
assert moe.dp_size == all2all_manager.dp_world_size
|
assert moe.dp_size == all2all_manager.dp_world_size
|
||||||
@@ -362,8 +141,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
dp_size=all2all_manager.dp_world_size,
|
dp_size=all2all_manager.dp_world_size,
|
||||||
rank_expert_offset=all2all_manager.rank *
|
rank_expert_offset=all2all_manager.rank *
|
||||||
moe.num_local_experts,
|
moe.num_local_experts,
|
||||||
quant_dtype=quant_dtype,
|
|
||||||
block_shape=act_quant_block_size,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
elif moe.use_deepep_ll_kernels:
|
elif moe.use_deepep_ll_kernels:
|
||||||
@@ -380,25 +157,25 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
|
|
||||||
# Note : We may want to use FP8 dispatch even otherwise just to
|
# Note : We may want to use FP8 dispatch even otherwise just to
|
||||||
# reduce datamovement
|
# reduce datamovement
|
||||||
assert act_quant_block_size is not None
|
use_fp8_dispatch = (moe.quant_config is not None
|
||||||
use_fp8_dispatch = (quant_dtype == current_platform.fp8_dtype()
|
and moe.quant_config.quant_dtype
|
||||||
and act_quant_block_size[1]
|
== current_platform.fp8_dtype()
|
||||||
== DEEPEP_QUANT_BLOCK_SIZE)
|
and moe.quant_config.block_shape
|
||||||
|
== DEEPEP_QUANT_BLOCK_SHAPE)
|
||||||
|
|
||||||
# Note (varun): Whether to use FP8 dispatch or not needs some
|
# Note (varun): Whether to use FP8 dispatch or not needs some
|
||||||
# profiling. Turning it off for now.
|
# profiling. Turning it off for now.
|
||||||
prepare_finalize = DeepEPLLPrepareAndFinalize(
|
prepare_finalize = DeepEPLLPrepareAndFinalize(
|
||||||
handle,
|
handle,
|
||||||
|
max_tokens_per_rank=moe.max_num_tokens,
|
||||||
world_size=all2all_manager.world_size,
|
world_size=all2all_manager.world_size,
|
||||||
dp_size=all2all_manager.dp_world_size,
|
dp_size=all2all_manager.dp_world_size,
|
||||||
max_tokens_per_rank=moe.max_num_tokens,
|
|
||||||
quant_dtype=quant_dtype,
|
|
||||||
block_shape=act_quant_block_size,
|
|
||||||
use_fp8_dispatch=use_fp8_dispatch,
|
use_fp8_dispatch=use_fp8_dispatch,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.topk_indices_dtype = None
|
self.topk_indices_dtype = None
|
||||||
if prepare_finalize is not None:
|
if prepare_finalize is not None:
|
||||||
|
logger.debug("%s", prepare_finalize.__class__.__name__)
|
||||||
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
|
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
|
||||||
experts = self.select_gemm_impl(prepare_finalize, moe)
|
experts = self.select_gemm_impl(prepare_finalize, moe)
|
||||||
self.fused_experts = FusedMoEModularKernel(
|
self.fused_experts = FusedMoEModularKernel(
|
||||||
@@ -407,13 +184,15 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def select_gemm_impl(
|
def select_gemm_impl(
|
||||||
self, prepare_finalize: FusedMoEPrepareAndFinalize,
|
self,
|
||||||
moe: Optional[MoEConfig]) -> FusedMoEPermuteExpertsUnpermute:
|
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||||
|
moe: FusedMoEConfig,
|
||||||
|
) -> FusedMoEPermuteExpertsUnpermute:
|
||||||
# based on the all2all implementation, select the appropriate
|
# based on the all2all implementation, select the appropriate
|
||||||
# gemm implementation
|
# gemm implementation
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Subclass must select appropriate gemm implementation"
|
f"{self.__class__.__name__} must select appropriate gemm "
|
||||||
" based on the prepare_finalize")
|
"implementation based on the prepare_finalize")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(
|
def apply(
|
||||||
@@ -445,7 +224,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||||
"""MoE method without quantization."""
|
"""MoE method without quantization."""
|
||||||
|
|
||||||
def __init__(self, moe: MoEConfig):
|
def __init__(self, moe: FusedMoEConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.fused_experts = fused_experts # type: ignore
|
self.fused_experts = fused_experts # type: ignore
|
||||||
self.topk_indices_dtype = None
|
self.topk_indices_dtype = None
|
||||||
@@ -458,44 +237,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
else:
|
else:
|
||||||
self.rocm_aiter_fused_experts = None # type: ignore
|
self.rocm_aiter_fused_experts = None # type: ignore
|
||||||
|
|
||||||
def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize,
|
def select_gemm_impl(
|
||||||
moe: Optional[MoEConfig]):
|
self,
|
||||||
|
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||||
|
moe: FusedMoEConfig,
|
||||||
|
) -> FusedMoEPermuteExpertsUnpermute:
|
||||||
|
|
||||||
assert self.fused_experts == fused_experts
|
assert self.fused_experts == fused_experts
|
||||||
|
|
||||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||||
assert all2all_manager is not None
|
assert all2all_manager is not None
|
||||||
|
|
||||||
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
|
if (prepare_finalize.activation_format ==
|
||||||
|
FusedMoEActivationFormat.BatchedExperts):
|
||||||
use_batched_experts = prepare_finalize.max_num_tokens_per_rank(
|
|
||||||
) is not None
|
|
||||||
if use_batched_experts:
|
|
||||||
logger.debug("BatchedTritonExperts %s", self.moe)
|
logger.debug("BatchedTritonExperts %s", self.moe)
|
||||||
assert self.moe.dp_size == all2all_manager.dp_world_size
|
assert self.moe.dp_size == all2all_manager.dp_world_size
|
||||||
experts = BatchedTritonExperts(
|
return BatchedTritonExperts(
|
||||||
max_num_tokens=self.moe.max_num_tokens,
|
max_num_tokens=self.moe.max_num_tokens,
|
||||||
world_size=all2all_manager.world_size,
|
world_size=all2all_manager.world_size,
|
||||||
# dp_size actually means tp_size, bug in pplx kernels
|
# dp_size actually means tp_size, bug in pplx kernels
|
||||||
dp_size=all2all_manager.tp_group.world_size,
|
dp_size=all2all_manager.tp_group.world_size,
|
||||||
use_fp8_w8a8=False,
|
|
||||||
use_int8_w8a8=False,
|
|
||||||
use_int8_w8a16=False,
|
|
||||||
use_int4_w4a16=False,
|
|
||||||
block_shape=None,
|
|
||||||
per_channel_quant=False,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug("TritonExperts %s", self.moe)
|
logger.debug("TritonExperts %s", self.moe)
|
||||||
experts = TritonExperts(
|
return TritonExperts()
|
||||||
use_fp8_w8a8=False,
|
|
||||||
use_int8_w8a8=False,
|
|
||||||
use_int8_w8a16=False,
|
|
||||||
use_int4_w4a16=False,
|
|
||||||
block_shape=None,
|
|
||||||
per_channel_quant=False,
|
|
||||||
)
|
|
||||||
return experts
|
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
hidden_size: int, intermediate_size_per_partition: int,
|
hidden_size: int, intermediate_size_per_partition: int,
|
||||||
@@ -883,13 +648,18 @@ class FusedMoE(torch.nn.Module):
|
|||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
self.params_dtype = params_dtype
|
self.params_dtype = params_dtype
|
||||||
|
|
||||||
|
tp_size_ = (tp_size if tp_size is not None else
|
||||||
|
get_tensor_model_parallel_world_size())
|
||||||
|
dp_size_ = (dp_size
|
||||||
|
if dp_size is not None else get_dp_group().world_size)
|
||||||
|
world_size_ = get_world_group().world_size
|
||||||
|
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
self.moe_parallel_config: FusedMoEParallelConfig = (
|
self.moe_parallel_config: FusedMoEParallelConfig = (
|
||||||
FusedMoEParallelConfig.make(
|
FusedMoEParallelConfig.make(
|
||||||
tp_size_=(tp_size if tp_size is not None else
|
tp_size_=tp_size_,
|
||||||
get_tensor_model_parallel_world_size()),
|
dp_size_=dp_size_,
|
||||||
dp_size_=(dp_size if dp_size is not None else
|
world_size_=world_size_,
|
||||||
get_dp_group().world_size),
|
|
||||||
vllm_parallel_config=vllm_config.parallel_config))
|
vllm_parallel_config=vllm_config.parallel_config))
|
||||||
|
|
||||||
self.global_num_experts = num_experts + num_redundant_experts
|
self.global_num_experts = num_experts + num_redundant_experts
|
||||||
@@ -948,25 +718,22 @@ class FusedMoE(torch.nn.Module):
|
|||||||
from vllm_hpu_extension.ops import DynamicFusedMOE
|
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||||
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
|
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
|
||||||
|
|
||||||
# Only support float8 for now.
|
if vllm_config.model_config is not None:
|
||||||
quant_dtype = params_dtype
|
model_dtype = vllm_config.model_config.dtype
|
||||||
if quant_config is not None:
|
else:
|
||||||
input_activations = get_quant_config_input_activations(
|
# TODO (bnell): This is a hack to get test_mixtral_moe to work
|
||||||
quant_config)
|
# since model_config is not set in the pytest test.
|
||||||
if (input_activations is not None
|
model_dtype = params_dtype
|
||||||
and input_activations.num_bits == 8
|
|
||||||
and input_activations.type == QuantizationType.FLOAT):
|
|
||||||
quant_dtype = torch.float8_e4m3fn
|
|
||||||
|
|
||||||
moe = MoEConfig(
|
moe = FusedMoEConfig.make(
|
||||||
num_experts=self.global_num_experts,
|
num_experts=self.global_num_experts,
|
||||||
experts_per_token=top_k,
|
experts_per_token=top_k,
|
||||||
hidden_dim=hidden_size,
|
hidden_dim=hidden_size,
|
||||||
num_local_experts=self.local_num_experts,
|
num_local_experts=self.local_num_experts,
|
||||||
moe_parallel_config=self.moe_parallel_config,
|
moe_parallel_config=self.moe_parallel_config,
|
||||||
in_dtype=params_dtype,
|
in_dtype=model_dtype,
|
||||||
quant_dtype=quant_dtype,
|
|
||||||
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.moe_config = moe
|
self.moe_config = moe
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
@@ -1017,16 +784,15 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.batched_router_logits: Optional[torch.Tensor] = None
|
self.batched_router_logits: Optional[torch.Tensor] = None
|
||||||
if (self.moe_parallel_config.use_pplx_kernels
|
if (self.moe_parallel_config.use_pplx_kernels
|
||||||
or self.moe_parallel_config.use_deepep_ll_kernels):
|
or self.moe_parallel_config.use_deepep_ll_kernels):
|
||||||
act_dtype = vllm_config.model_config.dtype
|
|
||||||
self.batched_hidden_states = torch.zeros(
|
self.batched_hidden_states = torch.zeros(
|
||||||
(envs.VLLM_MOE_DP_CHUNK_SIZE, self.hidden_size),
|
(moe.max_num_tokens, self.hidden_size),
|
||||||
dtype=act_dtype,
|
dtype=moe.in_dtype,
|
||||||
device=torch.cuda.current_device())
|
device=torch.cuda.current_device())
|
||||||
|
|
||||||
# Note here we use `num_experts` which is logical expert count
|
# Note here we use `num_experts` which is logical expert count
|
||||||
self.batched_router_logits = torch.zeros(
|
self.batched_router_logits = torch.zeros(
|
||||||
(envs.VLLM_MOE_DP_CHUNK_SIZE, num_experts),
|
(moe.max_num_tokens, num_experts),
|
||||||
dtype=act_dtype,
|
dtype=moe.in_dtype,
|
||||||
device=torch.cuda.current_device())
|
device=torch.cuda.current_device())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
from math import prod
|
from math import prod
|
||||||
from typing import Optional
|
from typing import Optional, final
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
|
|
||||||
@@ -82,6 +84,18 @@ def _moe_problem_size(
|
|||||||
return E, M, N, K, topk
|
return E, M, N, K, topk
|
||||||
|
|
||||||
|
|
||||||
|
class FusedMoEActivationFormat(Enum):
|
||||||
|
"""
|
||||||
|
The standard activation format (num_tokens, hidden dim).
|
||||||
|
"""
|
||||||
|
Standard = "standard",
|
||||||
|
"""
|
||||||
|
The batched experts format (num experts, max tokens per expert, hidden dim)
|
||||||
|
"""
|
||||||
|
BatchedExperts = "batched_experts",
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
|
||||||
class FusedMoEPrepareAndFinalize(ABC):
|
class FusedMoEPrepareAndFinalize(ABC):
|
||||||
"""
|
"""
|
||||||
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
|
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
|
||||||
@@ -99,6 +113,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
|||||||
num_experts: int,
|
num_experts: int,
|
||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
|
quant_config: FusedMoEQuantConfig,
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
@@ -148,6 +163,15 @@ class FusedMoEPrepareAndFinalize(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def activation_format(self) -> FusedMoEActivationFormat:
|
||||||
|
"""
|
||||||
|
A property indicating the output format of the activations for the
|
||||||
|
'prepare' method.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
||||||
"""
|
"""
|
||||||
@@ -176,6 +200,41 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
|||||||
above.
|
above.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
quant_config: Optional[FusedMoEQuantConfig],
|
||||||
|
):
|
||||||
|
if quant_config is not None:
|
||||||
|
self.quant_config = quant_config
|
||||||
|
else:
|
||||||
|
self.quant_config = FusedMoEQuantConfig()
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def activation_formats(
|
||||||
|
self) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]:
|
||||||
|
"""
|
||||||
|
A property which is a tuple of the input and output activation formats
|
||||||
|
for the 'apply' method.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def quant_dtype(self) -> Optional[torch.dtype]:
|
||||||
|
return self.quant_config.quant_dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def block_shape(self) -> Optional[list[int]]:
|
||||||
|
return self.quant_config.block_shape
|
||||||
|
|
||||||
|
@property
|
||||||
|
def per_act_token_quant(self) -> bool:
|
||||||
|
return self.quant_config.per_act_token_quant
|
||||||
|
|
||||||
|
@property
|
||||||
|
def per_out_ch_quant(self) -> bool:
|
||||||
|
return self.quant_config.per_out_ch_quant
|
||||||
|
|
||||||
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
|
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def supports_chunking(self) -> bool:
|
def supports_chunking(self) -> bool:
|
||||||
@@ -185,6 +244,13 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def supports_expert_map(self) -> bool:
|
||||||
|
"""
|
||||||
|
A flag indicating whether or not this class supports expert maps
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def workspace_shapes(
|
def workspace_shapes(
|
||||||
self,
|
self,
|
||||||
@@ -297,6 +363,7 @@ def _chunk_scales(scales: Optional[torch.Tensor], start: int,
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
class FusedMoEModularKernel(torch.nn.Module):
|
class FusedMoEModularKernel(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
This class combines a FusedMoEPrepareAndFinalize instance and
|
This class combines a FusedMoEPrepareAndFinalize instance and
|
||||||
@@ -318,6 +385,12 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.prepare_finalize = prepare_finalize
|
self.prepare_finalize = prepare_finalize
|
||||||
self.fused_experts = fused_experts
|
self.fused_experts = fused_experts
|
||||||
|
assert prepare_finalize.activation_format == \
|
||||||
|
fused_experts.activation_formats[0], (
|
||||||
|
f"{prepare_finalize.__class__.__name__}."
|
||||||
|
f"{prepare_finalize.activation_format} == "
|
||||||
|
f"{fused_experts.__class__.__name__}."
|
||||||
|
f"{fused_experts.activation_formats[0]}")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -383,8 +456,16 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
|
|
||||||
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
|
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
|
||||||
_expert_topk_weights) = self.prepare_finalize.prepare(
|
_expert_topk_weights) = self.prepare_finalize.prepare(
|
||||||
a1, a1_scale, a2_scale, topk_weights, topk_ids,
|
a1,
|
||||||
global_num_experts, expert_map, apply_router_weight_on_input)
|
a1_scale,
|
||||||
|
a2_scale,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
global_num_experts,
|
||||||
|
expert_map,
|
||||||
|
apply_router_weight_on_input,
|
||||||
|
self.fused_experts.quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
|
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
|
||||||
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
|
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
|
||||||
|
|||||||
@@ -6,33 +6,76 @@ import pplx_kernels as pplx
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.fused_moe.utils import (
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
moe_kernel_quantize_input)
|
moe_kernel_quantize_input)
|
||||||
|
from vllm.utils import cdiv, round_up
|
||||||
|
|
||||||
|
|
||||||
|
def pplx_hidden_dim_scale_bytes(
|
||||||
|
max_num_tokens: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
in_dtype: torch.dtype,
|
||||||
|
quant_dtype: Optional[torch.dtype],
|
||||||
|
per_act_token_quant: bool,
|
||||||
|
block_shape: Optional[list[int]],
|
||||||
|
):
|
||||||
|
# All pplx byte sizes must be 16-byte aligned.
|
||||||
|
align = 16
|
||||||
|
|
||||||
|
# For blocked per token: set to
|
||||||
|
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
||||||
|
# For per-token: set to 4 * sizeof(float32) (x4 for alignment)
|
||||||
|
if quant_dtype is not None:
|
||||||
|
assert quant_dtype.itemsize == 1
|
||||||
|
hidden_dim_bytes = hidden_dim * quant_dtype.itemsize
|
||||||
|
elem_size = torch.float32.itemsize
|
||||||
|
|
||||||
|
if per_act_token_quant:
|
||||||
|
# per-token
|
||||||
|
assert block_shape is None
|
||||||
|
hidden_scale_bytes = elem_size
|
||||||
|
elif block_shape is not None:
|
||||||
|
# per-group
|
||||||
|
block_size = block_shape[1]
|
||||||
|
num_blocks = cdiv(hidden_dim, block_size)
|
||||||
|
hidden_scale_bytes = num_blocks * elem_size
|
||||||
|
else:
|
||||||
|
# per-tensor
|
||||||
|
hidden_scale_bytes = elem_size
|
||||||
|
else:
|
||||||
|
hidden_dim_bytes = hidden_dim * in_dtype.itemsize
|
||||||
|
hidden_scale_bytes = 0
|
||||||
|
|
||||||
|
return (
|
||||||
|
round_up(hidden_dim_bytes, align),
|
||||||
|
round_up(hidden_scale_bytes, align),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# The max_num_tokens, world_size and dp_size must be the same
|
# The max_num_tokens, world_size and dp_size must be the same
|
||||||
# as the ones used to create the AllToAll.
|
# as the ones used to create the AllToAll.
|
||||||
class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
a2a: pplx.AllToAll,
|
self,
|
||||||
max_num_tokens: int,
|
a2a: pplx.AllToAll,
|
||||||
world_size: int,
|
max_num_tokens: int,
|
||||||
rank: int,
|
world_size: int,
|
||||||
dp_size: int,
|
rank: int,
|
||||||
quant_dtype: Optional[torch.dtype] = None,
|
dp_size: int,
|
||||||
block_shape: Optional[list[int]] = None,
|
):
|
||||||
per_act_token: bool = False):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert max_num_tokens > 0
|
assert max_num_tokens > 0
|
||||||
self.a2a = a2a
|
self.a2a = a2a
|
||||||
self.block_shape = block_shape
|
|
||||||
self.max_num_tokens = max_num_tokens
|
self.max_num_tokens = max_num_tokens
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.dp_size = dp_size
|
self.dp_size = dp_size
|
||||||
self.quant_dtype = quant_dtype
|
|
||||||
self.per_act_token = per_act_token
|
@property
|
||||||
|
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||||
|
return mk.FusedMoEActivationFormat.BatchedExperts
|
||||||
|
|
||||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||||
return self.max_num_tokens
|
return self.max_num_tokens
|
||||||
@@ -45,36 +88,43 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
a1: torch.Tensor,
|
a1: torch.Tensor,
|
||||||
a1_scale: Optional[torch.Tensor],
|
a1_scale: Optional[torch.Tensor],
|
||||||
a2_scale: Optional[torch.Tensor],
|
a2_scale: Optional[torch.Tensor],
|
||||||
rank_topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
rank_topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
|
quant_config: FusedMoEQuantConfig,
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
num_tokens = a1.size(0) # M
|
num_tokens = a1.size(0) # M
|
||||||
hidden_dim = a1.size(-1) # K
|
hidden_dim = a1.size(-1) # K
|
||||||
|
|
||||||
assert rank_topk_ids.size(0) == num_tokens
|
assert topk_ids.size(0) == num_tokens
|
||||||
# assert expert_map is None, "NYI"
|
# assert expert_map is None, "NYI"
|
||||||
|
|
||||||
# Is this always going to be a1.device?
|
# Is this always going to be a1.device?
|
||||||
device = a1.device
|
device = a1.device
|
||||||
|
|
||||||
if apply_router_weight_on_input:
|
if apply_router_weight_on_input:
|
||||||
topk = rank_topk_ids.size(1)
|
topk = topk_ids.size(1)
|
||||||
# TODO: this only works for topK=1, will need to update for topK>1
|
# TODO: this only works for topK=1, will need to update for topK>1
|
||||||
assert topk == 1, (
|
assert topk == 1, (
|
||||||
"apply_router_weight_on_input is only implemented for topk=1")
|
"apply_router_weight_on_input is only implemented for topk=1")
|
||||||
a1 = a1 * rank_topk_weights.to(a1.dtype)
|
a1 = a1 * topk_weights.to(a1.dtype)
|
||||||
|
|
||||||
repeat_cols = 4
|
repeat_cols = 4
|
||||||
repeat_rows = 1 if self.per_act_token else a1.size(0)
|
repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0)
|
||||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||||
a1, (None if self.per_act_token else a1_scale), self.quant_dtype,
|
a1, (None if quant_config.per_act_token_quant else a1_scale),
|
||||||
self.per_act_token, self.block_shape)
|
quant_dtype=quant_config.quant_dtype,
|
||||||
|
per_act_token_quant=quant_config.per_act_token_quant,
|
||||||
|
block_shape=quant_config.block_shape)
|
||||||
|
|
||||||
if a1q_scale is not None:
|
if a1q_scale is not None:
|
||||||
|
if a1q_scale.numel() == 1:
|
||||||
|
orig_a_scale_block_shape = 1
|
||||||
|
else:
|
||||||
|
orig_a_scale_block_shape = a1q_scale.shape[-1]
|
||||||
a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols)
|
a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols)
|
||||||
|
|
||||||
# rem_experts need to be 0 for pplx to work properly.
|
# rem_experts need to be 0 for pplx to work properly.
|
||||||
@@ -98,15 +148,12 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
|
|
||||||
expert_x_scale: Optional[torch.Tensor] = None
|
expert_x_scale: Optional[torch.Tensor] = None
|
||||||
if a1q.dtype.itemsize == 1:
|
if a1q.dtype.itemsize == 1:
|
||||||
float32_size = torch.float32.itemsize
|
block_size = (quant_config.block_shape[1]
|
||||||
block_size = (self.block_shape[0] if self.block_shape is not None
|
if quant_config.block_shape is not None else 1)
|
||||||
else 1) * float32_size
|
|
||||||
expert_x_scale = torch.empty(
|
expert_x_scale = torch.empty(
|
||||||
(
|
(num_local_experts, expert_x.size(1),
|
||||||
num_local_experts,
|
round_up(
|
||||||
expert_x.size(1),
|
(expert_x.size(2) + block_size - 1) // block_size, 4)),
|
||||||
(expert_x.size(2) + block_size - 1) // block_size,
|
|
||||||
),
|
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
@@ -121,11 +168,11 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
out_expert_x_scale=expert_x_scale,
|
out_expert_x_scale=expert_x_scale,
|
||||||
dp_x=a1q,
|
dp_x=a1q,
|
||||||
dp_x_scale=a1q_scale,
|
dp_x_scale=a1q_scale,
|
||||||
indices=rank_topk_ids,
|
indices=topk_ids,
|
||||||
bound_m=bound_m,
|
bound_m=bound_m,
|
||||||
)
|
)
|
||||||
if expert_x_scale is not None:
|
if expert_x_scale is not None:
|
||||||
expert_x_scale = expert_x_scale[:, :, 0:1]
|
expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
|
||||||
|
|
||||||
return expert_x, expert_x_scale, expert_num_tokens, None, None
|
return expert_x, expert_x_scale, expert_num_tokens, None, None
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||||
_moe_unpermute_and_reduce)
|
_moe_unpermute_and_reduce)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import (
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
@@ -13,16 +14,9 @@ from vllm.model_executor.layers.fused_moe.utils import (
|
|||||||
|
|
||||||
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||||
|
|
||||||
def __init__(
|
@property
|
||||||
self,
|
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||||
quant_dtype: Optional[torch.dtype] = None,
|
return mk.FusedMoEActivationFormat.Standard
|
||||||
per_channel_quant: bool = False,
|
|
||||||
block_shape: Optional[list[int]] = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.per_channel_quant = per_channel_quant
|
|
||||||
self.block_shape = block_shape
|
|
||||||
self.quant_dtype = quant_dtype
|
|
||||||
|
|
||||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||||
return None
|
return None
|
||||||
@@ -39,7 +33,8 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool,
|
||||||
|
quant_config: FusedMoEQuantConfig,
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
|
|
||||||
@@ -50,10 +45,9 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
|||||||
"apply_router_weight_on_input is only implemented for topk=1"
|
"apply_router_weight_on_input is only implemented for topk=1"
|
||||||
a1.mul_(topk_weights.to(a1.dtype))
|
a1.mul_(topk_weights.to(a1.dtype))
|
||||||
|
|
||||||
a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale,
|
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||||
self.quant_dtype,
|
a1, a1_scale, quant_config.quant_dtype,
|
||||||
self.per_channel_quant,
|
quant_config.per_act_token_quant, quant_config.block_shape)
|
||||||
self.block_shape)
|
|
||||||
|
|
||||||
return a1q, a1q_scale, None, None, None
|
return a1q, a1q_scale, None, None, None
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||||
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape)
|
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
||||||
@@ -12,34 +13,59 @@ from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
|||||||
|
|
||||||
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
use_fp8_w8a8: bool = False,
|
self,
|
||||||
use_int8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
use_int4_w4a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
per_channel_quant: bool = False,
|
use_int4_w4a16: bool = False,
|
||||||
block_shape: Optional[list[int]] = None,
|
per_act_token_quant: bool = False,
|
||||||
block_m: Optional[int] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
allow_deep_gemm: bool = False):
|
allow_deep_gemm: bool = False,
|
||||||
super().__init__()
|
):
|
||||||
self.triton_expert = TritonExperts(use_fp8_w8a8=use_fp8_w8a8,
|
super().__init__(
|
||||||
use_int8_w8a8=use_int8_w8a8,
|
FusedMoEQuantConfig.make(
|
||||||
use_int4_w4a16=use_int4_w4a16,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
use_int8_w8a8=use_int8_w8a8,
|
||||||
per_channel_quant=per_channel_quant,
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
block_shape=block_shape,
|
use_int4_w4a16=use_int4_w4a16,
|
||||||
block_m=block_m)
|
per_act_token_quant=per_act_token_quant,
|
||||||
self.allow_deep_gemm = allow_deep_gemm
|
block_shape=block_shape,
|
||||||
self.use_fp8_w8a8 = use_fp8_w8a8
|
))
|
||||||
|
self.triton_expert = TritonExperts(
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a8=use_int8_w8a8,
|
||||||
|
use_int4_w4a16=use_int4_w4a16,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
per_act_token_quant=per_act_token_quant,
|
||||||
|
block_shape=block_shape,
|
||||||
|
)
|
||||||
|
self.allow_deep_gemm = (allow_deep_gemm and not per_act_token_quant
|
||||||
|
and use_fp8_w8a8)
|
||||||
self.deep_gemm_expert = DeepGemmExperts(
|
self.deep_gemm_expert = DeepGemmExperts(
|
||||||
) if self.allow_deep_gemm else None
|
) if self.allow_deep_gemm else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activation_formats(
|
||||||
|
self
|
||||||
|
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||||
|
assert (self.deep_gemm_expert is None
|
||||||
|
or self.triton_expert.activation_formats
|
||||||
|
== self.deep_gemm_expert.activation_formats)
|
||||||
|
return self.triton_expert.activation_formats
|
||||||
|
|
||||||
def supports_chunking(self) -> bool:
|
def supports_chunking(self) -> bool:
|
||||||
dge = self.deep_gemm_expert
|
dge = self.deep_gemm_expert
|
||||||
te = self.triton_expert
|
te = self.triton_expert
|
||||||
return ((dge is None or dge.supports_chunking())
|
return ((dge is None or dge.supports_chunking())
|
||||||
and (te is None or te.supports_chunking()))
|
and (te is None or te.supports_chunking()))
|
||||||
|
|
||||||
|
def supports_expert_map(self) -> bool:
|
||||||
|
dge = self.deep_gemm_expert
|
||||||
|
te = self.triton_expert
|
||||||
|
return ((dge is None or dge.supports_expert_map())
|
||||||
|
and (te is None or te.supports_expert_map()))
|
||||||
|
|
||||||
def workspace_shapes(
|
def workspace_shapes(
|
||||||
self,
|
self,
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
@@ -83,9 +109,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
workspace2: torch.Tensor,
|
workspace2: torch.Tensor,
|
||||||
expert_num_tokens: Optional[torch.Tensor],
|
expert_num_tokens: Optional[torch.Tensor],
|
||||||
):
|
):
|
||||||
N = w1.size(1)
|
use_deep_gemm = (self.allow_deep_gemm
|
||||||
|
|
||||||
use_deep_gemm = (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512
|
|
||||||
and _valid_deep_gemm(hidden_states, w1, w2))
|
and _valid_deep_gemm(hidden_states, w1, w2))
|
||||||
|
|
||||||
experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
|
experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ def _fp8_quantize(
|
|||||||
A, A_scale = ops.scaled_fp8_quant(
|
A, A_scale = ops.scaled_fp8_quant(
|
||||||
A, A_scale, use_per_token_if_dynamic=per_act_token)
|
A, A_scale, use_per_token_if_dynamic=per_act_token)
|
||||||
else:
|
else:
|
||||||
|
assert not per_act_token
|
||||||
assert len(block_shape) == 2
|
assert len(block_shape) == 2
|
||||||
_, block_k = block_shape[0], block_shape[1]
|
_, block_k = block_shape[0], block_shape[1]
|
||||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||||
@@ -64,6 +65,7 @@ def _int8_quantize(
|
|||||||
"int8 quantization only supports block or channel-wise"
|
"int8 quantization only supports block or channel-wise"
|
||||||
A, A_scale = per_token_quant_int8(A)
|
A, A_scale = per_token_quant_int8(A)
|
||||||
else:
|
else:
|
||||||
|
assert not per_act_token
|
||||||
assert len(block_shape) == 2
|
assert len(block_shape) == 2
|
||||||
_, block_k = block_shape[0], block_shape[1]
|
_, block_k = block_shape[0], block_shape[1]
|
||||||
A, A_scale = per_token_group_quant_int8(A, block_k)
|
A, A_scale = per_token_group_quant_int8(A, block_k)
|
||||||
@@ -75,16 +77,15 @@ def _int8_quantize(
|
|||||||
def moe_kernel_quantize_input(
|
def moe_kernel_quantize_input(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
A_scale: Optional[torch.Tensor],
|
A_scale: Optional[torch.Tensor],
|
||||||
qtype: Optional[torch.dtype],
|
quant_dtype: Optional[torch.dtype],
|
||||||
per_channel_quant: bool,
|
per_act_token_quant: bool,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
if qtype == torch.float8_e4m3fn:
|
if quant_dtype == torch.float8_e4m3fn:
|
||||||
return _fp8_quantize(A, A_scale, per_channel_quant, block_shape)
|
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||||
elif qtype == torch.int8:
|
elif quant_dtype == torch.int8:
|
||||||
return _int8_quantize(A, A_scale, per_channel_quant, block_shape)
|
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||||
else:
|
else:
|
||||||
assert A_scale is None
|
|
||||||
return A, A_scale
|
return A, A_scale
|
||||||
|
|
||||||
|
|
||||||
@@ -96,3 +97,17 @@ def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
|||||||
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
|
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
|
||||||
else:
|
else:
|
||||||
return m[idx, ...]
|
return m[idx, ...]
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(bnell): better name
|
||||||
|
def maybe_fix_scales(scales: Optional[torch.Tensor],
|
||||||
|
num_experts: int) -> Optional[torch.Tensor]:
|
||||||
|
if scales is not None and scales.ndim < 3:
|
||||||
|
if scales.numel() == 1:
|
||||||
|
scales = scales.view(1)
|
||||||
|
scales = torch.repeat_interleave(scales, num_experts,
|
||||||
|
dim=0).view(num_experts, 1, 1)
|
||||||
|
else:
|
||||||
|
scales = scales.view(num_experts, -1, scales.size(-1))
|
||||||
|
|
||||||
|
return scales
|
||||||
|
|||||||
@@ -13,8 +13,10 @@ from compressed_tensors.quantization import (ActivationOrdering,
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
from vllm.model_executor.layers.fused_moe import (
|
||||||
FusedMoeWeightScaleSupported)
|
CutlassExpertsFp8, FusedMoE, FusedMoEActivationFormat, FusedMoEConfig,
|
||||||
|
FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute,
|
||||||
|
FusedMoEPrepareAndFinalize, FusedMoeWeightScaleSupported, fused_experts)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
||||||
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
|
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
|
||||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||||
@@ -32,14 +34,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
from vllm.utils import has_pplx
|
|
||||||
|
|
||||||
if current_platform.is_cuda_alike():
|
|
||||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
|
||||||
BatchedPrepareAndFinalize)
|
|
||||||
if has_pplx():
|
|
||||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
|
||||||
PplxPrepareAndFinalize)
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -569,15 +563,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts
|
self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts
|
||||||
else:
|
elif self.use_marlin:
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
|
||||||
self.fused_experts_func = fused_experts
|
|
||||||
|
|
||||||
if self.use_marlin:
|
|
||||||
prepare_moe_fp8_layer_for_marlin(layer, False)
|
prepare_moe_fp8_layer_for_marlin(layer, False)
|
||||||
# Activations not quantized for marlin.
|
# Activations not quantized for marlin.
|
||||||
del layer.w13_input_scale
|
del layer.w13_input_scale
|
||||||
del layer.w2_input_scale
|
del layer.w2_input_scale
|
||||||
|
self.fused_experts_func = None
|
||||||
|
else:
|
||||||
|
self.fused_experts_func = fused_experts
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@@ -653,6 +646,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map)
|
expert_map=expert_map)
|
||||||
|
|
||||||
|
assert self.fused_experts_func is not None
|
||||||
|
|
||||||
return self.fused_experts_func(
|
return self.fused_experts_func(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
@@ -826,28 +821,27 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
|||||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
def select_gemm_impl(self, prepare_finalize, moe):
|
def select_gemm_impl(
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
self,
|
||||||
CutlassExpertsFp8)
|
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||||
|
moe: FusedMoEConfig,
|
||||||
|
) -> FusedMoEPermuteExpertsUnpermute:
|
||||||
|
|
||||||
assert moe is not None
|
use_batched_format = (prepare_finalize.activation_format ==
|
||||||
|
FusedMoEActivationFormat.BatchedExperts)
|
||||||
|
|
||||||
|
num_experts = (moe.num_local_experts
|
||||||
|
if use_batched_format else moe.num_experts)
|
||||||
|
|
||||||
max_experts_per_worker = (
|
|
||||||
(moe.num_experts + prepare_finalize.world_size - 1) //
|
|
||||||
prepare_finalize.world_size)
|
|
||||||
experts = CutlassExpertsFp8(
|
experts = CutlassExpertsFp8(
|
||||||
max_experts_per_worker,
|
num_experts,
|
||||||
moe.in_dtype,
|
moe.in_dtype,
|
||||||
self.input_quant.strategy == QuantizationStrategy.TOKEN,
|
self.input_quant.strategy == QuantizationStrategy.TOKEN,
|
||||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
|
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
|
||||||
use_batched_format=True,
|
use_batched_format=use_batched_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
if has_pplx() and isinstance(
|
self.disable_expert_map = not experts.supports_expert_map()
|
||||||
prepare_finalize,
|
|
||||||
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
|
|
||||||
# no expert_map support in this case
|
|
||||||
self.disable_expert_map = True
|
|
||||||
return experts
|
return experts
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
@@ -888,7 +882,8 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
|||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
e_score_correction_bias=e_score_correction_bias,
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
indices_type=torch.uint32)
|
indices_type=self.topk_indices_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
return self.fused_experts(
|
return self.fused_experts(
|
||||||
x,
|
x,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -13,8 +13,11 @@ import vllm.envs as envs
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
from vllm.model_executor.layers.fused_moe import (
|
||||||
FusedMoeWeightScaleSupported)
|
BatchedTritonOrDeepGemmExperts, FusedMoE, FusedMoEActivationFormat,
|
||||||
|
FusedMoEConfig, FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute,
|
||||||
|
FusedMoEPrepareAndFinalize, FusedMoeWeightScaleSupported,
|
||||||
|
TritonOrDeepGemmExperts)
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
@@ -777,44 +780,46 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
del layer.w13_input_scale
|
del layer.w13_input_scale
|
||||||
del layer.w2_input_scale
|
del layer.w2_input_scale
|
||||||
|
|
||||||
def select_gemm_impl(self, prepare_finalize, moe):
|
def select_gemm_impl(
|
||||||
|
self,
|
||||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||||
BatchedTritonOrDeepGemmExperts)
|
moe: FusedMoEConfig,
|
||||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
) -> FusedMoEPermuteExpertsUnpermute:
|
||||||
TritonOrDeepGemmExperts)
|
|
||||||
|
|
||||||
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
|
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
|
||||||
"Marlin and ROCm AITER are not supported with all2all yet.")
|
"Marlin and ROCm AITER are not supported with all2all yet.")
|
||||||
|
|
||||||
experts: Optional[Union[BatchedTritonOrDeepGemmExperts,
|
if (prepare_finalize.activation_format ==
|
||||||
TritonOrDeepGemmExperts]] = None
|
FusedMoEActivationFormat.BatchedExperts):
|
||||||
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
|
max_num_tokens_per_rank = (
|
||||||
use_batched_experts = max_num_tokens_per_rank is not None
|
prepare_finalize.max_num_tokens_per_rank())
|
||||||
|
assert max_num_tokens_per_rank is not None
|
||||||
if use_batched_experts:
|
logger.debug(
|
||||||
experts = BatchedTritonOrDeepGemmExperts(
|
"BatchedTritonOrDeepGemmExperts(%s): "
|
||||||
|
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
|
||||||
|
self.__class__.__name__, max_num_tokens_per_rank,
|
||||||
|
self.quant_config.weight_block_size, False)
|
||||||
|
return BatchedTritonOrDeepGemmExperts(
|
||||||
max_num_tokens=max_num_tokens_per_rank,
|
max_num_tokens=max_num_tokens_per_rank,
|
||||||
world_size=prepare_finalize.world_size,
|
world_size=prepare_finalize.
|
||||||
dp_size=prepare_finalize.dp_size,
|
world_size, # type: ignore [attr-defined]
|
||||||
|
dp_size=prepare_finalize.
|
||||||
|
dp_size, # type: ignore [attr-defined]
|
||||||
use_fp8_w8a8=True,
|
use_fp8_w8a8=True,
|
||||||
use_int8_w8a8=False,
|
|
||||||
use_int8_w8a16=False,
|
|
||||||
use_int4_w4a16=False,
|
|
||||||
per_channel_quant=False,
|
|
||||||
block_shape=self.quant_config.weight_block_size,
|
block_shape=self.quant_config.weight_block_size,
|
||||||
|
per_act_token_quant=False,
|
||||||
allow_deep_gemm=self.allow_deep_gemm,
|
allow_deep_gemm=self.allow_deep_gemm,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
experts = TritonOrDeepGemmExperts(
|
logger.debug(
|
||||||
|
"TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
|
||||||
|
self.__class__.__name__, self.quant_config.weight_block_size,
|
||||||
|
False)
|
||||||
|
return TritonOrDeepGemmExperts(
|
||||||
use_fp8_w8a8=True,
|
use_fp8_w8a8=True,
|
||||||
block_shape=self.quant_config.weight_block_size,
|
block_shape=self.quant_config.weight_block_size,
|
||||||
allow_deep_gemm=self.allow_deep_gemm,
|
allow_deep_gemm=self.allow_deep_gemm,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert experts is not None
|
|
||||||
return experts
|
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
|
|||||||
Reference in New Issue
Block a user