[Bugfix] Fix chunked a2_scales in modular kernels (#25264)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-09-19 15:42:01 -04:00
committed by GitHub
parent 7852b82b93
commit 4bdf400218
11 changed files with 23 additions and 5 deletions

View File

@@ -1598,6 +1598,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
expert_map: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
@@ -1690,7 +1691,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2q_scale: Optional[torch.Tensor] = None
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
intermediate_cache2, self.a2_scale, self.quant_dtype,
intermediate_cache2, a2_scale, self.quant_dtype,
self.per_act_token_quant, self.block_shape)
invoke_fused_moe_kernel(