[Bugfix] Fix chunked a2_scales in modular kernels (#25264)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-09-19 15:42:01 -04:00
committed by GitHub
parent 7852b82b93
commit 4bdf400218
11 changed files with 23 additions and 5 deletions

View File

@@ -519,6 +519,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
global_num_experts: int,
expert_map: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[ExpertTokensMetadata],
@@ -634,6 +635,7 @@ class FusedMoEModularKernel(torch.nn.Module):
local_num_experts: int,
expert_map: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
) -> torch.Tensor:
@@ -671,6 +673,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts=global_num_experts,
expert_map=expert_map,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_tokens_meta=expert_tokens_meta,
@@ -718,6 +721,7 @@ class FusedMoEModularKernel(torch.nn.Module):
local_num_experts=local_num_experts,
expert_map=expert_map,
a1q_scale=a1q_scale,
a2_scale=self.fused_experts.a2_scale,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
)
@@ -803,6 +807,7 @@ class FusedMoEModularKernel(torch.nn.Module):
local_num_experts=local_num_experts,
expert_map=expert_map,
a1q_scale=c_a1q_scale,
a2_scale=c_a2_scale,
expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
)