[oss] Init gpt-oss bf16 support (#22508)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -275,6 +275,7 @@ def fused_moe_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
b_bias_ptr,
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
topk_weights_ptr,
|
||||
@@ -302,6 +303,8 @@ def fused_moe_kernel(
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
stride_bbe, # bias expert stride
|
||||
stride_bbn, # bias N stride
|
||||
# Block size for block-wise quantization
|
||||
group_n: tl.constexpr,
|
||||
group_k: tl.constexpr,
|
||||
@@ -317,6 +320,7 @@ def fused_moe_kernel(
|
||||
use_int8_w8a8: tl.constexpr,
|
||||
use_int8_w8a16: tl.constexpr,
|
||||
per_channel_quant: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Implements the fused computation for a Mixture of Experts (MOE) using
|
||||
@@ -414,7 +418,10 @@ def fused_moe_kernel(
|
||||
else:
|
||||
a_scale = tl.load(a_scale_ptr)
|
||||
b_scale = tl.load(b_scale_ptr + off_experts)
|
||||
|
||||
if HAS_BIAS:
|
||||
# bias shape: [num_experts, N]
|
||||
bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn
|
||||
bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0)
|
||||
# -----------------------------------------------------------
|
||||
# Iterate to compute a block of the C matrix.
|
||||
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
||||
@@ -456,7 +463,8 @@ def fused_moe_kernel(
|
||||
# Advance the ptrs to the next K block.
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
if HAS_BIAS:
|
||||
accumulator = accumulator + bias[None, :]
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token,
|
||||
mask=token_mask,
|
||||
@@ -471,6 +479,7 @@ def fused_moe_kernel(
|
||||
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
||||
else:
|
||||
accumulator = accumulator.to(compute_type)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Write back the block of the output
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
@@ -499,7 +508,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
per_channel_quant: bool,
|
||||
block_shape: Optional[list[int]] = None) -> None:
|
||||
block_shape: Optional[list[int]] = None,
|
||||
B_bias: Optional[torch.Tensor] = None) -> None:
|
||||
assert topk_weights is not None or not mul_routed_weight
|
||||
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
@@ -531,7 +541,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
A.size(0) * top_k * config['BLOCK_SIZE_M'])
|
||||
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
|
||||
B.size(1), META['BLOCK_SIZE_N']), )
|
||||
|
||||
HAS_BIAS = B_bias is not None
|
||||
if (use_int8_w8a16 or use_int4_w4a16) and \
|
||||
block_shape is not None and block_shape[1] > 0:
|
||||
assert B_scale is not None and B_scale.ndim == 3
|
||||
@@ -611,6 +621,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
B_bias,
|
||||
A_scale,
|
||||
B_scale,
|
||||
topk_weights,
|
||||
@@ -638,6 +649,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
if B_scale is not None and B_scale.ndim == 3 else 0,
|
||||
B_scale.stride(1)
|
||||
if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||
B_bias.stride(0) if B_bias is not None else 0,
|
||||
B_bias.stride(1) if B_bias is not None else 0,
|
||||
0 if block_shape is None else block_shape[0],
|
||||
0 if block_shape is None else block_shape[1],
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
@@ -647,6 +660,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
HAS_BIAS=HAS_BIAS,
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
**config,
|
||||
)
|
||||
@@ -1024,40 +1038,43 @@ def inplace_fused_experts(
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> None: #noqa: UP006
|
||||
block_shape: Optional[List[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None) -> None: #noqa: UP006
|
||||
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
||||
activation, is_act_and_mul,
|
||||
apply_router_weight_on_input, use_fp8_w8a8,
|
||||
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
|
||||
use_mxfp4_w4a4, per_channel_quant, global_num_experts,
|
||||
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
|
||||
a2_scale, block_shape)
|
||||
a2_scale, block_shape, w1_bias, w2_bias)
|
||||
|
||||
|
||||
def inplace_fused_experts_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None) -> None:
|
||||
def inplace_fused_experts_fake(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@@ -1246,36 +1263,38 @@ direct_register_custom_op(
|
||||
|
||||
|
||||
def outplace_fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None, #noqa: UP006
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None, #noqa: UP006
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return fused_experts_impl(
|
||||
hidden_states, w1, w2, topk_weights, topk_ids, False, activation,
|
||||
is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8,
|
||||
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4,
|
||||
per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale,
|
||||
w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
|
||||
w1_zp, w2_zp, a1_scale, a2_scale, block_shape, w1_bias, w2_bias)
|
||||
|
||||
|
||||
def outplace_fused_experts_fake(
|
||||
@@ -1300,7 +1319,9 @@ def outplace_fused_experts_fake(
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None) -> torch.Tensor:
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
@@ -1332,33 +1353,34 @@ def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
|
||||
|
||||
# TODO (bnell): replace this with modular op. Can get rid of inplace/outplace
|
||||
# torch ops.
|
||||
def fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
allow_deep_gemm: bool = False,
|
||||
allow_cutlass_block_scaled_grouped_gemm: bool = False) -> torch.Tensor:
|
||||
def fused_experts(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
allow_deep_gemm: bool = False,
|
||||
allow_cutlass_block_scaled_grouped_gemm: bool = False,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# For now, disable DeepGemm for small N (<= 512) until better
|
||||
# permute/unpermute ops are available.
|
||||
# However, on B200, we use DeepGemm for all cases because they only support
|
||||
@@ -1423,7 +1445,10 @@ def fused_experts(
|
||||
w2_zp=w2_zp,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape)
|
||||
block_shape=block_shape,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
)
|
||||
|
||||
|
||||
def fused_experts_impl(
|
||||
@@ -1451,6 +1476,8 @@ def fused_experts_impl(
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# Check constraints.
|
||||
if use_int4_w4a16:
|
||||
@@ -1591,7 +1618,19 @@ def fused_experts_impl(
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape)
|
||||
block_shape=block_shape,
|
||||
B_bias=w1_bias)
|
||||
|
||||
# TODO fused kernel
|
||||
def swiglu_oai(gate_up):
|
||||
alpha = 1.702
|
||||
limit = 7.0
|
||||
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
||||
gate = gate.clamp(min=None, max=limit)
|
||||
up = up.clamp(min=-limit, max=limit)
|
||||
glu = gate * torch.sigmoid(gate * alpha)
|
||||
gated_output = (up + 1) * glu
|
||||
return gated_output
|
||||
|
||||
# Activation function with multiplication
|
||||
if activation == "silu" and is_act_and_mul:
|
||||
@@ -1605,6 +1644,8 @@ def fused_experts_impl(
|
||||
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
|
||||
elif activation == "gelu":
|
||||
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
|
||||
elif activation == "swiglu_oai":
|
||||
intermediate_cache2 = swiglu_oai(intermediate_cache1.view(-1, N))
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}, "
|
||||
f"with is_act_and_mul={is_act_and_mul}.")
|
||||
@@ -1635,7 +1676,8 @@ def fused_experts_impl(
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape)
|
||||
block_shape=block_shape,
|
||||
B_bias=w2_bias)
|
||||
|
||||
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx])
|
||||
@@ -1672,6 +1714,8 @@ def fused_moe(
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
@@ -1766,7 +1810,9 @@ def fused_moe(
|
||||
w2_zp=w2_zp,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape)
|
||||
block_shape=block_shape,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias)
|
||||
|
||||
|
||||
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
@@ -1937,7 +1983,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
per_channel_quant=self.per_act_token_quant,
|
||||
block_shape=self.block_shape)
|
||||
block_shape=self.block_shape,
|
||||
B_bias=None # TODO support B_bias
|
||||
)
|
||||
|
||||
self.activation(activation, intermediate_cache2,
|
||||
intermediate_cache1.view(-1, N))
|
||||
@@ -1948,26 +1996,29 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
intermediate_cache2, a2_scale, self.quant_dtype,
|
||||
self.per_act_token_quant, self.block_shape)
|
||||
|
||||
invoke_fused_moe_kernel(qintermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2q_scale,
|
||||
w2_scale,
|
||||
w2_zp,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
not apply_router_weight_on_input,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a8=self.use_int8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
per_channel_quant=self.per_act_token_quant,
|
||||
block_shape=self.block_shape)
|
||||
invoke_fused_moe_kernel(
|
||||
qintermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2q_scale,
|
||||
w2_scale,
|
||||
w2_zp,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
not apply_router_weight_on_input,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a8=self.use_int8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
per_channel_quant=self.per_act_token_quant,
|
||||
block_shape=self.block_shape,
|
||||
B_bias=None # TODO support B_bias
|
||||
)
|
||||
|
||||
ops.moe_sum(intermediate_cache3, output)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user