[Bugfix] Fix chunked a2_scales in modular kernels (#25264)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -1598,6 +1598,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
@@ -1690,7 +1691,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||
intermediate_cache2, self.a2_scale, self.quant_dtype,
|
||||
intermediate_cache2, a2_scale, self.quant_dtype,
|
||||
self.per_act_token_quant, self.block_shape)
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
|
||||
Reference in New Issue
Block a user