[Bugfix] Fix chunked a2_scales in modular kernels (#25264)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-09-19 15:42:01 -04:00
committed by GitHub
parent 7852b82b93
commit 4bdf400218
11 changed files with 23 additions and 5 deletions

View File

@@ -241,6 +241,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
expert_map: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
@@ -262,7 +263,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, self.w1_scale, self.w2_scale,
a1q_scale, self.a2_scale, self.ab_strides1, self.ab_strides2,
a1q_scale, a2_scale, self.ab_strides1, self.ab_strides2,
self.c_strides1, self.c_strides2, workspace13, workspace2,
expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype,
@@ -705,6 +706,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
expert_map: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor], # unused
a2_scale: Optional[torch.Tensor], # unused
workspace13: Optional[torch.Tensor],
workspace2: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],