[Bugfix] Fix chunked a2_scales in modular kernels (#25264)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -519,6 +519,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
@@ -634,6 +635,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> torch.Tensor:
|
||||
@@ -671,6 +673,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
@@ -718,6 +721,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=self.fused_experts.a2_scale,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
@@ -803,6 +807,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
a1q_scale=c_a1q_scale,
|
||||
a2_scale=c_a2_scale,
|
||||
expert_tokens_meta=c_expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user