[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -9,7 +9,8 @@ from tests.kernels.quant_utils import per_block_cast_to_int8
|
||||
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
@@ -34,18 +35,22 @@ def triton_moe(
|
||||
per_act_token_quant=False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
return fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
per_channel_quant=per_act_token_quant,
|
||||
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
||||
block_shape=block_shape)
|
||||
quant_config=quant_config)
|
||||
|
||||
|
||||
def batched_moe(
|
||||
@@ -64,6 +69,16 @@ def batched_moe(
|
||||
) -> torch.Tensor:
|
||||
max_num_tokens = round_up(a.shape[0], 64)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
@@ -72,21 +87,11 @@ def batched_moe(
|
||||
BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
return fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale)
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
|
||||
|
||||
def naive_batched_moe(
|
||||
@@ -105,6 +110,16 @@ def naive_batched_moe(
|
||||
) -> torch.Tensor:
|
||||
max_num_tokens = round_up(a.shape[0], 64)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
@@ -113,21 +128,11 @@ def naive_batched_moe(
|
||||
NaiveBatchedExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
return fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale)
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
|
||||
|
||||
def chunk_scales(scales: Optional[torch.Tensor], start: int,
|
||||
@@ -216,7 +221,7 @@ def make_test_weight(
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
per_out_ch_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
|
||||
@@ -228,7 +233,7 @@ def make_test_weight(
|
||||
w_gs_l = [None] * e
|
||||
for idx in range(e):
|
||||
w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
|
||||
w_16[idx], None, quant_dtype, per_act_token_quant, block_shape)
|
||||
w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape)
|
||||
|
||||
w = torch.stack(w_l)
|
||||
w_s = torch.stack(w_s_l)
|
||||
@@ -258,16 +263,16 @@ def make_test_weights(
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
per_out_ch_quant: bool = False,
|
||||
) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]],
|
||||
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]]:
|
||||
return (
|
||||
make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
|
||||
per_act_token_quant),
|
||||
per_out_ch_quant),
|
||||
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
|
||||
per_act_token_quant),
|
||||
per_out_ch_quant),
|
||||
)
|
||||
|
||||
|
||||
@@ -285,6 +290,76 @@ def per_token_cast_to_fp8(
|
||||
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
def make_test_quant_config(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
in_dtype: torch.dtype,
|
||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
|
||||
(_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
|
||||
e,
|
||||
n,
|
||||
k,
|
||||
in_dtype,
|
||||
quant_dtype,
|
||||
per_out_ch_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
# Hacky/trivial scales for nvfp4.
|
||||
a1_gscale: Optional[torch.Tensor] = None
|
||||
a2_gscale: Optional[torch.Tensor] = None
|
||||
if quant_dtype == "nvfp4":
|
||||
a1_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||
a2_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||
a1_scale = a1_gscale
|
||||
a2_scale = a2_gscale
|
||||
else:
|
||||
a1_scale = None
|
||||
a2_scale = None
|
||||
|
||||
return w1, w2, FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
a1_gscale=a1_gscale,
|
||||
a2_gscale=a2_gscale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
# TODO: make sure this is handled properly
|
||||
g1_alphas=(1 / w1_gs) if w1_gs is not None else None,
|
||||
g2_alphas=(1 / w2_gs) if w2_gs is not None else None,
|
||||
)
|
||||
|
||||
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
score: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool = False,
|
||||
quant_config: Optional[FusedMoEQuantConfig] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids, _ = fused_topk(hidden_states, score.float(), topk,
|
||||
renormalize)
|
||||
return fused_experts(hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
quant_config=quant_config)
|
||||
|
||||
|
||||
# CustomOp?
|
||||
class BaselineMM(torch.nn.Module):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user