[Model] [Quantization] Support deepseek_v3 w8a8 fp8 block-wise quantization (#11523)
Signed-off-by: mgoin <michael@neuralmagic.com> Signed-off-by: simon-mo <simon.mo@hey.com> Signed-off-by: simon-mo <xmo@berkeley.edu> Co-authored-by: simon-mo <simon.mo@hey.com> Co-authored-by: simon-mo <xmo@berkeley.edu> Co-authored-by: HandH1998 <1335248067@qq.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -11,6 +11,8 @@ import triton.language as tl
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
@@ -45,8 +47,14 @@ def fused_moe_kernel(
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_asm,
|
||||
stride_ask,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
# Block size for block-wise quantization
|
||||
group_n: tl.constexpr,
|
||||
group_k: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
@@ -125,8 +133,14 @@ def fused_moe_kernel(
|
||||
b_scale = tl.load(b_scale_ptrs)
|
||||
|
||||
if use_fp8_w8a8:
|
||||
a_scale = tl.load(a_scale_ptr)
|
||||
b_scale = tl.load(b_scale_ptr + off_experts)
|
||||
if group_k > 0 and group_n > 0:
|
||||
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
||||
offs_bsn = offs_bn // group_n
|
||||
b_scale_ptrs = (b_scale_ptr + off_experts * stride_bse +
|
||||
offs_bsn * stride_bsn)
|
||||
else:
|
||||
a_scale = tl.load(a_scale_ptr)
|
||||
b_scale = tl.load(b_scale_ptr + off_experts)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Iterate to compute a block of the C matrix.
|
||||
@@ -149,7 +163,18 @@ def fused_moe_kernel(
|
||||
if use_int8_w8a16:
|
||||
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
||||
elif use_fp8_w8a8:
|
||||
accumulator = tl.dot(a, b, acc=accumulator)
|
||||
if group_k > 0 and group_n > 0:
|
||||
k_start = k * BLOCK_SIZE_K
|
||||
offs_ks = k_start // group_k
|
||||
a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask,
|
||||
mask=token_mask,
|
||||
other=0.0)
|
||||
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
|
||||
|
||||
accumulator += tl.dot(a, b) * a_scale[:,
|
||||
None] * b_scale[None, :]
|
||||
else:
|
||||
accumulator = tl.dot(a, b, acc=accumulator)
|
||||
else:
|
||||
accumulator += tl.dot(a, b)
|
||||
# Advance the ptrs to the next K block.
|
||||
@@ -164,7 +189,10 @@ def fused_moe_kernel(
|
||||
if use_int8_w8a16:
|
||||
accumulator = (accumulator * b_scale).to(compute_type)
|
||||
elif use_fp8_w8a8:
|
||||
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
||||
if group_k > 0 and group_n > 0:
|
||||
accumulator = accumulator.to(compute_type)
|
||||
else:
|
||||
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
||||
else:
|
||||
accumulator = accumulator.to(compute_type)
|
||||
# -----------------------------------------------------------
|
||||
@@ -233,22 +261,37 @@ def moe_align_block_size(
|
||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||
|
||||
|
||||
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
||||
def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
B_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
mul_routed_weight: bool, top_k: int,
|
||||
config: Dict[str, Any], compute_type: tl.dtype,
|
||||
use_fp8_w8a8: bool, use_int8_w8a16: bool) -> None:
|
||||
mul_routed_weight: bool,
|
||||
top_k: int,
|
||||
config: Dict[str, Any],
|
||||
compute_type: tl.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_shape: Optional[List[int]] = None) -> None:
|
||||
assert topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
|
||||
if use_fp8_w8a8:
|
||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||
assert B_scale is not None
|
||||
if block_shape is None:
|
||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||
else:
|
||||
assert len(block_shape) == 2
|
||||
block_n, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
||||
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
||||
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
||||
elif use_int8_w8a16:
|
||||
assert B_scale is not None
|
||||
else:
|
||||
@@ -279,8 +322,13 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
||||
B.stride(1),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,
|
||||
B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,
|
||||
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
|
||||
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||
0 if block_shape is None else block_shape[0],
|
||||
0 if block_shape is None else block_shape[1],
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
top_k=top_k,
|
||||
compute_type=compute_type,
|
||||
@@ -362,6 +410,7 @@ def try_get_optimal_moe_config(
|
||||
dtype: Optional[str],
|
||||
M: int,
|
||||
is_marlin: bool = False,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
):
|
||||
from vllm.model_executor.layers.fused_moe import get_config
|
||||
override_config = get_config()
|
||||
@@ -380,6 +429,12 @@ def try_get_optimal_moe_config(
|
||||
# Else use the default config
|
||||
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype,
|
||||
is_marlin)
|
||||
# NOTE: For block-wise quant,
|
||||
# BLOCK_K must be divisible by block_shape[1]
|
||||
# BLOCK_N and BLOCK_M has no requirements
|
||||
if block_shape is not None:
|
||||
config["BLOCK_SIZE_N"] = block_shape[0]
|
||||
config["BLOCK_SIZE_K"] = block_shape[1]
|
||||
return config
|
||||
|
||||
|
||||
@@ -479,10 +534,11 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None) -> None:
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> None:
|
||||
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
||||
use_fp8_w8a8, use_int8_w8a16, w1_scale, w2_scale,
|
||||
a1_scale, a2_scale)
|
||||
a1_scale, a2_scale, block_shape)
|
||||
|
||||
|
||||
def inplace_fused_experts_fake(
|
||||
@@ -496,7 +552,8 @@ def inplace_fused_experts_fake(
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None) -> None:
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@@ -519,10 +576,11 @@ def outplace_fused_experts(
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
|
||||
False, use_fp8_w8a8, use_int8_w8a16, w1_scale,
|
||||
w2_scale, a1_scale, a2_scale)
|
||||
w2_scale, a1_scale, a2_scale, block_shape)
|
||||
|
||||
|
||||
def outplace_fused_experts_fake(
|
||||
@@ -536,7 +594,8 @@ def outplace_fused_experts_fake(
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
@@ -559,18 +618,22 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None):
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None):
|
||||
if inplace:
|
||||
torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2,
|
||||
topk_weights, topk_ids,
|
||||
use_fp8_w8a8, use_int8_w8a16,
|
||||
w1_scale, w2_scale, a1_scale,
|
||||
a2_scale)
|
||||
a2_scale, block_shape)
|
||||
return hidden_states
|
||||
else:
|
||||
return torch.ops.vllm.outplace_fused_experts(
|
||||
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8,
|
||||
use_int8_w8a16, w1_scale, w2_scale, a1_scale, a2_scale)
|
||||
return torch.ops.vllm.outplace_fused_experts(hidden_states, w1, w2,
|
||||
topk_weights, topk_ids,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16, w1_scale,
|
||||
w2_scale, a1_scale,
|
||||
a2_scale, block_shape)
|
||||
|
||||
|
||||
def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
@@ -584,7 +647,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None):
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None):
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
@@ -611,6 +675,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
w2.shape,
|
||||
topk_ids.shape[1],
|
||||
config_dtype,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
config = get_config_func(M)
|
||||
@@ -674,7 +739,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16)
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
block_shape=block_shape)
|
||||
|
||||
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||
|
||||
@@ -693,7 +759,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16)
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
block_shape=block_shape)
|
||||
|
||||
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx])
|
||||
@@ -718,6 +785,7 @@ def fused_moe(
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
@@ -745,6 +813,12 @@ def fused_moe(
|
||||
w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w2.
|
||||
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
a1.
|
||||
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
a2.
|
||||
- block_shape: (Optional[List[int]]): Optional block size for block-wise
|
||||
quantization.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
@@ -775,4 +849,5 @@ def fused_moe(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale)
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape)
|
||||
|
||||
Reference in New Issue
Block a user