[Kernels] Clean up FusedMoeMethodBase and modular kernel setup. Remove extra arguments from modular kernel methods. (#22035)
Signed-off-by: Bill Nell <bnell@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from math import prod
|
||||
from typing import Any, Optional, final
|
||||
from typing import Optional, final
|
||||
|
||||
import torch
|
||||
|
||||
@@ -150,15 +150,23 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def prepare(
|
||||
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, num_experts: int,
|
||||
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
extra_prepare_args: Optional[dict[str, Any]]
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
Optional[ExpertTokensMetadata],
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
]:
|
||||
"""
|
||||
Perform any quantization (and/or) dispatching needed
|
||||
for this kernel.
|
||||
@@ -186,11 +194,15 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: TopKWeightAndReduce,
|
||||
extra_finalize_args: Optional[dict[str, Any]]) -> None:
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: TopKWeightAndReduce,
|
||||
) -> None:
|
||||
"""
|
||||
Perform any combine plus apply weights and perform a reduction on the
|
||||
fused experts output.
|
||||
@@ -368,7 +380,6 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
extra_expert_args: Optional[dict[str, Any]],
|
||||
):
|
||||
"""
|
||||
This function computes the intermediate result of a Mixture of Experts
|
||||
@@ -454,18 +465,27 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
f"{fused_experts.activation_formats[0]}")
|
||||
|
||||
def _do_fused_experts(
|
||||
self, fused_out: Optional[torch.Tensor], a1: torch.Tensor,
|
||||
a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
activation: str, global_num_experts: int, local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
extra_expert_args: Optional[dict[str, Any]]) -> torch.Tensor:
|
||||
self,
|
||||
fused_out: Optional[torch.Tensor],
|
||||
a1: torch.Tensor,
|
||||
a1q: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> torch.Tensor:
|
||||
|
||||
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
||||
|
||||
@@ -509,7 +529,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
workspace2=workspace2,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
extra_expert_args=extra_expert_args)
|
||||
)
|
||||
|
||||
return fused_out
|
||||
|
||||
@@ -533,7 +553,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
extra_expert_args: Optional[dict[str, Any]],
|
||||
) -> torch.Tensor:
|
||||
|
||||
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
||||
@@ -541,6 +560,9 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
num_chunks = cdiv(M, CHUNK_SIZE)
|
||||
|
||||
# TODO(bnell): get rid of one level here, update slice functions
|
||||
# to nops on num_chunks==1
|
||||
|
||||
if not self.fused_experts.supports_chunking() or num_chunks == 1:
|
||||
return self._do_fused_experts(
|
||||
fused_out=None,
|
||||
@@ -562,7 +584,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a2_scale=a2_scale,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
extra_expert_args=extra_expert_args)
|
||||
)
|
||||
|
||||
# Chunking required case
|
||||
assert num_chunks > 1
|
||||
@@ -618,15 +640,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
expert_num_tokens=c_expert_num_tokens,
|
||||
expert_num_tokens_cpu=c_expert_num_tokens_cpu)
|
||||
|
||||
m = None
|
||||
if extra_expert_args is not None and 'm' in extra_expert_args:
|
||||
m = extra_expert_args.get('m')
|
||||
|
||||
if extra_expert_args is not None:
|
||||
chunked_extra_expert_args = extra_expert_args
|
||||
else:
|
||||
chunked_extra_expert_args = {}
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = (
|
||||
slice_input_tensors(chunk_idx))
|
||||
@@ -637,11 +650,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
expert_tokens_meta, c_topk_ids, local_num_experts,
|
||||
expert_map)
|
||||
|
||||
s = chunk_idx * CHUNK_SIZE
|
||||
e = min(s + CHUNK_SIZE, M)
|
||||
|
||||
if m is not None:
|
||||
chunked_extra_expert_args['m'] = e - s
|
||||
self._do_fused_experts(
|
||||
fused_out=slice_output_tensor(chunk_idx),
|
||||
a1=a1,
|
||||
@@ -662,7 +670,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a2_scale=c_a2_scale,
|
||||
expert_tokens_meta=c_expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
extra_expert_args=chunked_extra_expert_args)
|
||||
)
|
||||
|
||||
return fused_out
|
||||
|
||||
@@ -684,9 +692,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
extra_expert_args: Optional[dict] = None,
|
||||
extra_prepare_args: Optional[dict] = None,
|
||||
extra_finalize_args: Optional[dict] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets
|
||||
@@ -719,12 +724,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
- apply_router_weight_on_input (bool): When true, the topk weights are
|
||||
applied directly on the inputs. This is only applicable when topk is
|
||||
1.
|
||||
- extra_expert_args (Optional[dict]): Extra keyword arguments to pass to
|
||||
fused_experts.apply.
|
||||
- extra_prepare_args (Optional[dict]): Extra keyword arguments to pass
|
||||
to prepare.
|
||||
- extra_finalize_args (Optional[dict]): Extra keyword arguments to pass
|
||||
to finalize.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
@@ -748,7 +747,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
extra_prepare_args,
|
||||
)
|
||||
|
||||
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
|
||||
@@ -786,12 +784,15 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a2_scale=a2_scale,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
extra_expert_args=extra_expert_args)
|
||||
)
|
||||
|
||||
self.prepare_finalize.finalize(
|
||||
output, fused_out, topk_weights, topk_ids,
|
||||
output,
|
||||
fused_out,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
extra_finalize_args)
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user