[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -177,8 +177,6 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
@@ -189,9 +187,6 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
Perform any quantization (and/or) dispatching needed for this kernel.
|
||||
- a1: The (unquantized) input to the MoE layer.
|
||||
- a1_scale: Optional scales for a1
|
||||
- a2_scale: Optional scales for the second MoE gemm. Required to make
|
||||
sure the quantization is consistent for both gemms.
|
||||
- topk_ids: The topk ids.
|
||||
- topk_weights: The topk weights.
|
||||
- num_experts: The total number of experts in the global expert space.
|
||||
@@ -199,10 +194,11 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
space to the local expert space of the expert parallel shard.
|
||||
- apply_router_weight_on_input: When True, apply the weights to the
|
||||
activations, before quantization + dispatching.
|
||||
- quant_config: Quantization info provided by the fused experts.
|
||||
|
||||
Returns a tuple of:
|
||||
- quantized + dispatched a.
|
||||
- quantized + dispatched a1_scales.
|
||||
- Optional quantized + dispatched a1_scales.
|
||||
- Optional ExpertTokensMetadata containing gpu/cpu tensors
|
||||
as big as the number of local experts with the information about the
|
||||
number of tokens assigned to each local expert.
|
||||
@@ -220,8 +216,6 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
def prepare_async(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
@@ -316,6 +310,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# TODO: add supported activations method (return string)
|
||||
class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
An abstract base class for the [Permute-Experts-Unpermute] step described
|
||||
@@ -324,12 +319,12 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: Optional[FusedMoEQuantConfig],
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
if quant_config is not None:
|
||||
self.quant_config = quant_config
|
||||
else:
|
||||
self.quant_config = FusedMoEQuantConfig()
|
||||
"""
|
||||
quant_config: Quantization parameters for this experts instance.
|
||||
"""
|
||||
self.quant_config = quant_config
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
@@ -341,6 +336,11 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
#
|
||||
# Various helpers for accessing quantization parameters from the
|
||||
# quant_config.
|
||||
#
|
||||
|
||||
@property
|
||||
def quant_dtype(self) -> Optional[torch.dtype]:
|
||||
return self.quant_config.quant_dtype
|
||||
@@ -357,6 +357,54 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
def per_out_ch_quant(self) -> bool:
|
||||
return self.quant_config.per_out_ch_quant
|
||||
|
||||
@property
|
||||
def a1_scale(self) -> Optional[torch.Tensor]:
|
||||
return self.quant_config.a1_scale
|
||||
|
||||
@property
|
||||
def a2_scale(self) -> Optional[torch.Tensor]:
|
||||
return self.quant_config.a2_scale
|
||||
|
||||
@property
|
||||
def a1_gscale(self) -> Optional[torch.Tensor]:
|
||||
return self.quant_config.a1_gscale
|
||||
|
||||
@property
|
||||
def a2_gscale(self) -> Optional[torch.Tensor]:
|
||||
return self.quant_config.a2_gscale
|
||||
|
||||
@property
|
||||
def w1_scale(self) -> Optional[torch.Tensor]:
|
||||
return self.quant_config.w1_scale
|
||||
|
||||
@property
|
||||
def w2_scale(self) -> Optional[torch.Tensor]:
|
||||
return self.quant_config.w2_scale
|
||||
|
||||
@property
|
||||
def w1_zp(self) -> Optional[torch.Tensor]:
|
||||
return self.quant_config.w1_zp
|
||||
|
||||
@property
|
||||
def w2_zp(self) -> Optional[torch.Tensor]:
|
||||
return self.quant_config.w2_zp
|
||||
|
||||
@property
|
||||
def w1_bias(self) -> Optional[torch.Tensor]:
|
||||
return self.quant_config.w1_bias
|
||||
|
||||
@property
|
||||
def w2_bias(self) -> Optional[torch.Tensor]:
|
||||
return self.quant_config.w2_bias
|
||||
|
||||
@property
|
||||
def g1_alphas(self) -> Optional[torch.Tensor]:
|
||||
return self.quant_config.g1_alphas
|
||||
|
||||
@property
|
||||
def g2_alphas(self) -> Optional[torch.Tensor]:
|
||||
return self.quant_config.g2_alphas
|
||||
|
||||
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
|
||||
@abstractmethod
|
||||
def supports_chunking(self) -> bool:
|
||||
@@ -433,12 +481,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
@@ -455,7 +498,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- topk_weights: A map of row to expert weights. Some implementations
|
||||
choose to do weight application.
|
||||
choose to do weight application.
|
||||
- topk_ids (torch.Tensor): A map of row to expert id.
|
||||
- activation (str): The activation function to apply after the first
|
||||
MoE layer.
|
||||
@@ -464,15 +507,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||
from the global expert space to the local expert space of the expert
|
||||
parallel shard.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
|
||||
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
|
||||
w1.
|
||||
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
|
||||
w2.
|
||||
- a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be
|
||||
used for a1.
|
||||
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
|
||||
used for a1. Result of quantization from prepare/finalize and not
|
||||
from the FusedMoEQuantConfig.
|
||||
- workspace13 (torch.Tensor): A scratch tensor used for gemm outputs
|
||||
must be large enough to hold output of either MoE gemm.
|
||||
- workspace2 (torch.Tensor): A scratch tensor used for the activation
|
||||
@@ -559,12 +596,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> torch.Tensor:
|
||||
@@ -601,12 +633,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
@@ -627,12 +654,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> torch.Tensor:
|
||||
@@ -658,12 +680,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
@@ -685,9 +702,13 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
|
||||
s = chunk_idx * CHUNK_SIZE
|
||||
e = min(s + CHUNK_SIZE, M)
|
||||
return (a1q[s:e], _chunk_scales(a1q_scale, s, e),
|
||||
_chunk_scales(a2_scale, s,
|
||||
e), topk_ids[s:e], topk_weights[s:e])
|
||||
return (
|
||||
a1q[s:e],
|
||||
_chunk_scales(a1q_scale, s, e),
|
||||
_chunk_scales(self.fused_experts.a2_scale, s, e),
|
||||
topk_ids[s:e],
|
||||
topk_weights[s:e],
|
||||
)
|
||||
|
||||
def slice_output_tensor(chunk_idx: int) -> torch.Tensor:
|
||||
assert fused_out.size(0) % M == 0, (
|
||||
@@ -744,12 +765,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=c_a1q_scale,
|
||||
a2_scale=c_a2_scale,
|
||||
expert_tokens_meta=c_expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
@@ -767,12 +783,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
activation: str = "silu",
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
@@ -795,14 +805,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||
from the global expert space to the local expert space of the expert
|
||||
parallel shard.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
|
||||
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
|
||||
w1.
|
||||
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
|
||||
w2.
|
||||
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1.
|
||||
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
|
||||
- apply_router_weight_on_input (bool): When true, the topk weights are
|
||||
applied directly on the inputs. This is only applicable when topk is
|
||||
1.
|
||||
@@ -832,8 +834,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
_expert_topk_weights) = self.prepare_finalize.prepare(
|
||||
a1,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
@@ -846,8 +846,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
dbo_maybe_run_recv_hook()
|
||||
hook, receiver = self.prepare_finalize.prepare_async(
|
||||
a1,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
@@ -897,12 +895,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user