[Feat][Perf] Enable deepep-low-latency with round-robin expert placement. (#28449)
Signed-off-by: bruceszchen <bruceszchen@tencent.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -67,6 +67,7 @@ def maybe_roundup_layer_hidden_size(
|
|||||||
def maybe_make_prepare_finalize(
|
def maybe_make_prepare_finalize(
|
||||||
moe: FusedMoEConfig,
|
moe: FusedMoEConfig,
|
||||||
quant_config: FusedMoEQuantConfig | None,
|
quant_config: FusedMoEQuantConfig | None,
|
||||||
|
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||||
) -> FusedMoEPrepareAndFinalize | None:
|
) -> FusedMoEPrepareAndFinalize | None:
|
||||||
if not moe.moe_parallel_config.use_all2all_kernels:
|
if not moe.moe_parallel_config.use_all2all_kernels:
|
||||||
return None
|
return None
|
||||||
@@ -134,6 +135,13 @@ def maybe_make_prepare_finalize(
|
|||||||
|
|
||||||
elif moe.use_deepep_ll_kernels:
|
elif moe.use_deepep_ll_kernels:
|
||||||
assert quant_config is not None
|
assert quant_config is not None
|
||||||
|
global_to_physical = physical_to_global = local_expert_global_ids = None
|
||||||
|
if routing_tables is not None:
|
||||||
|
(
|
||||||
|
global_to_physical,
|
||||||
|
physical_to_global,
|
||||||
|
local_expert_global_ids,
|
||||||
|
) = routing_tables
|
||||||
all_to_all_args = dict(
|
all_to_all_args = dict(
|
||||||
max_num_tokens_per_dp_rank=moe.max_num_tokens,
|
max_num_tokens_per_dp_rank=moe.max_num_tokens,
|
||||||
token_hidden_size=moe.hidden_dim,
|
token_hidden_size=moe.hidden_dim,
|
||||||
@@ -155,6 +163,9 @@ def maybe_make_prepare_finalize(
|
|||||||
max_tokens_per_rank=moe.max_num_tokens,
|
max_tokens_per_rank=moe.max_num_tokens,
|
||||||
num_dispatchers=all2all_manager.world_size,
|
num_dispatchers=all2all_manager.world_size,
|
||||||
use_fp8_dispatch=use_fp8_dispatch,
|
use_fp8_dispatch=use_fp8_dispatch,
|
||||||
|
global_to_physical=global_to_physical,
|
||||||
|
physical_to_global=physical_to_global,
|
||||||
|
local_expert_global_ids=local_expert_global_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
return prepare_finalize
|
return prepare_finalize
|
||||||
|
|||||||
@@ -85,6 +85,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
max_tokens_per_rank: int,
|
max_tokens_per_rank: int,
|
||||||
num_dispatchers: int,
|
num_dispatchers: int,
|
||||||
use_fp8_dispatch: bool = False,
|
use_fp8_dispatch: bool = False,
|
||||||
|
global_to_physical: torch.Tensor | None = None,
|
||||||
|
physical_to_global: torch.Tensor | None = None,
|
||||||
|
local_expert_global_ids: torch.Tensor | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -97,6 +100,17 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
self.handles: list[tuple | None] = [None, None]
|
self.handles: list[tuple | None] = [None, None]
|
||||||
self.num_dispatchers_ = num_dispatchers
|
self.num_dispatchers_ = num_dispatchers
|
||||||
|
|
||||||
|
topk_indices_dtype = self.topk_indices_dtype()
|
||||||
|
|
||||||
|
def _maybe_cast(tensor: torch.Tensor | None) -> torch.Tensor | None:
|
||||||
|
if tensor is None or topk_indices_dtype is None:
|
||||||
|
return tensor
|
||||||
|
return tensor.to(dtype=topk_indices_dtype)
|
||||||
|
|
||||||
|
self.global_to_physical = _maybe_cast(global_to_physical)
|
||||||
|
self.physical_to_global = _maybe_cast(physical_to_global)
|
||||||
|
self.local_expert_global_ids = _maybe_cast(local_expert_global_ids)
|
||||||
|
|
||||||
# We don't have enough information to determine if we should dispatch
|
# We don't have enough information to determine if we should dispatch
|
||||||
# activation scales in a packed ue8m0 format during object construction
|
# activation scales in a packed ue8m0 format during object construction
|
||||||
# time. This setting is handled by post_init_setup.
|
# time. This setting is handled by post_init_setup.
|
||||||
@@ -136,6 +150,16 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||||
return torch.int64
|
return torch.int64
|
||||||
|
|
||||||
|
def _map_global_to_physical_ids(self, topk_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.global_to_physical is None:
|
||||||
|
return topk_ids
|
||||||
|
return self.global_to_physical[topk_ids]
|
||||||
|
|
||||||
|
def _map_local_to_global_ids(self, expert_topk_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.local_expert_global_ids is None:
|
||||||
|
return expert_topk_ids
|
||||||
|
return self.local_expert_global_ids[expert_topk_ids]
|
||||||
|
|
||||||
def _do_quant(
|
def _do_quant(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||||
@@ -226,9 +250,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
a1 = a1 * topk_weights.to(a1.dtype)
|
a1 = a1 * topk_weights.to(a1.dtype)
|
||||||
|
|
||||||
# Dispatch
|
# Dispatch
|
||||||
|
dispatch_topk_ids = self._map_global_to_physical_ids(topk_ids)
|
||||||
expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch(
|
expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch(
|
||||||
a1,
|
a1,
|
||||||
topk_ids,
|
dispatch_topk_ids,
|
||||||
self.max_tokens_per_rank,
|
self.max_tokens_per_rank,
|
||||||
num_experts,
|
num_experts,
|
||||||
use_fp8=self.use_fp8_dispatch,
|
use_fp8=self.use_fp8_dispatch,
|
||||||
@@ -313,11 +338,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
# weights have already been applied.
|
# weights have already been applied.
|
||||||
combine_topk_weights = torch.ones_like(topk_weights)
|
combine_topk_weights = torch.ones_like(topk_weights)
|
||||||
|
|
||||||
|
combine_topk_ids = self._map_global_to_physical_ids(topk_ids)
|
||||||
# TODO (varun) : Enable zero copy mode
|
# TODO (varun) : Enable zero copy mode
|
||||||
dbo_maybe_run_recv_hook()
|
dbo_maybe_run_recv_hook()
|
||||||
_, _, recv_hook = self.buffer.low_latency_combine(
|
_, _, recv_hook = self.buffer.low_latency_combine(
|
||||||
fused_expert_output,
|
fused_expert_output,
|
||||||
topk_ids,
|
combine_topk_ids,
|
||||||
combine_topk_weights,
|
combine_topk_weights,
|
||||||
handle,
|
handle,
|
||||||
async_finish=False,
|
async_finish=False,
|
||||||
|
|||||||
@@ -50,10 +50,15 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None:
|
def maybe_make_prepare_finalize(
|
||||||
|
self,
|
||||||
|
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||||
|
) -> FusedMoEPrepareAndFinalize | None:
|
||||||
from .all2all_utils import maybe_make_prepare_finalize
|
from .all2all_utils import maybe_make_prepare_finalize
|
||||||
|
|
||||||
return maybe_make_prepare_finalize(self.moe, self.moe_quant_config)
|
return maybe_make_prepare_finalize(
|
||||||
|
self.moe, self.moe_quant_config, routing_tables
|
||||||
|
)
|
||||||
|
|
||||||
def select_gemm_impl(
|
def select_gemm_impl(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from collections.abc import Callable, Iterable
|
|||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Literal, get_args, overload
|
from typing import Literal, cast, get_args, overload
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -192,6 +192,42 @@ def determine_expert_map(
|
|||||||
return (local_num_experts, expert_map, expert_mask)
|
return (local_num_experts, expert_map, expert_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def determine_expert_placement_strategy(
|
||||||
|
expert_placement_strategy: ExpertPlacementStrategy,
|
||||||
|
moe_parallel_config: FusedMoEParallelConfig,
|
||||||
|
num_expert_group: int | None,
|
||||||
|
num_redundant_experts: int,
|
||||||
|
enable_eplb: bool,
|
||||||
|
) -> ExpertPlacementStrategy:
|
||||||
|
if expert_placement_strategy == "round_robin":
|
||||||
|
round_robin_supported = (
|
||||||
|
(num_expert_group is not None and num_expert_group > 1)
|
||||||
|
and num_redundant_experts == 0
|
||||||
|
and not enable_eplb
|
||||||
|
)
|
||||||
|
|
||||||
|
if not round_robin_supported:
|
||||||
|
logger.warning(
|
||||||
|
"Round-robin expert placement is only supported for "
|
||||||
|
"models with multiple expert groups and no redundant "
|
||||||
|
"experts. Falling back to linear expert placement."
|
||||||
|
)
|
||||||
|
return "linear"
|
||||||
|
if (
|
||||||
|
moe_parallel_config.use_all2all_kernels
|
||||||
|
and not moe_parallel_config.use_deepep_ll_kernels
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"Round-robin expert placement currently only supports "
|
||||||
|
"the DeepEP low-latency backend, but '%s' was configured. "
|
||||||
|
"Falling back to linear expert placement.",
|
||||||
|
moe_parallel_config.all2all_backend,
|
||||||
|
)
|
||||||
|
return "linear"
|
||||||
|
|
||||||
|
return expert_placement_strategy
|
||||||
|
|
||||||
|
|
||||||
def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
|
def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
|
||||||
"""
|
"""
|
||||||
Compresses the expert map by removing any -1 entries.
|
Compresses the expert map by removing any -1 entries.
|
||||||
@@ -400,6 +436,9 @@ class FusedMoE(CustomOp):
|
|||||||
self.expert_load_view: torch.Tensor | None = None
|
self.expert_load_view: torch.Tensor | None = None
|
||||||
self.logical_to_physical_map: torch.Tensor | None = None
|
self.logical_to_physical_map: torch.Tensor | None = None
|
||||||
self.logical_replica_count: torch.Tensor | None = None
|
self.logical_replica_count: torch.Tensor | None = None
|
||||||
|
self.expert_placement_strategy: ExpertPlacementStrategy = (
|
||||||
|
vllm_config.parallel_config.expert_placement_strategy
|
||||||
|
)
|
||||||
|
|
||||||
# ROCm aiter shared experts fusion
|
# ROCm aiter shared experts fusion
|
||||||
self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||||
@@ -433,38 +472,27 @@ class FusedMoE(CustomOp):
|
|||||||
"Redundant experts are only supported with EPLB."
|
"Redundant experts are only supported with EPLB."
|
||||||
)
|
)
|
||||||
|
|
||||||
expert_placement_strategy = (
|
self.expert_placement_strategy = determine_expert_placement_strategy(
|
||||||
vllm_config.parallel_config.expert_placement_strategy
|
expert_placement_strategy=self.expert_placement_strategy,
|
||||||
|
moe_parallel_config=self.moe_parallel_config,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
num_redundant_experts=num_redundant_experts,
|
||||||
|
enable_eplb=self.enable_eplb,
|
||||||
)
|
)
|
||||||
if expert_placement_strategy == "round_robin":
|
|
||||||
# TODO(Bruce): will support round robin expert placement with
|
|
||||||
# EPLB enabled in the future.
|
|
||||||
round_robin_supported = (
|
|
||||||
(num_expert_group is not None and num_expert_group > 1)
|
|
||||||
and num_redundant_experts == 0
|
|
||||||
and not self.enable_eplb
|
|
||||||
)
|
|
||||||
|
|
||||||
if not round_robin_supported:
|
|
||||||
logger.warning(
|
|
||||||
"Round-robin expert placement is only supported for "
|
|
||||||
"models with multiple expert groups and no redundant "
|
|
||||||
"experts. Falling back to linear expert placement."
|
|
||||||
)
|
|
||||||
expert_placement_strategy = "linear"
|
|
||||||
|
|
||||||
self.expert_map: torch.Tensor | None
|
self.expert_map: torch.Tensor | None
|
||||||
local_num_experts, expert_map, expert_mask = determine_expert_map(
|
local_num_experts, expert_map, expert_mask = determine_expert_map(
|
||||||
ep_size=self.ep_size,
|
ep_size=self.ep_size,
|
||||||
ep_rank=self.ep_rank,
|
ep_rank=self.ep_rank,
|
||||||
global_num_experts=self.global_num_experts,
|
global_num_experts=self.global_num_experts,
|
||||||
expert_placement_strategy=expert_placement_strategy,
|
expert_placement_strategy=self.expert_placement_strategy,
|
||||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||||
return_expert_mask=self.rocm_aiter_fmoe_enabled,
|
return_expert_mask=self.rocm_aiter_fmoe_enabled,
|
||||||
)
|
)
|
||||||
self.local_num_experts = local_num_experts
|
self.local_num_experts = local_num_experts
|
||||||
self.register_buffer("expert_map", expert_map)
|
self.register_buffer("expert_map", expert_map)
|
||||||
self.register_buffer("expert_mask", expert_mask)
|
self.register_buffer("expert_mask", expert_mask)
|
||||||
|
self._maybe_init_expert_routing_tables()
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
"[EP Rank %s/%s] Expert parallelism is enabled. Expert "
|
"[EP Rank %s/%s] Expert parallelism is enabled. Expert "
|
||||||
"placement strategy: %s. Local/global"
|
"placement strategy: %s. Local/global"
|
||||||
@@ -472,7 +500,7 @@ class FusedMoE(CustomOp):
|
|||||||
" %s.",
|
" %s.",
|
||||||
self.ep_rank,
|
self.ep_rank,
|
||||||
self.ep_size,
|
self.ep_size,
|
||||||
expert_placement_strategy,
|
self.expert_placement_strategy,
|
||||||
self.local_num_experts,
|
self.local_num_experts,
|
||||||
self.global_num_experts,
|
self.global_num_experts,
|
||||||
get_compressed_expert_map(self.expert_map),
|
get_compressed_expert_map(self.expert_map),
|
||||||
@@ -621,7 +649,12 @@ class FusedMoE(CustomOp):
|
|||||||
# should be safe to swap out the quant_method.
|
# should be safe to swap out the quant_method.
|
||||||
def maybe_init_modular_kernel(self) -> None:
|
def maybe_init_modular_kernel(self) -> None:
|
||||||
self.ensure_moe_quant_config_init()
|
self.ensure_moe_quant_config_init()
|
||||||
prepare_finalize = self.quant_method.maybe_make_prepare_finalize()
|
# routing_tables only needed for round-robin expert placement with
|
||||||
|
# DeepEP all2all backend.
|
||||||
|
routing_tables = self._maybe_init_expert_routing_tables()
|
||||||
|
prepare_finalize = self.quant_method.maybe_make_prepare_finalize(
|
||||||
|
routing_tables=routing_tables
|
||||||
|
)
|
||||||
if prepare_finalize is not None:
|
if prepare_finalize is not None:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self)
|
"%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self)
|
||||||
@@ -703,6 +736,84 @@ class FusedMoE(CustomOp):
|
|||||||
# By default, router/gate is called before FusedMoE forward pass
|
# By default, router/gate is called before FusedMoE forward pass
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _maybe_init_expert_routing_tables(
|
||||||
|
self,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None:
|
||||||
|
# Currently routing_tables only needed for round-robin expert placement
|
||||||
|
# with DeepEP-ll all2all backend.
|
||||||
|
if (
|
||||||
|
self.expert_placement_strategy != "round_robin"
|
||||||
|
or not self.use_deepep_ll_kernels
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if hasattr(self, "expert_global_to_physical"):
|
||||||
|
return cast(
|
||||||
|
tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||||
|
(
|
||||||
|
self.expert_global_to_physical,
|
||||||
|
self.expert_physical_to_global,
|
||||||
|
self.expert_local_to_global,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.expert_map is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
routing_tables = self.ensure_round_robin_expert_routing_tables(
|
||||||
|
global_num_experts=self.global_num_experts,
|
||||||
|
ep_size=self.ep_size,
|
||||||
|
ep_rank=self.ep_rank,
|
||||||
|
local_num_experts=self.local_num_experts,
|
||||||
|
device=self.expert_map.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
global_to_physical, physical_to_global, local_global = routing_tables
|
||||||
|
self.register_buffer("expert_global_to_physical", global_to_physical)
|
||||||
|
self.register_buffer("expert_physical_to_global", physical_to_global)
|
||||||
|
self.register_buffer("expert_local_to_global", local_global)
|
||||||
|
|
||||||
|
return routing_tables
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def ensure_round_robin_expert_routing_tables(
|
||||||
|
global_num_experts: int,
|
||||||
|
ep_size: int,
|
||||||
|
ep_rank: int,
|
||||||
|
local_num_experts: int,
|
||||||
|
device: torch.device | None = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
device_kwargs = {"device": device} if device is not None else {}
|
||||||
|
global_indices = torch.arange(
|
||||||
|
global_num_experts, dtype=torch.long, **device_kwargs
|
||||||
|
)
|
||||||
|
owner = torch.remainder(global_indices, ep_size)
|
||||||
|
local_index = torch.div(global_indices, ep_size, rounding_mode="floor")
|
||||||
|
base = global_num_experts // ep_size
|
||||||
|
remainder = global_num_experts % ep_size
|
||||||
|
physical_offset = owner * base
|
||||||
|
if remainder > 0:
|
||||||
|
remainder_tensor = torch.tensor(
|
||||||
|
remainder, dtype=torch.long, **device_kwargs
|
||||||
|
)
|
||||||
|
physical_offset = physical_offset + torch.minimum(owner, remainder_tensor)
|
||||||
|
|
||||||
|
global_to_physical = physical_offset + local_index
|
||||||
|
physical_to_global = torch.empty_like(global_to_physical)
|
||||||
|
physical_to_global[global_to_physical] = global_indices
|
||||||
|
|
||||||
|
local_global = torch.arange(
|
||||||
|
ep_rank,
|
||||||
|
global_num_experts,
|
||||||
|
ep_size,
|
||||||
|
dtype=torch.long,
|
||||||
|
**device_kwargs,
|
||||||
|
)
|
||||||
|
if local_global.numel() != local_num_experts:
|
||||||
|
local_global = local_global[:local_num_experts]
|
||||||
|
|
||||||
|
return (global_to_physical, physical_to_global, local_global)
|
||||||
|
|
||||||
def update_expert_map(self):
|
def update_expert_map(self):
|
||||||
# ep_size and ep_rank should already be updated
|
# ep_size and ep_rank should already be updated
|
||||||
assert self.expert_map is not None
|
assert self.expert_map is not None
|
||||||
@@ -711,12 +822,14 @@ class FusedMoE(CustomOp):
|
|||||||
ep_size=self.ep_size,
|
ep_size=self.ep_size,
|
||||||
ep_rank=self.ep_rank,
|
ep_rank=self.ep_rank,
|
||||||
global_num_experts=self.global_num_experts,
|
global_num_experts=self.global_num_experts,
|
||||||
|
expert_placement_strategy=self.expert_placement_strategy,
|
||||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||||
return_expert_mask=self.rocm_aiter_fmoe_enabled,
|
return_expert_mask=self.rocm_aiter_fmoe_enabled,
|
||||||
)
|
)
|
||||||
self.local_num_experts = local_num_experts
|
self.local_num_experts = local_num_experts
|
||||||
self.register_buffer("expert_map", expert_map)
|
self.register_buffer("expert_map", expert_map)
|
||||||
self.register_buffer("expert_mask", expert_mask)
|
self.register_buffer("expert_mask", expert_mask)
|
||||||
|
self._maybe_init_expert_routing_tables()
|
||||||
if self.aiter_fmoe_shared_expert_enabled:
|
if self.aiter_fmoe_shared_expert_enabled:
|
||||||
self._init_aiter_shared_experts_topK_buffer(
|
self._init_aiter_shared_experts_topK_buffer(
|
||||||
vllm_config=get_current_vllm_config(),
|
vllm_config=get_current_vllm_config(),
|
||||||
|
|||||||
@@ -108,11 +108,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
def allow_inplace(self) -> bool:
|
def allow_inplace(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None:
|
def maybe_make_prepare_finalize(
|
||||||
|
self,
|
||||||
|
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||||
|
) -> FusedMoEPrepareAndFinalize | None:
|
||||||
if self.rocm_aiter_moe_enabled:
|
if self.rocm_aiter_moe_enabled:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return super().maybe_make_prepare_finalize()
|
return super().maybe_make_prepare_finalize(routing_tables)
|
||||||
|
|
||||||
def select_gemm_impl(
|
def select_gemm_impl(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -380,11 +380,14 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
(layer.w2_input_global_scale), requires_grad=False
|
(layer.w2_input_global_scale), requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
|
def maybe_make_prepare_finalize(
|
||||||
|
self,
|
||||||
|
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||||
|
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
return None
|
return None
|
||||||
elif not self.allow_flashinfer:
|
elif not self.allow_flashinfer:
|
||||||
return super().maybe_make_prepare_finalize()
|
return super().maybe_make_prepare_finalize(routing_tables)
|
||||||
|
|
||||||
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(self.moe)
|
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(self.moe)
|
||||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||||
@@ -890,11 +893,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
layer.w2_weight_scale
|
layer.w2_weight_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
|
def maybe_make_prepare_finalize(
|
||||||
|
self,
|
||||||
|
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||||
|
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||||
if self.use_marlin or self.rocm_aiter_moe_enabled:
|
if self.use_marlin or self.rocm_aiter_moe_enabled:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return super().maybe_make_prepare_finalize()
|
return super().maybe_make_prepare_finalize(routing_tables)
|
||||||
|
|
||||||
def select_gemm_impl(
|
def select_gemm_impl(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1018,7 +1018,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
del layer.w13_input_scale
|
del layer.w13_input_scale
|
||||||
del layer.w2_input_scale
|
del layer.w2_input_scale
|
||||||
|
|
||||||
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
|
def maybe_make_prepare_finalize(
|
||||||
|
self,
|
||||||
|
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||||
|
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||||
if (
|
if (
|
||||||
self.rocm_aiter_moe_enabled
|
self.rocm_aiter_moe_enabled
|
||||||
or self.use_marlin
|
or self.use_marlin
|
||||||
@@ -1039,7 +1042,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||||
return prepare_finalize
|
return prepare_finalize
|
||||||
else:
|
else:
|
||||||
return super().maybe_make_prepare_finalize()
|
return super().maybe_make_prepare_finalize(routing_tables)
|
||||||
|
|
||||||
def select_gemm_impl(
|
def select_gemm_impl(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -373,6 +373,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
def maybe_make_prepare_finalize(
|
def maybe_make_prepare_finalize(
|
||||||
self,
|
self,
|
||||||
|
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||||
# TRT LLM not supported with all2all yet.
|
# TRT LLM not supported with all2all yet.
|
||||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||||
@@ -384,7 +385,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||||
return prepare_finalize
|
return prepare_finalize
|
||||||
else:
|
else:
|
||||||
return super().maybe_make_prepare_finalize()
|
return super().maybe_make_prepare_finalize(routing_tables)
|
||||||
|
|
||||||
def select_gemm_impl(
|
def select_gemm_impl(
|
||||||
self,
|
self,
|
||||||
@@ -1179,7 +1180,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
" for ModelOptNvFp4FusedMoE."
|
" for ModelOptNvFp4FusedMoE."
|
||||||
)
|
)
|
||||||
|
|
||||||
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
|
def maybe_make_prepare_finalize(
|
||||||
|
self,
|
||||||
|
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||||
|
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||||
if self.use_marlin or (
|
if self.use_marlin or (
|
||||||
self.allow_flashinfer
|
self.allow_flashinfer
|
||||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||||
@@ -1196,7 +1200,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||||
return prepare_finalize
|
return prepare_finalize
|
||||||
else:
|
else:
|
||||||
return super().maybe_make_prepare_finalize()
|
return super().maybe_make_prepare_finalize(routing_tables)
|
||||||
|
|
||||||
def select_gemm_impl(
|
def select_gemm_impl(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user