[Kernels] Clean up FusedMoeMethodBase and modular kernel setup. Remove extra arguments from modular kernel methods. (#22035)

Signed-off-by: Bill Nell <bnell@redhat.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
bnellnm
2025-08-15 14:46:00 -04:00
committed by GitHub
parent 48b01fd4d4
commit 8ad7285ea2
54 changed files with 2010 additions and 1293 deletions

View File

@@ -7,41 +7,22 @@ import torch
import vllm._custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import make_test_weights, per_token_cast_to_fp8
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
# Fused experts and PrepareFinalize imports
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts)
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase,
TritonExperts)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from .mk_objects import (expert_info, make_fused_experts,
make_prepare_finalize, prepare_finalize_info)
from .parallel_utils import ProcessGroupInfo
from .utils import (make_block_quant_fp8_weights, make_non_quant_weights,
make_quant_fp8_weights, per_token_cast_to_fp8)
if has_pplx():
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str:
@@ -69,24 +50,31 @@ class Config:
torch_trace_dir_path: Optional[str] = None
def __post_init__(self):
if self.quant_config is None:
self.quant_config = FusedMoEQuantConfig()
def describe(self) -> str:
s = ""
s += "== Config: \n"
s += f" world_size={self.world_size} \n"
s += f" PF={self.prepare_finalize_type.__name__} \n"
s += f" FE={self.fused_experts_type.__name__} \n"
s += f" topk={self.topks} \n"
s += f" dtype={self.dtype} \n"
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n"
s += " Quant: \n"
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n "
s += "== Config:\n"
s += f" world_size={self.world_size}\n"
s += f" PF={self.prepare_finalize_type.__name__}\n"
s += f" FE={self.fused_experts_type.__name__}\n"
s += f" E={self.E}\n"
s += f" Ms={self.Ms}\n"
s += f" N={self.N}\n"
s += f" K={self.K}\n"
s += f" topk={self.topks}\n"
s += f" dtype={self.dtype}\n"
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size}\n"
s += " Quant:\n"
if self.quant_config is not None:
s += f" q_dtype={self.quant_dtype} \n"
s += f" q_block_shape={self.quant_block_shape} \n"
s += f" q_per_out_ch_quant={self.is_per_out_ch_quant} \n"
s += f" q_per_act_token={self.is_per_act_token_quant} \n"
s += f" q_dtype={self.quant_dtype}\n"
s += f" q_block_shape={self.quant_block_shape}\n"
s += f" q_per_out_ch_quant={self.is_per_out_ch_quant}\n"
s += f" q_per_act_token={self.is_per_act_token_quant}\n"
else:
s += " quant=None \n"
s += " quant=None\n"
return s
@property
@@ -95,34 +83,28 @@ class Config:
return self.Ms
@property
def quant_dtype(self) -> Optional[torch.dtype]:
if self.quant_config is None:
return None
def quant_dtype(self) -> Union[torch.dtype, str, None]:
assert self.quant_config is not None
return self.quant_config.quant_dtype
@property
def is_per_act_token_quant(self) -> bool:
if self.quant_config is None:
return False
assert self.quant_config is not None
return self.quant_config.per_act_token_quant
@property
def is_per_tensor_act_quant(self) -> bool:
if self.quant_config is None:
return False
return (not self.is_per_act_token_quant
and self.quant_block_shape is None)
@property
def is_per_out_ch_quant(self) -> bool:
if self.quant_config is None:
return False
assert self.quant_config is not None
return self.quant_config.per_out_ch_quant
@property
def quant_block_shape(self) -> Optional[list[int]]:
if self.quant_config is None:
return None
assert self.quant_config is not None
return self.quant_config.block_shape
@property
@@ -130,36 +112,30 @@ class Config:
assert isinstance(self.topks, int)
return self.topks
@property
def topk_ids_dtype(self) -> Optional[torch.dtype]:
topk_ids_dtype = None
if self.prepare_finalize_type == PplxPrepareAndFinalize:
topk_ids_dtype = torch.uint32
elif self.prepare_finalize_type in [
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
]:
topk_ids_dtype = torch.int64
return topk_ids_dtype
@property
def num_local_experts(self) -> int:
return self.E // self.world_size
def make_env_data(self) -> tuple[VllmConfig, dict[Any, Any]]:
"""
make env data for vllm launch.
make env data for vllm launch.
"""
vllm_config = VllmConfig()
vllm_config.parallel_config.data_parallel_size = self.world_size
vllm_config.parallel_config.enable_expert_parallel = True
env_dict = {
"VLLM_ALL2ALL_BACKEND": self.all2all_backend(),
"VLLM_USE_DEEP_GEMM": str(int(self.needs_deep_gemm())),
}
backend = self.all2all_backend()
if backend is not None:
env_dict.update({"VLLM_ALL2ALL_BACKEND": backend})
if self.fused_moe_chunk_size is not None:
env_dict.update(
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)})
return vllm_config, env_dict
def is_fp8_block_quantized(self):
@@ -167,85 +143,59 @@ class Config:
and self.quant_block_shape is not None)
def is_batched_prepare_finalize(self):
return self.prepare_finalize_type in [
PplxPrepareAndFinalize, DeepEPLLPrepareAndFinalize
]
info = prepare_finalize_info(self.prepare_finalize_type)
return (mk.FusedMoEActivationFormat.BatchedExperts ==
info.activation_format)
def is_batched_fused_experts(self):
return self.fused_experts_type in [
CutlassExpertsFp8, BatchedDeepGemmExperts, BatchedTritonExperts,
NaiveBatchedExperts, BatchedTritonOrDeepGemmExperts
]
info = expert_info(self.fused_experts_type)
return (mk.FusedMoEActivationFormat.BatchedExperts ==
info.activation_format)
def is_standard_fused_experts(self):
return self.fused_experts_type in [
CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts,
TritonExperts
]
info = expert_info(self.fused_experts_type)
return mk.FusedMoEActivationFormat.Standard == info.activation_format
def is_fe_16bit_supported(self):
return self.fused_experts_type in [
BatchedTritonExperts, BatchedTritonOrDeepGemmExperts,
NaiveBatchedExperts, TritonExperts
]
def fe_supported_types(self):
info = expert_info(self.fused_experts_type)
return info.supported_dtypes
def is_fe_fp8_supported(self):
return self.fused_experts_type in [
BatchedDeepGemmExperts,
BatchedTritonExperts,
BatchedTritonOrDeepGemmExperts,
CutlassExpertsFp8,
DeepGemmExperts,
TritonExperts,
TritonOrDeepGemmExperts,
NaiveBatchedExperts,
]
def pf_supported_types(self):
info = prepare_finalize_info(self.prepare_finalize_type)
return info.supported_dtypes
def is_fe_block_fp8_supported(self):
return self.fused_experts_type in [
BatchedDeepGemmExperts,
BatchedTritonOrDeepGemmExperts,
DeepGemmExperts,
TritonExperts,
TritonOrDeepGemmExperts,
BatchedTritonExperts,
NaiveBatchedExperts,
]
def is_block_quant_supported(self):
info = expert_info(self.fused_experts_type)
return info.blocked_quantization_support
def is_fe_supports_chunking(self):
return self.fused_experts_type in [
CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts,
TritonExperts
]
info = expert_info(self.fused_experts_type)
return info.supports_chunking
def supports_expert_map(self):
info = expert_info(self.fused_experts_type)
return info.supports_expert_map
def supports_apply_weight_on_input(self):
info = prepare_finalize_info(self.prepare_finalize_type)
return info.supports_apply_weight_on_input
def needs_deep_gemm(self):
return self.fused_experts_type in [
BatchedDeepGemmExperts,
DeepGemmExperts,
]
info = expert_info(self.fused_experts_type)
return info.needs_deep_gemm
def needs_pplx(self):
return self.prepare_finalize_type in [PplxPrepareAndFinalize]
info = prepare_finalize_info(self.prepare_finalize_type)
return info.backend == "pplx"
def needs_deep_ep(self):
return self.prepare_finalize_type in [
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
]
info = prepare_finalize_info(self.prepare_finalize_type)
return (info.backend == "deepep_high_throughput"
or info.backend == "deepep_low_latency")
def all2all_backend(self):
if self.needs_pplx():
return "pplx"
if self.prepare_finalize_type == DeepEPHTPrepareAndFinalize:
return "deepep_high_throughput"
if self.prepare_finalize_type == DeepEPLLPrepareAndFinalize:
return "deepep_low_latency"
return "naive"
def needs_all2all(self):
return self.prepare_finalize_type in [
PplxPrepareAndFinalize, DeepEPHTPrepareAndFinalize,
DeepEPLLPrepareAndFinalize
]
info = prepare_finalize_info(self.prepare_finalize_type)
return info.backend
def is_valid(self):
# Check prepare-finalize and fused-experts compatibility
@@ -267,28 +217,28 @@ class Config:
# invalid quant config
return False
# check bf16 / fp16 support
is_16bit = (self.dtype.itemsize == 2 and self.quant_dtype is None)
if is_16bit and not self.is_fe_16bit_supported():
return False
# check type support
if self.quant_dtype is None:
if (self.dtype not in self.pf_supported_types()
or self.dtype not in self.fe_supported_types()):
return False
else:
if (self.quant_dtype not in self.pf_supported_types()
or self.quant_dtype not in self.fe_supported_types()):
return False
# Check fp8 support
is_fp8 = self.quant_dtype == torch.float8_e4m3fn
if is_fp8 and not self.is_fe_fp8_supported():
return False
# Check fp8 block quanization support
# Check block quanization support
is_block_quatized = self.quant_block_shape is not None
if is_block_quatized and not is_fp8:
if is_block_quatized and self.quant_dtype is None:
return False
if is_block_quatized and not self.is_fe_block_fp8_supported():
if is_block_quatized and not self.is_block_quant_supported():
return False
# deep_gemm only works with block-quantized
if self.needs_deep_gemm() and not is_block_quatized:
return False
# Check dependencies
# Check dependencies (turn into asserts?)
if self.needs_deep_ep() and not has_deep_ep():
return False
if self.needs_deep_gemm() and not has_deep_gemm():
@@ -305,6 +255,8 @@ class WeightTensors:
w2: torch.Tensor
w1_scale: Optional[torch.Tensor]
w2_scale: Optional[torch.Tensor]
w1_gs: Optional[torch.Tensor] = None
w2_gs: Optional[torch.Tensor] = None
def describe(self):
s = ""
@@ -313,13 +265,20 @@ class WeightTensors:
s += f' - {_describe_tensor(self.w2, "w2")} \n'
s += f' - {_describe_tensor(self.w1_scale, "w1_scale")} \n'
s += f' - {_describe_tensor(self.w2_scale, "w2_scale")} \n'
s += f' - {_describe_tensor(self.w1_gs, "w1_gs")} \n'
s += f' - {_describe_tensor(self.w2_gs, "w2_gs")} \n'
return s
def is_quantized(self) -> bool:
# or w1_scale is not None?
return (self.w1.dtype == torch.float8_e4m3fn
or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8)
def to_current_device(self):
self.w1 = self.w1.to(device=torch.cuda.current_device())
self.w2 = self.w2.to(device=torch.cuda.current_device())
is_quantized = self.w1.dtype == torch.float8_e4m3fn
if is_quantized:
if self.is_quantized():
assert self.w1_scale is not None
assert self.w2_scale is not None
self.w1_scale = self.w1_scale.to(
@@ -327,56 +286,51 @@ class WeightTensors:
self.w2_scale = self.w2_scale.to(
device=torch.cuda.current_device())
if self.w1_gs is not None:
assert self.w2_gs is not None
self.w1_gs = self.w1_gs.to(device=torch.cuda.current_device())
self.w2_gs = self.w2_gs.to(device=torch.cuda.current_device())
def slice_weights(self, rank: int,
num_local_experts: int) -> "WeightTensors":
s = rank * num_local_experts
e = s + num_local_experts
w1 = self.w1[s:e, :, :]
w2 = self.w2[s:e, :, :]
is_quantized = self.w1.dtype == torch.float8_e4m3fn
w1_scale, w2_scale = (None, None)
if is_quantized:
if self.is_quantized():
assert self.w1_scale is not None
assert self.w2_scale is not None
w1_scale = self.w1_scale[s:e, :, :]
w2_scale = self.w2_scale[s:e, :, :]
return WeightTensors(w1, w2, w1_scale, w2_scale)
w1_gs = self.w1_gs
w2_gs = self.w2_gs
if w1_gs is not None:
assert w2_gs is not None
w1_gs = w1_gs[s:e]
w2_gs = w2_gs[s:e]
return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs)
@staticmethod
def make(config: Config) -> "WeightTensors":
if config.quant_dtype is None:
# just make normal dtype weights
w1, w2 = make_non_quant_weights(e=config.E,
n=config.N,
k=config.K,
dtype=config.dtype)
return WeightTensors(w1=w1, w2=w2, w1_scale=None, w2_scale=None)
assert config.quant_dtype == torch.float8_e4m3fn
if not config.is_fp8_block_quantized():
w1, w2, w1_scale, w2_scale = make_quant_fp8_weights(
e=config.E,
n=config.N,
k=config.K,
per_out_channel_quant=config.is_per_out_ch_quant,
)
return WeightTensors(w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale)
assert config.quant_block_shape is not None
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
(_, w1, w1_scale, w1_gs), (_, w2, w2_scale, w2_gs) = make_test_weights(
e=config.E,
n=config.N,
k=config.K,
block_size=config.quant_block_shape,
in_dtype=config.dtype,
quant_dtype=config.quant_dtype,
block_shape=config.quant_block_shape,
per_act_token_quant=config.is_per_out_ch_quant,
)
return WeightTensors(w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale)
w2_scale=w2_scale,
w1_gs=w1_gs,
w2_gs=w2_gs)
@dataclass
@@ -449,7 +403,6 @@ class RankTensors:
dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk,
False)
topk_ids = topk_ids.to(config.topk_ids_dtype)
# distribute topk_ids evenly
for mi in range(m):
@@ -457,7 +410,7 @@ class RankTensors:
topk_ids = topk_ids.to(device=torch.cuda.current_device())
expert_map = None
if config.world_size > 1:
if config.world_size > 1 and config.supports_expert_map():
expert_map = torch.full((global_num_experts, ),
fill_value=-1,
dtype=torch.int32)
@@ -480,92 +433,100 @@ class RankTensors:
def reference_moe_impl(config: Config, weights: WeightTensors,
rank_tensors: RankTensors) -> torch.Tensor:
return torch_experts(a=rank_tensors.hidden_states,
w1=weights.w1,
w2=weights.w2,
if config.quant_dtype == "nvfp4":
quant_blocksize = 16
dtype = config.dtype
w1_q = weights.w1
w1_blockscale = weights.w1_scale
w1_gs = weights.w1_gs
w2_q = weights.w2
w2_blockscale = weights.w2_scale
w2_gs = weights.w2_gs
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(
rank_tensors.hidden_states.flatten(), dim=-1)).to(torch.float32)
assert w1_gs is not None
assert w2_gs is not None
assert w1_blockscale is not None
assert w2_blockscale is not None
assert w1_blockscale.shape[1] % 128 == 0
assert w1_blockscale.shape[2] % 4 == 0
assert w2_blockscale.shape[1] % 128 == 0
assert w2_blockscale.shape[2] % 4 == 0
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(
rank_tensors.hidden_states, a_global_scale)
a = dequantize_nvfp4_to_dtype(a_fp4,
a_scale_interleaved,
a_global_scale,
dtype=dtype,
device=a_fp4.device,
block_size=quant_blocksize)
e = w1_q.shape[0]
n = w1_q.shape[1] // 2
k = w2_q.shape[1]
w1 = torch.zeros((e, 2 * n, k), device="cuda", dtype=dtype)
w2 = torch.zeros((e, k, n), device="cuda", dtype=dtype)
for idx in range(0, e):
w1[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
w1_blockscale[idx],
w1_gs[idx],
dtype=dtype,
device=w1_q.device,
block_size=quant_blocksize)
w2[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
w2_blockscale[idx],
w2_gs[idx],
dtype=dtype,
device=w2_q.device,
block_size=quant_blocksize)
a_scale = None
w1_scale = None
w2_scale = None
quant_dtype = None
per_act_token_quant = False
block_shape = None
else:
a = rank_tensors.hidden_states
a_scale = rank_tensors.hidden_states_scale
w1 = weights.w1
w1_scale = weights.w1_scale
w2 = weights.w2
w2_scale = weights.w2_scale
quant_dtype = config.quant_dtype
per_act_token_quant = config.is_per_act_token_quant
block_shape = config.quant_block_shape
return torch_experts(a=a,
w1=w1,
w2=w2,
topk_weight=rank_tensors.topk_weights,
topk_ids=rank_tensors.topk_ids,
global_num_experts=config.E,
expert_map=None,
w1_scale=weights.w1_scale,
w2_scale=weights.w2_scale,
a1_scale=rank_tensors.hidden_states_scale,
quant_dtype=config.quant_dtype,
per_act_token_quant=config.is_per_act_token_quant,
block_shape=config.quant_block_shape,
apply_router_weights_on_input=config.topk == 1)
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
apply_router_weights_on_input=config.topk == 1
and config.supports_apply_weight_on_input())
def make_fused_experts(
config: Config, moe: FusedMoEConfig,
num_dispatchers: int) -> mk.FusedMoEPermuteExpertsUnpermute:
use_fp8 = config.quant_dtype == torch.float8_e4m3fn
batch_kwargs = {
"max_num_tokens": moe.max_num_tokens,
"num_dispatchers": num_dispatchers,
}
quant_kwargs = {
"use_fp8_w8a8": use_fp8,
"use_int8_w8a8": False,
"use_int8_w8a16": False,
"use_int4_w4a16": False,
"block_shape": config.quant_block_shape,
"per_act_token_quant": config.is_per_act_token_quant,
}
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
if config.fused_experts_type == BatchedDeepGemmExperts:
kwargs = batch_kwargs | {
"block_shape": config.quant_block_shape,
"per_act_token_quant": config.is_per_act_token_quant,
}
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
experts = BatchedDeepGemmExperts(**kwargs)
elif config.fused_experts_type == BatchedTritonExperts:
kwargs = batch_kwargs | quant_kwargs
print(f"Making BatchedTritonExperts {kwargs} ...")
experts = BatchedTritonExperts(**kwargs)
elif config.fused_experts_type == BatchedTritonOrDeepGemmExperts:
kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
elif config.fused_experts_type == DeepGemmExperts:
print("Making DeepGemmExperts () ...")
experts = DeepGemmExperts()
elif config.fused_experts_type == TritonExperts:
kwargs = quant_kwargs
print(f"Making TritonExperts {kwargs} ...")
experts = TritonExperts(**kwargs)
elif config.fused_experts_type == TritonOrDeepGemmExperts:
kwargs = quant_kwargs | deepgemm_kwargs
print(f"Making TritonOrDeepGemmExperts {kwargs} ...")
experts = TritonOrDeepGemmExperts(**kwargs)
elif config.fused_experts_type == NaiveBatchedExperts:
kwargs = batch_kwargs | quant_kwargs
print(f"Making NaiveBatchedExperts {kwargs} ...")
experts = NaiveBatchedExperts(**kwargs)
elif config.fused_experts_type == CutlassExpertsFp8:
use_batched_format = config.is_batched_prepare_finalize()
num_experts = (moe.num_local_experts
if use_batched_format else moe.num_experts)
kwargs = {
"max_experts_per_worker": num_experts,
"out_dtype": moe.in_dtype,
"per_act_token_quant": config.is_per_act_token_quant,
"per_out_ch_quant": config.is_per_out_ch_quant,
"block_shape": config.quant_block_shape,
"num_dispatchers": num_dispatchers,
"use_batched_format": use_batched_format
}
print(f"Making CutlassExpertsFp8 {kwargs} ...")
experts = CutlassExpertsFp8(**kwargs)
return experts
def make_modular_kernel(config: Config,
vllm_config: VllmConfig) -> mk.FusedMoEModularKernel:
def make_modular_kernel(
config: Config,
vllm_config: VllmConfig,
weights: WeightTensors,
) -> mk.FusedMoEModularKernel:
def next_power_of_2(x):
import math
@@ -579,6 +540,7 @@ def make_modular_kernel(config: Config,
dp_size_=get_dp_group().world_size,
vllm_parallel_config=vllm_config.parallel_config,
)
moe = FusedMoEConfig(
num_experts=config.E,
experts_per_token=config.topk,
@@ -591,15 +553,16 @@ def make_modular_kernel(config: Config,
)
# make modular kernel
prepare_finalize = None
if config.needs_all2all():
prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize(moe)
assert prepare_finalize is not None
else:
prepare_finalize = MoEPrepareAndFinalizeNoEP()
prepare_finalize = make_prepare_finalize(config.prepare_finalize_type,
config.all2all_backend(), moe)
fused_experts = make_fused_experts(config, moe,
prepare_finalize.num_dispatchers())
fused_experts = make_fused_experts(
config.fused_experts_type,
moe,
prepare_finalize.num_dispatchers(),
weights.w1_gs,
weights.w2_gs,
)
modular_kernel = mk.FusedMoEModularKernel(
prepare_finalize=prepare_finalize, fused_experts=fused_experts)
@@ -620,22 +583,45 @@ def run_modular_kernel(
# weights for rank
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
mk = make_modular_kernel(config, vllm_config)
mk = make_modular_kernel(config, vllm_config, weights)
mk_kwargs = {
"hidden_states": rank_tensors.hidden_states.clone(
"hidden_states":
rank_tensors.hidden_states.clone(
), # impls might update the tensor in place
"w1": rank_weights.w1,
"w2": rank_weights.w2,
"topk_weights": rank_tensors.topk_weights,
"topk_ids": rank_tensors.topk_ids,
"expert_map": rank_tensors.expert_map,
"w1_scale": rank_weights.w1_scale,
"w2_scale": rank_weights.w2_scale,
"a1_scale": rank_tensors.hidden_states_scale,
"global_num_experts": config.E,
"apply_router_weight_on_input": config.topk == 1,
"w1":
rank_weights.w1,
"w2":
rank_weights.w2,
"topk_weights":
rank_tensors.topk_weights,
"topk_ids":
rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype()),
"expert_map":
rank_tensors.expert_map,
"w1_scale":
rank_weights.w1_scale,
"w2_scale":
rank_weights.w2_scale,
"a1_scale":
rank_tensors.hidden_states_scale,
"global_num_experts":
config.E,
"apply_router_weight_on_input":
config.topk == 1 and config.supports_apply_weight_on_input(),
}
out = mk.forward(**mk_kwargs)
num_tokens = rank_tensors.hidden_states.shape[0]
num_tokens_across_dp = torch.tensor([num_tokens] * config.world_size,
device="cuda",
dtype=torch.int)
with set_forward_context(
None,
vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
):
out = mk.forward(**mk_kwargs)
return out