[Kernel] moe wna16 marlin kernel (#14447)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Co-authored-by: Michael Goin <michael@neuralmagic.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -1245,6 +1245,29 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
token_expert_indicies, gating_output)
|
||||
|
||||
|
||||
def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor],
|
||||
b_qweight: torch.Tensor, b_scales: torch.Tensor,
|
||||
b_qzeros: Optional[torch.Tensor],
|
||||
g_idx: Optional[torch.Tensor],
|
||||
perm: Optional[torch.Tensor],
|
||||
workspace: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_past_padded: torch.Tensor,
|
||||
topk_weights: torch.Tensor, moe_block_size: int,
|
||||
top_k: int, mul_topk_weights: bool, is_ep: bool,
|
||||
b_q_type: ScalarType, size_m: int, size_n: int,
|
||||
size_k: int, is_k_full: bool, use_atomic_add: bool,
|
||||
use_fp32_reduce: bool,
|
||||
is_zp_float: bool) -> torch.Tensor:
|
||||
return torch.ops._moe_C.moe_wna16_marlin_gemm(
|
||||
input, output, b_qweight, b_scales, b_qzeros, g_idx, perm, workspace,
|
||||
sorted_token_ids, expert_ids, num_tokens_past_padded, topk_weights,
|
||||
moe_block_size, top_k, mul_topk_weights, is_ep, b_q_type.id, size_m,
|
||||
size_n, size_k, is_k_full, use_atomic_add, use_fp32_reduce,
|
||||
is_zp_float)
|
||||
|
||||
|
||||
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
|
||||
|
||||
@register_fake("_moe_C::marlin_gemm_moe")
|
||||
@@ -1263,6 +1286,29 @@ if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
|
||||
dtype=a.dtype,
|
||||
device=a.device)
|
||||
|
||||
@register_fake("_moe_C::moe_wna16_marlin_gemm")
|
||||
def moe_wna16_marlin_gemm_fake(input: torch.Tensor,
|
||||
output: Optional[torch.Tensor],
|
||||
b_qweight: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
b_qzeros: Optional[torch.Tensor],
|
||||
g_idx: Optional[torch.Tensor],
|
||||
perm: Optional[torch.Tensor],
|
||||
workspace: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_past_padded: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
moe_block_size: int, top_k: int,
|
||||
mul_topk_weights: bool, is_ep: bool,
|
||||
b_q_type: ScalarType, size_m: int,
|
||||
size_n: int, size_k: int, is_k_full: bool,
|
||||
use_atomic_add: bool, use_fp32_reduce: bool,
|
||||
is_zp_float: bool) -> torch.Tensor:
|
||||
return torch.empty((size_m * top_k, size_n),
|
||||
dtype=input.dtype,
|
||||
device=input.device)
|
||||
|
||||
|
||||
def reshape_and_cache(
|
||||
key: torch.Tensor,
|
||||
|
||||
@@ -5,17 +5,16 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
def get_scalar_type(num_bits: int, has_zp: bool):
|
||||
if has_zp:
|
||||
assert num_bits == 4
|
||||
return scalar_types.uint4
|
||||
return scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
|
||||
else:
|
||||
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
|
||||
|
||||
@@ -27,9 +26,12 @@ def single_marlin_moe(
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
g_idx: Optional[torch.Tensor] = None,
|
||||
sort_indices: Optional[torch.Tensor] = None,
|
||||
w_zeros: Optional[torch.Tensor] = None,
|
||||
workspace: Optional[torch.Tensor] = None,
|
||||
num_bits: int = 8,
|
||||
is_k_full: bool = True,
|
||||
) -> torch.Tensor:
|
||||
@@ -62,7 +64,7 @@ def single_marlin_moe(
|
||||
assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w.is_contiguous(), "Expert weights must be contiguous"
|
||||
assert hidden_states.dtype == torch.float16
|
||||
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
|
||||
assert num_bits in [4, 8]
|
||||
|
||||
M, K = hidden_states.shape
|
||||
@@ -83,39 +85,54 @@ def single_marlin_moe(
|
||||
|
||||
block_size_m = config['BLOCK_SIZE_M']
|
||||
|
||||
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = \
|
||||
moe_align_block_size(topk_ids, block_size_m, E, expert_map)
|
||||
|
||||
max_workspace_size = (N // 64) * 16
|
||||
workspace = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
device=hidden_states.device,
|
||||
requires_grad=False)
|
||||
if workspace is None:
|
||||
max_workspace_size = (max(2 * N, K) // 64) * \
|
||||
(sorted_token_ids.size(0) // block_size_m)
|
||||
device = hidden_states.device
|
||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||
max_workspace_size = min(max_workspace_size, sms)
|
||||
workspace = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
requires_grad=False)
|
||||
|
||||
has_zero_point = w_zeros is not None
|
||||
if w_zeros is None:
|
||||
w_zeros = torch.empty((0, 0),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
requires_grad=False)
|
||||
scalar_type = get_scalar_type(num_bits, w_zeros is not None)
|
||||
intermediate_cache = torch.empty(
|
||||
(M * topk_ids.shape[1], N),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
if g_idx is None:
|
||||
g_idx = torch.empty((0, 0),
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device,
|
||||
requires_grad=False)
|
||||
|
||||
if sort_indices is None:
|
||||
sort_indices = torch.empty((0),
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device,
|
||||
requires_grad=False)
|
||||
|
||||
scalar_type = get_scalar_type(num_bits, has_zero_point)
|
||||
|
||||
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
|
||||
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
|
||||
w_zeros, g_idx, sort_indices, workspace, scalar_type.id, M, N, K,
|
||||
is_k_full, E, topk, block_size_m, True, False)
|
||||
ops.moe_wna16_marlin_gemm(hidden_states,
|
||||
intermediate_cache,
|
||||
w,
|
||||
scales,
|
||||
w_zeros,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
topk_weights,
|
||||
moe_block_size=block_size_m,
|
||||
top_k=topk,
|
||||
mul_topk_weights=False,
|
||||
is_ep=expert_map is not None,
|
||||
b_q_type=scalar_type,
|
||||
size_m=M,
|
||||
size_n=N,
|
||||
size_k=K,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=False,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False)
|
||||
intermediate_cache = intermediate_cache.view(-1, topk, N)
|
||||
|
||||
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
|
||||
|
||||
@@ -127,9 +144,12 @@ def single_marlin_moe_fake(
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
g_idx: Optional[torch.Tensor] = None,
|
||||
sort_indices: Optional[torch.Tensor] = None,
|
||||
w_zeros: Optional[torch.Tensor] = None,
|
||||
workspace: Optional[torch.Tensor] = None,
|
||||
num_bits: int = 8,
|
||||
is_k_full: bool = True,
|
||||
) -> torch.Tensor:
|
||||
@@ -144,24 +164,26 @@ direct_register_custom_op(
|
||||
)
|
||||
|
||||
|
||||
def fused_marlin_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
g_idx1: Optional[torch.Tensor] = None,
|
||||
g_idx2: Optional[torch.Tensor] = None,
|
||||
sort_indices1: Optional[torch.Tensor] = None,
|
||||
sort_indices2: Optional[torch.Tensor] = None,
|
||||
w1_zeros: Optional[torch.Tensor] = None,
|
||||
w2_zeros: Optional[torch.Tensor] = None,
|
||||
num_bits: int = 8,
|
||||
is_k_full: bool = True,
|
||||
) -> torch.Tensor:
|
||||
def fused_marlin_moe(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
g_idx1: Optional[torch.Tensor] = None,
|
||||
g_idx2: Optional[torch.Tensor] = None,
|
||||
sort_indices1: Optional[torch.Tensor] = None,
|
||||
sort_indices2: Optional[torch.Tensor] = None,
|
||||
w1_zeros: Optional[torch.Tensor] = None,
|
||||
w2_zeros: Optional[torch.Tensor] = None,
|
||||
workspace: Optional[torch.Tensor] = None,
|
||||
num_bits: int = 8,
|
||||
is_k_full: bool = True,
|
||||
inplace: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
weights, w1 and w2, and top-k gating mechanism.
|
||||
@@ -196,27 +218,12 @@ def fused_marlin_moe(
|
||||
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
|
||||
assert hidden_states.shape[1] == w2.shape[2] // (
|
||||
num_bits // 2), "Hidden size mismatch w2"
|
||||
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
assert hidden_states.dtype == torch.float16
|
||||
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
|
||||
assert num_bits in [4, 8]
|
||||
|
||||
has_no_act_order = (g_idx1 is None and g_idx2 is None
|
||||
and sort_indices1 is None and sort_indices2 is None)
|
||||
has_all_act_order = (g_idx1 is not None and g_idx2 is not None
|
||||
and sort_indices1 is not None
|
||||
and sort_indices2 is not None)
|
||||
assert has_no_act_order or has_all_act_order, (
|
||||
"g_idx and sorted_indices "
|
||||
"must be all not None or must be all None")
|
||||
|
||||
has_no_zp = w1_zeros is None and w2_zeros is None
|
||||
has_all_zp = w1_zeros is not None and w2_zeros is not None
|
||||
assert has_no_zp or has_all_zp, ("zero points must be both not None or "
|
||||
"must be both None")
|
||||
|
||||
M, K = hidden_states.shape
|
||||
E = w1.shape[0]
|
||||
N = w2.shape[1] * 16
|
||||
@@ -234,122 +241,128 @@ def fused_marlin_moe(
|
||||
|
||||
block_size_m = config["BLOCK_SIZE_M"]
|
||||
|
||||
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = \
|
||||
moe_align_block_size(topk_ids, block_size_m, global_num_experts,
|
||||
expert_map)
|
||||
|
||||
max_workspace_size = (max(2 * N, K) // 64) * 16
|
||||
workspace = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
device=current_platform.device_type,
|
||||
requires_grad=False)
|
||||
if workspace is None:
|
||||
max_workspace_size = (max(2 * N, K) // 64) * \
|
||||
(sorted_token_ids.size(0) // block_size_m)
|
||||
device = hidden_states.device
|
||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||
max_workspace_size = min(max_workspace_size, sms * 4)
|
||||
workspace = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
requires_grad=False)
|
||||
|
||||
if has_no_zp:
|
||||
w1_zeros = torch.empty((0, 0),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
requires_grad=False)
|
||||
w2_zeros = torch.empty((0, 0),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
requires_grad=False)
|
||||
|
||||
if has_no_act_order:
|
||||
g_idx1 = torch.empty((0, 0),
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device,
|
||||
requires_grad=False)
|
||||
g_idx2 = torch.empty((0, 0),
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device,
|
||||
requires_grad=False)
|
||||
sort_indices1 = torch.empty((0),
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device,
|
||||
requires_grad=False)
|
||||
sort_indices2 = torch.empty((0, 0),
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device,
|
||||
requires_grad=False)
|
||||
|
||||
scalar_type1 = get_scalar_type(num_bits, has_all_zp)
|
||||
scalar_type2 = get_scalar_type(num_bits, has_all_zp)
|
||||
scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None)
|
||||
scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None)
|
||||
|
||||
intermediate_cache2 = torch.empty(
|
||||
(M * topk_ids.shape[1], N),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache13 = torch.empty(
|
||||
(M * topk_ids.shape[1] * max(2 * N, K), ),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache1 = intermediate_cache13[:M * topk_ids.shape[1] * 2 * N]
|
||||
intermediate_cache1 = intermediate_cache1.view(-1, 2 * N)
|
||||
intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K]
|
||||
intermediate_cache3 = intermediate_cache3.view(-1, K)
|
||||
|
||||
intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe(
|
||||
use_atomic_add = hidden_states.dtype == torch.half or \
|
||||
torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
|
||||
|
||||
intermediate_cache1 = ops.moe_wna16_marlin_gemm(
|
||||
hidden_states,
|
||||
intermediate_cache1,
|
||||
w1,
|
||||
sorted_token_ids,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w1_zeros,
|
||||
g_idx1,
|
||||
sort_indices1,
|
||||
workspace,
|
||||
scalar_type1.id,
|
||||
M,
|
||||
2 * N,
|
||||
K,
|
||||
is_k_full,
|
||||
E,
|
||||
topk,
|
||||
block_size_m,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
topk_weights,
|
||||
moe_block_size=block_size_m,
|
||||
top_k=topk,
|
||||
mul_topk_weights=False,
|
||||
is_ep=expert_map is not None,
|
||||
b_q_type=scalar_type1,
|
||||
size_m=M,
|
||||
size_n=2 * N,
|
||||
size_k=K,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False)
|
||||
|
||||
torch.ops._C.silu_and_mul(intermediate_cache2,
|
||||
intermediate_cache1.view(-1, 2 * N))
|
||||
|
||||
intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
|
||||
if expert_map is not None:
|
||||
intermediate_cache3.zero_()
|
||||
|
||||
intermediate_cache3 = ops.moe_wna16_marlin_gemm(
|
||||
intermediate_cache2,
|
||||
intermediate_cache3,
|
||||
w2,
|
||||
sorted_token_ids,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w2_scale,
|
||||
w2_zeros,
|
||||
g_idx2,
|
||||
sort_indices2,
|
||||
workspace,
|
||||
scalar_type2.id,
|
||||
M,
|
||||
K,
|
||||
N,
|
||||
is_k_full,
|
||||
E,
|
||||
topk,
|
||||
block_size_m,
|
||||
False,
|
||||
True,
|
||||
)
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
topk_weights,
|
||||
moe_block_size=block_size_m,
|
||||
top_k=1,
|
||||
mul_topk_weights=True,
|
||||
is_ep=expert_map is not None,
|
||||
b_q_type=scalar_type2,
|
||||
size_m=M * topk,
|
||||
size_n=K,
|
||||
size_k=N,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False).view(-1, topk, K)
|
||||
|
||||
output = hidden_states if inplace else torch.empty_like(hidden_states)
|
||||
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
dim=1)
|
||||
dim=1,
|
||||
out=output)
|
||||
|
||||
|
||||
def fused_marlin_moe_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
g_idx1: Optional[torch.Tensor] = None,
|
||||
g_idx2: Optional[torch.Tensor] = None,
|
||||
sort_indices1: Optional[torch.Tensor] = None,
|
||||
sort_indices2: Optional[torch.Tensor] = None,
|
||||
w1_zeros: Optional[torch.Tensor] = None,
|
||||
w2_zeros: Optional[torch.Tensor] = None,
|
||||
num_bits: int = 8,
|
||||
is_k_full: bool = True,
|
||||
) -> torch.Tensor:
|
||||
def fused_marlin_moe_fake(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
g_idx1: Optional[torch.Tensor] = None,
|
||||
g_idx2: Optional[torch.Tensor] = None,
|
||||
sort_indices1: Optional[torch.Tensor] = None,
|
||||
sort_indices2: Optional[torch.Tensor] = None,
|
||||
w1_zeros: Optional[torch.Tensor] = None,
|
||||
w2_zeros: Optional[torch.Tensor] = None,
|
||||
workspace: Optional[torch.Tensor] = None,
|
||||
num_bits: int = 8,
|
||||
is_k_full: bool = True,
|
||||
inplace: bool = False) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
|
||||
@@ -773,6 +773,18 @@ def get_default_config(
|
||||
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1}
|
||||
else:
|
||||
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1}
|
||||
elif is_marlin:
|
||||
for block_size_m in [8, 16, 32, 48, 64]:
|
||||
if M * topk / E / block_size_m < 0.9:
|
||||
break
|
||||
return {"BLOCK_SIZE_M": block_size_m}
|
||||
elif M <= E:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
}
|
||||
else:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
@@ -780,14 +792,6 @@ def get_default_config(
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
}
|
||||
# A heuristic: fused marlin works faster with this config for small M
|
||||
if M <= E or (is_marlin and M <= 32):
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
|
||||
@@ -472,6 +472,7 @@ class FusedMoE(torch.nn.Module):
|
||||
self.global_num_experts = num_experts
|
||||
|
||||
assert intermediate_size % self.tp_size == 0
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||
self.reduce_results = reduce_results
|
||||
self.renormalize = renormalize
|
||||
|
||||
@@ -17,14 +17,13 @@ from vllm.model_executor.layers.quantization.awq import (AWQConfig,
|
||||
is_layer_skipped_awq)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
|
||||
check_marlin_supports_layer, marlin_make_empty_g_idx,
|
||||
marlin_make_workspace, marlin_moe_permute_scales, marlin_permute_scales,
|
||||
moe_awq_to_marlin_zero_points, verify_marlin_supported,
|
||||
verify_marlin_supports_shape)
|
||||
check_marlin_supports_layer, check_moe_marlin_supports_layer,
|
||||
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
|
||||
marlin_permute_scales, moe_awq_to_marlin_zero_points,
|
||||
verify_marlin_supported, verify_marlin_supports_shape)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
@@ -136,12 +135,15 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
return AWQMarlinLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
if layer.local_num_experts > 32:
|
||||
# For MoEs with many experts the moe_wna16 kernel is faster
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import (
|
||||
MoeWNA16Config)
|
||||
if not check_moe_marlin_supports_layer(layer, self.group_size):
|
||||
logger.warning_one(
|
||||
f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
|
||||
"Falling back to Moe WNA16 kernels.")
|
||||
return MoeWNA16Config.from_config(
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
else:
|
||||
return AWQMoEMethod(self)
|
||||
return AWQMoEMethod(self)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@@ -391,6 +393,13 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
layer.register_parameter("w2_qzeros", w2_qzeros)
|
||||
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||
|
||||
device = layer.w13_qweight.device
|
||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||
layer.workspace = torch.zeros((sms * 4, ),
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
requires_grad=False)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
num_experts = layer.w13_qweight.shape[0]
|
||||
device = layer.w13_qweight.device
|
||||
@@ -473,10 +482,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
if expert_map is not None:
|
||||
raise NotImplementedError(
|
||||
"Expert Parallelism is not supported for "
|
||||
"fused Marlin MoE method.")
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
raise NotImplementedError(
|
||||
"Apply router weight on input is not supported for"
|
||||
@@ -503,7 +509,10 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_zeros=layer.w13_qzeros,
|
||||
w2_zeros=layer.w2_qzeros,
|
||||
workspace=layer.workspace,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
)
|
||||
|
||||
@@ -15,13 +15,13 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
MPLinearLayerConfig, choose_mp_linear_kernel)
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_linear_quant_method)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported, marlin_moe_permute_scales,
|
||||
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
|
||||
check_marlin_supported, check_moe_marlin_supports_layer,
|
||||
marlin_moe_permute_scales, marlin_repeat_scales_on_all_ranks,
|
||||
verify_marlin_supported)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
@@ -153,12 +153,15 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, FusedMoE):
|
||||
if layer.local_num_experts > 32:
|
||||
# For MoEs with many experts the moe_wna16 kernel is faster
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import (
|
||||
MoeWNA16Config)
|
||||
if not check_moe_marlin_supports_layer(layer, self.group_size):
|
||||
logger.warning_one(
|
||||
f"Layer '{prefix}' is not supported by GPTQMoeMarlin. "
|
||||
"Falling back to Moe WNA16 kernels.")
|
||||
return MoeWNA16Config.from_config(
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
else:
|
||||
return GPTQMarlinMoEMethod(self)
|
||||
return GPTQMarlinMoEMethod(self)
|
||||
return get_linear_quant_method(self, layer, prefix,
|
||||
GPTQMarlinLinearMethod)
|
||||
|
||||
@@ -408,7 +411,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
torch.empty(num_experts,
|
||||
scales_size13,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.half),
|
||||
dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_scales", w13_scales)
|
||||
@@ -418,7 +421,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
torch.empty(num_experts,
|
||||
scales_size2,
|
||||
hidden_size,
|
||||
dtype=torch.half),
|
||||
dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_scales", w2_scales)
|
||||
@@ -493,6 +496,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
w2_g_idx_sort_indices)
|
||||
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
|
||||
|
||||
device = layer.w13_qweight.device
|
||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||
layer.workspace = torch.zeros((sms * 4, ),
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
requires_grad=False)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
# Process act_order
|
||||
@@ -601,10 +611,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
"Apply router weight on input is not supported for"
|
||||
"fused Marlin MoE method.")
|
||||
|
||||
# The input must currently be float16
|
||||
orig_dtype = x.dtype
|
||||
x = x.half()
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
@@ -626,9 +632,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
g_idx1=layer.w13_g_idx,
|
||||
g_idx2=layer.w2_g_idx,
|
||||
sort_indices1=layer.w13_g_idx_sort_indices,
|
||||
sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
num_bits=self.quant_config.quant_type.size_bits,
|
||||
is_k_full=self.is_k_full).to(orig_dtype)
|
||||
workspace=layer.workspace,
|
||||
is_k_full=self.is_k_full)
|
||||
|
||||
@@ -151,6 +151,19 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) \
|
||||
group_size=group_size)[0]
|
||||
|
||||
|
||||
def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \
|
||||
-> bool:
|
||||
hidden_size = layer.hidden_size
|
||||
intermediate_size_per_partition = layer.intermediate_size_per_partition
|
||||
|
||||
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
|
||||
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
|
||||
# moe marlin requires n % 128 == 0 and k % 64 == 0
|
||||
return hidden_size % 128 == 0 and \
|
||||
intermediate_size_per_partition % max(64, group_size) == 0 and \
|
||||
group_size in [-1, 32, 64, 128]
|
||||
|
||||
|
||||
def marlin_make_workspace(output_size_per_partition: int,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
max_workspace_size = (output_size_per_partition //
|
||||
|
||||
Reference in New Issue
Block a user