[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-09-17 19:43:31 -04:00
committed by GitHub
parent e6585ddb45
commit 5963b98b46
68 changed files with 2698 additions and 2526 deletions

View File

@@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from .mk_objects import (expert_info, make_fused_experts,
from .mk_objects import (TestMoEQuantConfig, expert_info, make_fused_experts,
make_prepare_finalize, prepare_finalize_info)
from .parallel_utils import ProcessGroupInfo
@@ -40,7 +40,7 @@ class Config:
E: int
topks: Union[list[int], int]
dtype: torch.dtype
quant_config: Optional[FusedMoEQuantConfig]
quant_config: Optional[TestMoEQuantConfig]
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute
@@ -52,7 +52,7 @@ class Config:
def __post_init__(self):
if self.quant_config is None:
self.quant_config = FusedMoEQuantConfig()
self.quant_config = TestMoEQuantConfig(None, False, False, None)
def describe(self) -> str:
s = ""
@@ -275,21 +275,19 @@ class WeightTensors:
or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8)
def to_current_device(self):
self.w1 = self.w1.to(device=torch.cuda.current_device())
self.w2 = self.w2.to(device=torch.cuda.current_device())
device = torch.cuda.current_device()
self.w1 = self.w1.to(device=device)
self.w2 = self.w2.to(device=device)
if self.is_quantized():
assert self.w1_scale is not None
assert self.w2_scale is not None
self.w1_scale = self.w1_scale.to(
device=torch.cuda.current_device())
self.w2_scale = self.w2_scale.to(
device=torch.cuda.current_device())
if self.w1_scale is not None:
self.w1_scale = self.w1_scale.to(device=device)
if self.w2_scale is not None:
self.w2_scale = self.w2_scale.to(device=device)
if self.w1_gs is not None:
assert self.w2_gs is not None
self.w1_gs = self.w1_gs.to(device=torch.cuda.current_device())
self.w2_gs = self.w2_gs.to(device=torch.cuda.current_device())
self.w1_gs = self.w1_gs.to(device=device)
if self.w2_gs is not None:
self.w2_gs = self.w2_gs.to(device=device)
def slice_weights(self, rank: int,
num_local_experts: int) -> "WeightTensors":
@@ -297,20 +295,12 @@ class WeightTensors:
e = s + num_local_experts
w1 = self.w1[s:e, :, :]
w2 = self.w2[s:e, :, :]
w1_scale, w2_scale = (None, None)
if self.is_quantized():
assert self.w1_scale is not None
assert self.w2_scale is not None
w1_scale = self.w1_scale[s:e, :, :]
w2_scale = self.w2_scale[s:e, :, :]
w1_gs = self.w1_gs
w2_gs = self.w2_gs
if w1_gs is not None:
assert w2_gs is not None
w1_gs = w1_gs[s:e]
w2_gs = w2_gs[s:e]
w1_scale = self.w1_scale[
s:e, :, :] if self.w1_scale is not None else None
w2_scale = self.w2_scale[
s:e, :, :] if self.w2_scale is not None else None
w1_gs = self.w1_gs[s:e] if self.w1_gs is not None else None
w2_gs = self.w2_gs[s:e] if self.w2_gs is not None else None
return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs)
@@ -323,7 +313,8 @@ class WeightTensors:
in_dtype=config.dtype,
quant_dtype=config.quant_dtype,
block_shape=config.quant_block_shape,
per_act_token_quant=config.is_per_out_ch_quant,
per_out_ch_quant=config.
is_per_act_token_quant, # or config.is_per_out_ch_quant
)
return WeightTensors(w1=w1,
w2=w2,
@@ -342,8 +333,6 @@ class RankTensors:
topk_ids: torch.Tensor
expert_map: Optional[torch.Tensor]
quant_config: Optional[FusedMoEQuantConfig]
def describe(self):
s = ""
s += "== Rank Tensors: \n"
@@ -426,7 +415,6 @@ class RankTensors:
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
quant_config=config.quant_config,
)
@@ -522,10 +510,16 @@ def reference_moe_impl(config: Config, weights: WeightTensors,
and config.supports_apply_weight_on_input())
def _make_gscale(num_experts: int) -> torch.Tensor:
return torch.ones((num_experts, ),
device=torch.cuda.current_device(),
dtype=torch.float32)
def make_modular_kernel(
config: Config,
vllm_config: VllmConfig,
weights: WeightTensors,
quant_config: FusedMoEQuantConfig,
) -> mk.FusedMoEModularKernel:
def next_power_of_2(x):
@@ -548,20 +542,20 @@ def make_modular_kernel(
num_local_experts=config.num_local_experts,
moe_parallel_config=moe_parallel_config,
in_dtype=config.dtype,
quant_config=config.quant_config,
max_num_tokens=next_power_of_2(config.M),
)
# make modular kernel
prepare_finalize = make_prepare_finalize(config.prepare_finalize_type,
config.all2all_backend(), moe)
config.all2all_backend(), moe,
quant_config)
fused_experts = make_fused_experts(
config.fused_experts_type,
moe,
quant_config,
prepare_finalize.num_dispatchers(),
weights.w1_gs,
weights.w2_gs,
config.N,
)
modular_kernel = mk.FusedMoEModularKernel(
@@ -583,12 +577,38 @@ def run_modular_kernel(
# weights for rank
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
mk = make_modular_kernel(config, vllm_config, weights)
if config.quant_dtype == "nvfp4":
gscale = _make_gscale(config.num_local_experts)
else:
gscale = None
quant_config = FusedMoEQuantConfig.make(
config.quant_dtype,
w1_scale=rank_weights.w1_scale,
w2_scale=rank_weights.w2_scale,
a1_scale=rank_tensors.hidden_states_scale,
g1_alphas=(1 / rank_weights.w1_gs)
if rank_weights.w1_gs is not None else None,
g2_alphas=(1 / rank_weights.w2_gs)
if rank_weights.w2_gs is not None else None,
a1_gscale=gscale,
a2_gscale=gscale,
block_shape=config.quant_block_shape,
per_act_token_quant=config.is_per_act_token_quant,
per_out_ch_quant=config.is_per_out_ch_quant,
)
mk = make_modular_kernel(config, vllm_config, quant_config)
# impls might update the tensor in place
hidden_states = rank_tensors.hidden_states.clone()
topk_ids = rank_tensors.topk_ids.to(
mk.prepare_finalize.topk_indices_dtype())
mk_kwargs = {
"hidden_states":
rank_tensors.hidden_states.clone(
), # impls might update the tensor in place
hidden_states,
"w1":
rank_weights.w1,
"w2":
@@ -596,15 +616,9 @@ def run_modular_kernel(
"topk_weights":
rank_tensors.topk_weights,
"topk_ids":
rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype()),
topk_ids,
"expert_map":
rank_tensors.expert_map,
"w1_scale":
rank_weights.w1_scale,
"w2_scale":
rank_weights.w2_scale,
"a1_scale":
rank_tensors.hidden_states_scale,
"global_num_experts":
config.E,
"apply_router_weight_on_input":