[Kernels] Clean up FusedMoeMethodBase and modular kernel setup. Remove extra arguments from modular kernel methods. (#22035)

Signed-off-by: Bill Nell <bnell@redhat.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
bnellnm
2025-08-15 14:46:00 -04:00
committed by GitHub
parent 48b01fd4d4
commit 8ad7285ea2
54 changed files with 2010 additions and 1293 deletions

View File

@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from math import prod
from typing import Any, Optional, final
from typing import Optional, final
import torch
@@ -150,15 +150,23 @@ class FusedMoEPrepareAndFinalize(ABC):
@abstractmethod
def prepare(
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
topk_ids: torch.Tensor, num_experts: int,
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
) -> tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[ExpertTokensMetadata],
Optional[torch.Tensor],
Optional[torch.Tensor],
]:
"""
Perform any quantization (and/or) dispatching needed
for this kernel.
@@ -186,11 +194,15 @@ class FusedMoEPrepareAndFinalize(ABC):
raise NotImplementedError
@abstractmethod
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: TopKWeightAndReduce,
) -> None:
"""
Perform any combine plus apply weights and perform a reduction on the
fused experts output.
@@ -368,7 +380,6 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
workspace2: torch.Tensor,
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
):
"""
This function computes the intermediate result of a Mixture of Experts
@@ -454,18 +465,27 @@ class FusedMoEModularKernel(torch.nn.Module):
f"{fused_experts.activation_formats[0]}")
def _do_fused_experts(
self, fused_out: Optional[torch.Tensor], a1: torch.Tensor,
a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
activation: str, global_num_experts: int, local_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]) -> torch.Tensor:
self,
fused_out: Optional[torch.Tensor],
a1: torch.Tensor,
a1q: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
@@ -509,7 +529,7 @@ class FusedMoEModularKernel(torch.nn.Module):
workspace2=workspace2,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args)
)
return fused_out
@@ -533,7 +553,6 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
@@ -541,6 +560,9 @@ class FusedMoEModularKernel(torch.nn.Module):
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_chunks = cdiv(M, CHUNK_SIZE)
# TODO(bnell): get rid of one level here, update slice functions
# to nops on num_chunks==1
if not self.fused_experts.supports_chunking() or num_chunks == 1:
return self._do_fused_experts(
fused_out=None,
@@ -562,7 +584,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args)
)
# Chunking required case
assert num_chunks > 1
@@ -618,15 +640,6 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_num_tokens=c_expert_num_tokens,
expert_num_tokens_cpu=c_expert_num_tokens_cpu)
m = None
if extra_expert_args is not None and 'm' in extra_expert_args:
m = extra_expert_args.get('m')
if extra_expert_args is not None:
chunked_extra_expert_args = extra_expert_args
else:
chunked_extra_expert_args = {}
for chunk_idx in range(num_chunks):
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = (
slice_input_tensors(chunk_idx))
@@ -637,11 +650,6 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_tokens_meta, c_topk_ids, local_num_experts,
expert_map)
s = chunk_idx * CHUNK_SIZE
e = min(s + CHUNK_SIZE, M)
if m is not None:
chunked_extra_expert_args['m'] = e - s
self._do_fused_experts(
fused_out=slice_output_tensor(chunk_idx),
a1=a1,
@@ -662,7 +670,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale=c_a2_scale,
expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=chunked_extra_expert_args)
)
return fused_out
@@ -684,9 +692,6 @@ class FusedMoEModularKernel(torch.nn.Module):
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
extra_expert_args: Optional[dict] = None,
extra_prepare_args: Optional[dict] = None,
extra_finalize_args: Optional[dict] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
@@ -719,12 +724,6 @@ class FusedMoEModularKernel(torch.nn.Module):
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is
1.
- extra_expert_args (Optional[dict]): Extra keyword arguments to pass to
fused_experts.apply.
- extra_prepare_args (Optional[dict]): Extra keyword arguments to pass
to prepare.
- extra_finalize_args (Optional[dict]): Extra keyword arguments to pass
to finalize.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
@@ -748,7 +747,6 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
extra_prepare_args,
)
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
@@ -786,12 +784,15 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args)
)
self.prepare_finalize.finalize(
output, fused_out, topk_weights, topk_ids,
output,
fused_out,
topk_weights,
topk_ids,
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(),
extra_finalize_args)
)
return output