[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-09-17 19:43:31 -04:00
committed by GitHub
parent e6585ddb45
commit 5963b98b46
68 changed files with 2698 additions and 2526 deletions

View File

@@ -177,8 +177,6 @@ class FusedMoEPrepareAndFinalize(ABC):
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
@@ -189,9 +187,6 @@ class FusedMoEPrepareAndFinalize(ABC):
"""
Perform any quantization (and/or) dispatching needed for this kernel.
- a1: The (unquantized) input to the MoE layer.
- a1_scale: Optional scales for a1
- a2_scale: Optional scales for the second MoE gemm. Required to make
sure the quantization is consistent for both gemms.
- topk_ids: The topk ids.
- topk_weights: The topk weights.
- num_experts: The total number of experts in the global expert space.
@@ -199,10 +194,11 @@ class FusedMoEPrepareAndFinalize(ABC):
space to the local expert space of the expert parallel shard.
- apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching.
- quant_config: Quantization info provided by the fused experts.
Returns a tuple of:
- quantized + dispatched a.
- quantized + dispatched a1_scales.
- Optional quantized + dispatched a1_scales.
- Optional ExpertTokensMetadata containing gpu/cpu tensors
as big as the number of local experts with the information about the
number of tokens assigned to each local expert.
@@ -220,8 +216,6 @@ class FusedMoEPrepareAndFinalize(ABC):
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
@@ -316,6 +310,7 @@ class FusedMoEPrepareAndFinalize(ABC):
raise NotImplementedError
# TODO: add supported activations method (return string)
class FusedMoEPermuteExpertsUnpermute(ABC):
"""
An abstract base class for the [Permute-Experts-Unpermute] step described
@@ -324,12 +319,12 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
def __init__(
self,
quant_config: Optional[FusedMoEQuantConfig],
quant_config: FusedMoEQuantConfig,
):
if quant_config is not None:
self.quant_config = quant_config
else:
self.quant_config = FusedMoEQuantConfig()
"""
quant_config: Quantization parameters for this experts instance.
"""
self.quant_config = quant_config
@property
@abstractmethod
@@ -341,6 +336,11 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
"""
raise NotImplementedError
#
# Various helpers for accessing quantization parameters from the
# quant_config.
#
@property
def quant_dtype(self) -> Optional[torch.dtype]:
return self.quant_config.quant_dtype
@@ -357,6 +357,54 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
def per_out_ch_quant(self) -> bool:
return self.quant_config.per_out_ch_quant
@property
def a1_scale(self) -> Optional[torch.Tensor]:
return self.quant_config.a1_scale
@property
def a2_scale(self) -> Optional[torch.Tensor]:
return self.quant_config.a2_scale
@property
def a1_gscale(self) -> Optional[torch.Tensor]:
return self.quant_config.a1_gscale
@property
def a2_gscale(self) -> Optional[torch.Tensor]:
return self.quant_config.a2_gscale
@property
def w1_scale(self) -> Optional[torch.Tensor]:
return self.quant_config.w1_scale
@property
def w2_scale(self) -> Optional[torch.Tensor]:
return self.quant_config.w2_scale
@property
def w1_zp(self) -> Optional[torch.Tensor]:
return self.quant_config.w1_zp
@property
def w2_zp(self) -> Optional[torch.Tensor]:
return self.quant_config.w2_zp
@property
def w1_bias(self) -> Optional[torch.Tensor]:
return self.quant_config.w1_bias
@property
def w2_bias(self) -> Optional[torch.Tensor]:
return self.quant_config.w2_bias
@property
def g1_alphas(self) -> Optional[torch.Tensor]:
return self.quant_config.g1_alphas
@property
def g2_alphas(self) -> Optional[torch.Tensor]:
return self.quant_config.g2_alphas
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
@abstractmethod
def supports_chunking(self) -> bool:
@@ -433,12 +481,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[ExpertTokensMetadata],
@@ -455,7 +498,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_weights: A map of row to expert weights. Some implementations
choose to do weight application.
choose to do weight application.
- topk_ids (torch.Tensor): A map of row to expert id.
- activation (str): The activation function to apply after the first
MoE layer.
@@ -464,15 +507,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
w1.
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
w2.
- a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be
used for a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
used for a1. Result of quantization from prepare/finalize and not
from the FusedMoEQuantConfig.
- workspace13 (torch.Tensor): A scratch tensor used for gemm outputs
must be large enough to hold output of either MoE gemm.
- workspace2 (torch.Tensor): A scratch tensor used for the activation
@@ -559,12 +596,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts: int,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
) -> torch.Tensor:
@@ -601,12 +633,7 @@ class FusedMoEModularKernel(torch.nn.Module):
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_tokens_meta=expert_tokens_meta,
@@ -627,12 +654,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts: int,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
) -> torch.Tensor:
@@ -658,12 +680,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts=global_num_experts,
local_num_experts=local_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
)
@@ -685,9 +702,13 @@ class FusedMoEModularKernel(torch.nn.Module):
Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
s = chunk_idx * CHUNK_SIZE
e = min(s + CHUNK_SIZE, M)
return (a1q[s:e], _chunk_scales(a1q_scale, s, e),
_chunk_scales(a2_scale, s,
e), topk_ids[s:e], topk_weights[s:e])
return (
a1q[s:e],
_chunk_scales(a1q_scale, s, e),
_chunk_scales(self.fused_experts.a2_scale, s, e),
topk_ids[s:e],
topk_weights[s:e],
)
def slice_output_tensor(chunk_idx: int) -> torch.Tensor:
assert fused_out.size(0) % M == 0, (
@@ -744,12 +765,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts=global_num_experts,
local_num_experts=local_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=c_a1q_scale,
a2_scale=c_a2_scale,
expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
)
@@ -767,12 +783,6 @@ class FusedMoEModularKernel(torch.nn.Module):
activation: str = "silu",
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""
@@ -795,14 +805,6 @@ class FusedMoEModularKernel(torch.nn.Module):
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
w1.
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
w2.
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is
1.
@@ -832,8 +834,6 @@ class FusedMoEModularKernel(torch.nn.Module):
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
global_num_experts,
@@ -846,8 +846,6 @@ class FusedMoEModularKernel(torch.nn.Module):
dbo_maybe_run_recv_hook()
hook, receiver = self.prepare_finalize.prepare_async(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
global_num_experts,
@@ -897,12 +895,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts=global_num_experts,
local_num_experts=local_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
)