[Kernel] add triton fused moe kernel for gptq/awq (#12185)

This commit is contained in:
Jinzhen Lin
2025-01-29 22:07:09 +08:00
committed by GitHub
parent b02fd288b2
commit 27b78c73ca
4 changed files with 874 additions and 55 deletions

View File

@@ -19,6 +19,206 @@ from vllm.utils import direct_register_custom_op
logger = init_logger(__name__)
@triton.jit
def fused_moe_kernel_gptq_awq(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
b_scale_ptr,
b_zp_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N: tl.constexpr,
K: tl.constexpr,
EM,
num_valid_tokens,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_bse,
stride_bsk,
stride_bsn,
stride_bze,
stride_bzk,
stride_bzn,
block_k_diviable: tl.constexpr,
group_size: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
has_zp: tl.constexpr,
use_int4_w4a16: tl.constexpr,
use_int8_w8a16: tl.constexpr):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can
be any shape representing batches and K is the feature dimension of
each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
the number of experts, K is the input feature dimension, and N is
the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the
total number of tokens post padding, topk is the number of times
each token is repeated, and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens,
repeated topk times and arranged by the expert index they are
assigned to.
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
multiplication across different blocks processed by the same expert.
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(
tl.int64)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens
offs_bn = (pid_n * BLOCK_SIZE_N +
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak)
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
if use_int4_w4a16:
b_ptrs = b_ptr + off_experts * stride_be + \
(offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn
b_shifter = (offs_k[:, None] % 2) * 4
elif use_int8_w8a16:
b_ptrs = b_ptr + off_experts * stride_be + \
offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
if not has_zp and use_int4_w4a16:
b_zp_num = 8
if not has_zp and use_int8_w8a16:
b_zp_num = 128
elif has_zp and use_int4_w4a16:
b_zp_shifter = (offs_bn[None, :] % 2) * 4
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
if not block_k_diviable:
k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
k_other = 0.0
else:
k_mask = None
k_other = None
a = tl.load(a_ptrs,
mask=token_mask[:, None] &
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0)
b = tl.load(b_ptrs)
if use_int4_w4a16:
b = (b >> b_shifter) & 0xF
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \
offs_bn[None, :] * stride_bsn + \
((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
b_scale = b_scale.to(tl.float32)
if has_zp and use_int4_w4a16:
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \
(offs_bn[None, :] // 2) * stride_bzn + \
offs_k_true * stride_bzk
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
b_zp = ((b_zp >> b_zp_shifter) & 0xF)
b_zp = b_zp.to(tl.float32)
elif has_zp and use_int8_w8a16:
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \
offs_bn[None, :] * stride_bzn + \
offs_k_true * stride_bzk
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
b_zp = b_zp.to(tl.float32)
# We accumulate along the K dimension.
if has_zp:
b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
else:
b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
accumulator = tl.dot(a, b, acc=accumulator)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
if use_int4_w4a16:
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
else:
b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token,
mask=token_mask,
other=0)
accumulator = accumulator * moe_weight[:, None]
accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
@triton.jit
def fused_moe_kernel(
# Pointers to matrices
@@ -266,6 +466,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
C: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
B_zp: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
@@ -277,6 +478,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
compute_type: tl.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
block_shape: Optional[List[int]] = None) -> None:
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
@@ -292,50 +494,108 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a16:
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
assert block_shape is None or block_shape[0] == 0
else:
assert A_scale is None
assert B_scale is None
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
EM = sorted_token_ids.shape[0]
if A.shape[0] < config["BLOCK_SIZE_M"]:
# optimize for small batch_size.
# We assume that top_ids of each token is unique, so
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
# and we can skip some invalid blocks.
EM = min(sorted_token_ids.shape[0],
A.shape[0] * top_k * config['BLOCK_SIZE_M'])
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
B.shape[1], META['BLOCK_SIZE_N']), )
fused_moe_kernel[grid](
A,
B,
C,
A_scale,
B_scale,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.shape[1],
B.shape[2],
sorted_token_ids.shape[0],
topk_ids.numel(),
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(1),
C.stride(2),
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
0 if block_shape is None else block_shape[0],
0 if block_shape is None else block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
**config,
)
if (use_int8_w8a16 or use_int4_w4a16) and \
block_shape is not None and block_shape[1] > 0:
assert B_scale is not None and B_scale.ndim == 3
assert B_zp is None or B_zp.ndim == 3
fused_moe_kernel_gptq_awq[grid](
A,
B,
C,
B_scale,
B_zp,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.shape[1],
A.shape[1],
EM,
topk_ids.numel(),
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(1),
C.stride(2),
B_scale.stride(0),
B_scale.stride(2),
B_scale.stride(1),
B_zp.stride(0) if B_zp is not None else 0,
B_zp.stride(2) if B_zp is not None else 0,
B_zp.stride(1) if B_zp is not None else 0,
block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0,
group_size=block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=compute_type,
has_zp=B_zp is not None,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16,
**config,
)
else:
fused_moe_kernel[grid](
A,
B,
C,
A_scale,
B_scale,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.shape[1],
A.shape[1],
EM,
topk_ids.numel(),
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(1),
C.stride(2),
A_scale.stride(0)
if A_scale is not None and A_scale.ndim == 2 else 0,
A_scale.stride(1)
if A_scale is not None and A_scale.ndim == 2 else 0,
B_scale.stride(0)
if B_scale is not None and B_scale.ndim >= 2 else 0,
B_scale.stride(2)
if B_scale is not None and B_scale.ndim == 3 else 0,
B_scale.stride(1)
if B_scale is not None and B_scale.ndim >= 2 else 0,
0 if block_shape is None else block_shape[0],
0 if block_shape is None else block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
**config,
)
def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
@@ -432,7 +692,7 @@ def try_get_optimal_moe_config(
# NOTE: For block-wise quant,
# BLOCK_K must be divisible by block_shape[1]
# BLOCK_N and BLOCK_M has no requirements
if block_shape is not None:
if block_shape is not None and block_shape[0] != 0:
config["BLOCK_SIZE_N"] = block_shape[0]
config["BLOCK_SIZE_K"] = block_shape[1]
return config
@@ -531,12 +791,15 @@ def grouped_topk(hidden_states: torch.Tensor,
def get_config_dtype_str(dtype: torch.dtype,
use_int4_w4a16: Optional[bool] = False,
use_int8_w8a16: Optional[bool] = False,
use_fp8_w8a8: Optional[bool] = False):
if use_fp8_w8a8:
return "fp8_w8a8"
elif use_int8_w8a16:
return "int8_w8a16"
elif use_int4_w4a16:
return "int4_w8a16"
elif dtype == torch.float:
# avoiding cases where kernel fails when float32 MoE
# use fp16/bfloat16 configs
@@ -551,14 +814,17 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
use_fp8_w8a8, use_int8_w8a16, w1_scale, w2_scale,
a1_scale, a2_scale, block_shape)
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
def inplace_fused_experts_fake(
@@ -569,8 +835,11 @@ def inplace_fused_experts_fake(
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None:
@@ -593,14 +862,18 @@ def outplace_fused_experts(
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, use_fp8_w8a8, use_int8_w8a16, w1_scale,
w2_scale, a1_scale, a2_scale, block_shape)
False, use_fp8_w8a8, use_int8_w8a16,
use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape)
def outplace_fused_experts_fake(
@@ -611,8 +884,11 @@ def outplace_fused_experts_fake(
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor:
@@ -635,8 +911,11 @@ def fused_experts(hidden_states: torch.Tensor,
inplace: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None):
@@ -644,16 +923,15 @@ def fused_experts(hidden_states: torch.Tensor,
torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2,
topk_weights, topk_ids,
use_fp8_w8a8, use_int8_w8a16,
w1_scale, w2_scale, a1_scale,
use_int4_w4a16, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale,
a2_scale, block_shape)
return hidden_states
else:
return torch.ops.vllm.outplace_fused_experts(hidden_states, w1, w2,
topk_weights, topk_ids,
use_fp8_w8a8,
use_int8_w8a16, w1_scale,
w2_scale, a1_scale,
a2_scale, block_shape)
return torch.ops.vllm.outplace_fused_experts(
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8,
use_int8_w8a16, use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape)
def fused_experts_impl(hidden_states: torch.Tensor,
@@ -664,13 +942,21 @@ def fused_experts_impl(hidden_states: torch.Tensor,
inplace: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None):
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
if use_int4_w4a16:
assert hidden_states.shape[1] // 2 == w1.shape[
2], "Hidden size mismatch"
else:
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
@@ -687,6 +973,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
M = min(num_tokens, CHUNK_SIZE)
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
dtype=hidden_states.dtype)
get_config_func = functools.partial(
@@ -755,6 +1042,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
intermediate_cache1,
a1_scale,
w1_scale,
w1_zp,
curr_topk_weights,
curr_topk_ids,
sorted_token_ids,
@@ -766,6 +1054,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
block_shape=block_shape)
torch.ops._C.silu_and_mul(intermediate_cache2,
@@ -776,6 +1065,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
intermediate_cache3,
a2_scale,
w2_scale,
w2_zp,
curr_topk_weights,
curr_topk_ids,
sorted_token_ids,
@@ -787,6 +1077,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
block_shape=block_shape)
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
@@ -808,8 +1099,11 @@ def fused_moe(
custom_routing_function: Optional[Callable] = None,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
@@ -834,8 +1128,12 @@ def fused_moe(
note: Deepseekv2 model uses grouped_topk
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
activation to compute the inner products for w1 and w2.
Defaults to False.
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
activation to compute the inner products for w1 and w2.
Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
@@ -873,8 +1171,11 @@ def fused_moe(
inplace=inplace,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape)