[AMD][ROCm] MoRI EP: a high-performance all2all backend (#28664)
Signed-off-by: Alex Sun <alex.s@amd.com>
This commit is contained in:
@@ -141,7 +141,7 @@ def make_config(args: argparse.Namespace) -> Config:
|
||||
|
||||
quant_config = None
|
||||
if args.quant_dtype is not None:
|
||||
quant_config = FusedMoEQuantConfig(
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype=args.quant_dtype,
|
||||
per_act_token_quant=args.per_token_quantized_activations,
|
||||
per_out_ch_quant=args.per_channel_quantized_weights,
|
||||
|
||||
@@ -28,7 +28,13 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
from vllm.utils.import_utils import (
|
||||
has_aiter,
|
||||
has_deep_ep,
|
||||
has_deep_gemm,
|
||||
has_mori,
|
||||
has_pplx,
|
||||
)
|
||||
|
||||
from .mk_objects import (
|
||||
TestMoEQuantConfig,
|
||||
@@ -211,6 +217,14 @@ class Config:
|
||||
or info.backend == "deepep_low_latency"
|
||||
)
|
||||
|
||||
def needs_aiter(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return info.needs_aiter
|
||||
|
||||
def needs_mori(self):
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return info.backend == "mori"
|
||||
|
||||
def all2all_backend(self):
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return info.backend
|
||||
@@ -278,6 +292,10 @@ class Config:
|
||||
return False, "Needs DeepGEMM, but DeepGEMM not available."
|
||||
if self.needs_pplx() and not has_pplx(): # noqa: SIM103
|
||||
return False, "Needs PPLX, but PPLX not available."
|
||||
if self.needs_aiter() and not has_aiter(): # noqa: SIM103
|
||||
return False, "Needs Aiter, but Aiter not available."
|
||||
if self.needs_mori() and not has_mori(): # noqa: SIM103
|
||||
return False, "Needs MoRI, but MoRI not available."
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
@@ -37,7 +37,13 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_supported
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
from vllm.utils.import_utils import (
|
||||
has_aiter,
|
||||
has_deep_ep,
|
||||
has_deep_gemm,
|
||||
has_mori,
|
||||
has_pplx,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -66,6 +72,7 @@ class ExpertInfo:
|
||||
supports_expert_map: bool
|
||||
needs_matching_quant: bool = False
|
||||
needs_deep_gemm: bool = False
|
||||
needs_aiter: bool = False
|
||||
|
||||
|
||||
PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize, PrepareFinalizeInfo] = {}
|
||||
@@ -126,6 +133,7 @@ def register_experts(
|
||||
supports_expert_map: bool,
|
||||
needs_matching_quant: bool = False,
|
||||
needs_deep_gemm: bool = False,
|
||||
needs_aiter: bool = False,
|
||||
):
|
||||
global EXPERT_INFO
|
||||
global MK_FUSED_EXPERT_TYPES
|
||||
@@ -139,6 +147,7 @@ def register_experts(
|
||||
supports_expert_map,
|
||||
needs_matching_quant,
|
||||
needs_deep_gemm,
|
||||
needs_aiter,
|
||||
)
|
||||
|
||||
MK_FUSED_EXPERT_TYPES.append(kind)
|
||||
@@ -218,6 +227,20 @@ if has_deep_ep() and not current_platform.has_device_capability(100):
|
||||
backend="deepep_low_latency",
|
||||
)
|
||||
|
||||
if has_mori():
|
||||
from vllm.model_executor.layers.fused_moe.mori_prepare_finalize import (
|
||||
MoriPrepareAndFinalize,
|
||||
)
|
||||
|
||||
register_prepare_and_finalize(
|
||||
MoriPrepareAndFinalize,
|
||||
standard_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
backend="mori",
|
||||
supports_apply_weight_on_input=False,
|
||||
)
|
||||
|
||||
if has_pplx():
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize,
|
||||
@@ -261,6 +284,25 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
|
||||
)
|
||||
else:
|
||||
FlashInferCutlassMoEPrepareAndFinalize = None
|
||||
FlashInferExperts = None
|
||||
|
||||
|
||||
if has_aiter():
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
AiterExperts,
|
||||
)
|
||||
|
||||
register_experts(
|
||||
AiterExperts,
|
||||
standard_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=True,
|
||||
needs_aiter=True,
|
||||
)
|
||||
else:
|
||||
AiterExperts = None
|
||||
|
||||
if has_deep_gemm() and is_deep_gemm_supported():
|
||||
register_experts(
|
||||
@@ -316,6 +358,9 @@ if cutlass_fp8_supported():
|
||||
supports_chunking=False,
|
||||
supports_expert_map=False,
|
||||
)
|
||||
else:
|
||||
CutlassBatchedExpertsFp8 = None
|
||||
CutlassExpertsFp8 = None
|
||||
|
||||
if cutlass_fp4_supported():
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp4
|
||||
@@ -328,6 +373,8 @@ if cutlass_fp4_supported():
|
||||
supports_chunking=True,
|
||||
supports_expert_map=False,
|
||||
)
|
||||
else:
|
||||
CutlassExpertsFp4 = None
|
||||
|
||||
MK_QUANT_CONFIGS: list[TestMoEQuantConfig | None] = [
|
||||
None,
|
||||
|
||||
@@ -79,6 +79,8 @@ def _rocm_aiter_fused_moe_impl(
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
num_local_tokens: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype | None = None,
|
||||
) -> torch.Tensor:
|
||||
from aiter import ActivationType, QuantType
|
||||
from aiter.fused_moe import fused_moe
|
||||
@@ -100,6 +102,8 @@ def _rocm_aiter_fused_moe_impl(
|
||||
w2_scale,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
num_local_tokens=num_local_tokens,
|
||||
dtype=output_dtype,
|
||||
)
|
||||
|
||||
|
||||
@@ -117,7 +121,11 @@ def _rocm_aiter_fused_moe_fake(
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
num_local_tokens: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype | None = None,
|
||||
) -> torch.Tensor:
|
||||
if output_dtype is not None:
|
||||
return torch.empty_like(hidden_states, dtype=output_dtype)
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
@@ -1236,6 +1244,8 @@ class rocm_aiter_ops:
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
num_local_tokens: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.rocm_aiter_fused_moe(
|
||||
hidden_states,
|
||||
@@ -1251,6 +1261,8 @@ class rocm_aiter_ops:
|
||||
w2_scale,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
num_local_tokens,
|
||||
output_dtype,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -43,6 +43,7 @@ All2AllBackend = Literal[
|
||||
"pplx",
|
||||
"deepep_high_throughput",
|
||||
"deepep_low_latency",
|
||||
"mori",
|
||||
"allgather_reducescatter",
|
||||
"flashinfer_all2allv",
|
||||
]
|
||||
@@ -158,6 +159,7 @@ class ParallelConfig:
|
||||
- "pplx": Use pplx kernels\n
|
||||
- "deepep_high_throughput": Use deepep high-throughput kernels\n
|
||||
- "deepep_low_latency": Use deepep low-latency kernels\n
|
||||
- "mori": Use mori kernels\n
|
||||
- "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
|
||||
|
||||
max_parallel_loading_workers: int | None = None
|
||||
@@ -443,6 +445,7 @@ class ParallelConfig:
|
||||
"naive",
|
||||
"deepep_high_throughput",
|
||||
"deepep_low_latency",
|
||||
"mori",
|
||||
)
|
||||
and self.enable_expert_parallel
|
||||
and self.tensor_parallel_size > 1
|
||||
|
||||
@@ -10,7 +10,7 @@ from vllm.distributed import get_dp_group, get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.flashinfer import has_flashinfer_all2all
|
||||
from vllm.utils.import_utils import has_deep_ep, has_pplx
|
||||
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
|
||||
|
||||
from .base_device_communicator import All2AllManagerBase, Cache
|
||||
|
||||
@@ -507,3 +507,96 @@ class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
self.prepare_workspace_tensor = None
|
||||
self.mapping = None
|
||||
self.initialized = False
|
||||
|
||||
|
||||
class MoriAll2AllManager(All2AllManagerBase):
|
||||
def __init__(self, cpu_group):
|
||||
assert has_mori(), (
|
||||
"MoRI kernels not found. Please follow https://github.com/ROCm/mori/blob/main/README.md"
|
||||
" to install MoRI kernels."
|
||||
) # noqa
|
||||
import mori
|
||||
|
||||
super().__init__(cpu_group)
|
||||
self.handle_cache = Cache()
|
||||
|
||||
torch._C._distributed_c10d._register_process_group("mori", cpu_group)
|
||||
mori.shmem.shmem_torch_process_group_init("mori")
|
||||
|
||||
def _make_all2all_kwargs(
|
||||
self,
|
||||
rank: int,
|
||||
num_ep_ranks: int,
|
||||
input_dtype: torch.dtype,
|
||||
quant_dtype: torch.dtype,
|
||||
token_hidden_size: int,
|
||||
scale_dim: int,
|
||||
scale_type_size: int,
|
||||
max_num_tokens_per_dp_rank: int,
|
||||
num_local_experts: int,
|
||||
num_experts_per_token: int,
|
||||
):
|
||||
import mori # type: ignore[import-not-found]
|
||||
|
||||
from vllm.platforms.rocm import on_gfx942, on_gfx950
|
||||
|
||||
assert on_gfx942() or on_gfx950(), (
|
||||
"mori currently only support arch gfx942 and gfx950"
|
||||
)
|
||||
|
||||
if not self.internode:
|
||||
# single node
|
||||
kernel_type = mori.ops.EpDispatchCombineKernelType.IntraNode
|
||||
rdma_block_num = 0
|
||||
warp_num_per_block = 16
|
||||
block_num = 80
|
||||
else:
|
||||
# multi node
|
||||
kernel_type = mori.ops.EpDispatchCombineKernelType.InterNodeV1
|
||||
if on_gfx942():
|
||||
warp_num_per_block = 16
|
||||
block_num = 32
|
||||
rdma_block_num = 16
|
||||
elif on_gfx950():
|
||||
warp_num_per_block = 8
|
||||
block_num = 64
|
||||
rdma_block_num = 32
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"mori currently only support arch gfx942 and gfx950"
|
||||
)
|
||||
|
||||
return dict(
|
||||
rank=rank,
|
||||
world_size=num_ep_ranks,
|
||||
data_type=quant_dtype,
|
||||
hidden_dim=token_hidden_size,
|
||||
scale_dim=scale_dim,
|
||||
scale_type_size=scale_type_size,
|
||||
max_token_type_size=input_dtype.itemsize,
|
||||
max_num_inp_token_per_rank=max_num_tokens_per_dp_rank,
|
||||
num_experts_per_rank=num_local_experts,
|
||||
num_experts_per_token=num_experts_per_token,
|
||||
warp_num_per_block=warp_num_per_block,
|
||||
block_num=block_num,
|
||||
kernel_type=kernel_type,
|
||||
rdma_block_num=rdma_block_num,
|
||||
gpu_per_node=min(8, num_ep_ranks),
|
||||
)
|
||||
|
||||
def _make_handle(self, **kwargs):
|
||||
import mori # type: ignore[import-not-found]
|
||||
|
||||
mori_config = mori.ops.EpDispatchCombineConfig(**kwargs)
|
||||
handle = mori.ops.EpDispatchCombineOp(mori_config)
|
||||
return handle
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
import mori # type: ignore[import-not-found]
|
||||
|
||||
mori_kwargs = self._make_all2all_kwargs(**kwargs)
|
||||
logger.debug("MoRI all2all args %s", mori_kwargs)
|
||||
handle: mori.ops.EpDispatchCombineOp = self.handle_cache.get_or_create(
|
||||
mori_kwargs, self._make_handle
|
||||
)
|
||||
return handle
|
||||
|
||||
@@ -110,6 +110,10 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
from .all2all import DeepEPLLAll2AllManager
|
||||
|
||||
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
|
||||
elif self.all2all_backend == "mori":
|
||||
from .all2all import MoriAll2AllManager
|
||||
|
||||
self.all2all_manager = MoriAll2AllManager(self.cpu_group)
|
||||
elif self.all2all_backend == "flashinfer_all2allv":
|
||||
from .all2all import FlashInferAllToAllManager
|
||||
|
||||
|
||||
@@ -187,6 +187,7 @@ if TYPE_CHECKING:
|
||||
"pplx",
|
||||
"deepep_high_throughput",
|
||||
"deepep_low_latency",
|
||||
"mori",
|
||||
"allgather_reducescatter",
|
||||
"flashinfer_all2allv",
|
||||
] = "allgather_reducescatter"
|
||||
@@ -1298,6 +1299,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# - "pplx": use pplx kernels
|
||||
# - "deepep_high_throughput", use deepep high-throughput kernels
|
||||
# - "deepep_low_latency", use deepep low-latency kernels
|
||||
# - "mori", use MoRI kernels
|
||||
# - "flashinfer_all2allv", use flashinfer alltoallv kernels for mnnvl
|
||||
"VLLM_ALL2ALL_BACKEND": env_with_choices(
|
||||
"VLLM_ALL2ALL_BACKEND",
|
||||
@@ -1307,6 +1309,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"pplx",
|
||||
"deepep_high_throughput",
|
||||
"deepep_low_latency",
|
||||
"mori",
|
||||
"allgather_reducescatter",
|
||||
"flashinfer_all2allv",
|
||||
],
|
||||
|
||||
@@ -88,6 +88,9 @@ if HAS_TRITON:
|
||||
fused_experts,
|
||||
get_config_file_name,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
AiterExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import (
|
||||
fused_topk,
|
||||
)
|
||||
@@ -99,6 +102,7 @@ if HAS_TRITON:
|
||||
)
|
||||
|
||||
__all__ += [
|
||||
"AiterExperts",
|
||||
"fused_topk",
|
||||
"fused_experts",
|
||||
"get_config_file_name",
|
||||
|
||||
@@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEPrepareAndFinalize,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import has_deep_ep, has_pplx
|
||||
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
if has_pplx():
|
||||
@@ -30,6 +30,8 @@ if current_platform.is_cuda_alike():
|
||||
DEEPEP_QUANT_BLOCK_SHAPE,
|
||||
DeepEPLLPrepareAndFinalize,
|
||||
)
|
||||
if has_mori():
|
||||
from .mori_prepare_finalize import MoriPrepareAndFinalize
|
||||
|
||||
|
||||
def maybe_roundup_layer_hidden_size(
|
||||
@@ -169,5 +171,36 @@ def maybe_make_prepare_finalize(
|
||||
physical_to_global=physical_to_global,
|
||||
local_expert_global_ids=local_expert_global_ids,
|
||||
)
|
||||
elif moe.use_mori_kernels:
|
||||
assert quant_config is not None
|
||||
|
||||
# Note: We may want to use FP8 dispatch just to reduce
|
||||
# data movement.
|
||||
use_fp8_dispatch = (
|
||||
quant_config.is_per_act_token or quant_config.is_block_quantized
|
||||
)
|
||||
# For PTPC (per token per channel) quant, the scale dim for each token is 1
|
||||
# For 1x128 quant, the scale dim for each token is hidden_dim // 128
|
||||
scale_dim = 1 if quant_config.is_per_act_token else moe.hidden_dim // 128
|
||||
all_to_all_args = dict(
|
||||
rank=all2all_manager.rank,
|
||||
num_ep_ranks=all2all_manager.world_size,
|
||||
quant_dtype=quant_config.quant_dtype,
|
||||
token_hidden_size=moe.hidden_dim,
|
||||
scale_dim=scale_dim,
|
||||
scale_type_size=torch.float32.itemsize,
|
||||
max_num_tokens_per_dp_rank=moe.max_num_tokens,
|
||||
input_dtype=moe.in_dtype,
|
||||
num_local_experts=moe.num_experts // all2all_manager.world_size,
|
||||
num_experts_per_token=moe.experts_per_token,
|
||||
)
|
||||
handle = all2all_manager.get_handle(all_to_all_args)
|
||||
|
||||
prepare_finalize = MoriPrepareAndFinalize(
|
||||
handle,
|
||||
max_tokens_per_rank=moe.max_num_tokens,
|
||||
num_dispatchers=all2all_manager.world_size,
|
||||
use_fp8_dispatch=use_fp8_dispatch,
|
||||
)
|
||||
|
||||
return prepare_finalize
|
||||
|
||||
@@ -893,6 +893,10 @@ class FusedMoEParallelConfig:
|
||||
self.all2all_backend in ["naive", "allgather_reducescatter"]
|
||||
)
|
||||
|
||||
@property
|
||||
def use_mori_kernels(self):
|
||||
return self.use_all2all_kernels and self.all2all_backend == "mori"
|
||||
|
||||
@staticmethod
|
||||
def flatten_tp_across_dp_and_pcp(
|
||||
tp_size: int, dp_size: int, dp_rank: int, pcp_size: int, pcp_rank: int
|
||||
@@ -1136,6 +1140,10 @@ class FusedMoEConfig:
|
||||
def use_deepep_ll_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ll_kernels
|
||||
|
||||
@property
|
||||
def use_mori_kernels(self):
|
||||
return self.moe_parallel_config.use_mori_kernels
|
||||
|
||||
@property
|
||||
def use_flashinfer_cutlass_kernels(self):
|
||||
"""
|
||||
|
||||
@@ -570,6 +570,14 @@ class FusedMoE(CustomOp):
|
||||
self.moe_config_use_flashinfer_cutlass_kernels = (
|
||||
self.moe_config.use_flashinfer_cutlass_kernels
|
||||
)
|
||||
if self.use_mori_kernels:
|
||||
assert self.rocm_aiter_fmoe_enabled, (
|
||||
"Mori needs to be used with aiter fused_moe for now."
|
||||
)
|
||||
assert not self.aiter_fmoe_shared_expert_enabled, (
|
||||
"Mori does not support fusion shared expert now. "
|
||||
"Turn it off by setting VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=0"
|
||||
)
|
||||
|
||||
self.quant_config = quant_config
|
||||
|
||||
@@ -712,6 +720,10 @@ class FusedMoE(CustomOp):
|
||||
def use_deepep_ll_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ll_kernels
|
||||
|
||||
@property
|
||||
def use_mori_kernels(self):
|
||||
return self.moe_parallel_config.use_mori_kernels
|
||||
|
||||
@property
|
||||
def use_flashinfer_cutlass_kernels(self):
|
||||
return (
|
||||
@@ -729,6 +741,7 @@ class FusedMoE(CustomOp):
|
||||
return (
|
||||
self.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels
|
||||
or self.moe_parallel_config.use_mori_kernels
|
||||
or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels)
|
||||
) and envs.VLLM_ENABLE_MOE_DP_CHUNK
|
||||
|
||||
|
||||
121
vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py
Normal file
121
vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import mori
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
Prepare/Finalize using MoRI kernels.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mori_op: mori.ops.EpDispatchCombineOp,
|
||||
max_tokens_per_rank: int,
|
||||
num_dispatchers: int,
|
||||
use_fp8_dispatch: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.mori_op = mori_op
|
||||
self.num_dispatchers_ = num_dispatchers
|
||||
self.max_tokens_per_rank = max_tokens_per_rank
|
||||
self.use_fp8_dispatch = use_fp8_dispatch
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return True
|
||||
|
||||
def num_dispatchers(self):
|
||||
return self.num_dispatchers_
|
||||
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return self.max_tokens_per_rank
|
||||
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return torch.int32
|
||||
|
||||
def supports_async(self) -> bool:
|
||||
return False
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
"""
|
||||
Returns a tuple of:
|
||||
- quantized + dispatched a.
|
||||
- Optional quantized + dispatched a1_scales.
|
||||
- Optional ExpertTokensMetadata containing gpu/cpu tensors
|
||||
as big as the number of local experts with the information about the
|
||||
number of tokens assigned to each local expert.
|
||||
- Optional dispatched expert topk IDs
|
||||
- Optional dispatched expert topk weight
|
||||
"""
|
||||
assert not apply_router_weight_on_input, (
|
||||
"mori does not support apply_router_weight_on_input=True now."
|
||||
)
|
||||
scale = None
|
||||
if self.use_fp8_dispatch:
|
||||
from aiter import QuantType, get_hip_quant
|
||||
|
||||
if quant_config.is_block_quantized:
|
||||
quant_func = get_hip_quant(QuantType.per_1x128)
|
||||
a1, scale = quant_func(a1, quant_dtype=current_platform.fp8_dtype())
|
||||
elif quant_config.is_per_act_token:
|
||||
quant_func = get_hip_quant(QuantType.per_Token)
|
||||
a1, scale = quant_func(a1, quant_dtype=current_platform.fp8_dtype())
|
||||
|
||||
(
|
||||
dispatch_a1,
|
||||
dispatch_weights,
|
||||
dispatch_scale,
|
||||
dispatch_ids,
|
||||
dispatch_recv_token_num,
|
||||
) = self.mori_op.dispatch(a1, topk_weights, scale, topk_ids)
|
||||
|
||||
expert_tokens_meta = mk.ExpertTokensMetadata(
|
||||
expert_num_tokens=dispatch_recv_token_num, expert_num_tokens_cpu=None
|
||||
)
|
||||
|
||||
return (
|
||||
dispatch_a1,
|
||||
dispatch_scale,
|
||||
expert_tokens_meta,
|
||||
dispatch_ids,
|
||||
dispatch_weights,
|
||||
)
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
num_token = output.shape[0]
|
||||
result = self.mori_op.combine(
|
||||
fused_expert_output,
|
||||
None,
|
||||
topk_ids,
|
||||
)[0]
|
||||
output.copy_(result[:num_token])
|
||||
@@ -188,6 +188,9 @@ def rocm_aiter_fused_experts(
|
||||
apply_router_weight_on_input: bool = False,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
quant_config: FusedMoEQuantConfig | None = None,
|
||||
a1q_scale: torch.Tensor | None = None,
|
||||
num_local_tokens: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype | None = None,
|
||||
) -> torch.Tensor:
|
||||
if quant_config is None:
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
@@ -216,6 +219,9 @@ def rocm_aiter_fused_experts(
|
||||
assert topk_weights.shape[-1] == 1, (
|
||||
"Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
)
|
||||
assert num_local_tokens is None, (
|
||||
"AITER tkw1 kernel does not support `num_local_tokens`"
|
||||
)
|
||||
|
||||
return rocm_aiter_ops.asm_moe_tkw1(
|
||||
hidden_states,
|
||||
@@ -272,9 +278,11 @@ def rocm_aiter_fused_experts(
|
||||
activation_method=activation_method,
|
||||
w1_scale=quant_config.w1_scale,
|
||||
w2_scale=quant_config.w2_scale,
|
||||
a1_scale=quant_config.a1_scale,
|
||||
a1_scale=quant_config.a1_scale if a1q_scale is None else a1q_scale,
|
||||
a2_scale=quant_config.a2_scale,
|
||||
doweight_stage1=apply_router_weight_on_input,
|
||||
num_local_tokens=num_local_tokens,
|
||||
output_dtype=output_dtype,
|
||||
)
|
||||
|
||||
|
||||
@@ -370,9 +378,12 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# TODO(rob): rocm_aiter_fused_experts uses self.quant_config's
|
||||
# a_scales for static quantization. Update this to fit better
|
||||
# with the interface once all quant integrations are complete.
|
||||
assert a1q_scale is None
|
||||
assert a2_scale == self.quant_config.a2_scale
|
||||
assert expert_tokens_meta is None
|
||||
|
||||
if expert_tokens_meta is not None:
|
||||
num_local_tokens = expert_tokens_meta.expert_num_tokens
|
||||
else:
|
||||
num_local_tokens = None
|
||||
|
||||
result = rocm_aiter_fused_experts(
|
||||
hidden_states=hidden_states,
|
||||
@@ -384,6 +395,8 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.quant_config,
|
||||
a1q_scale=a1q_scale,
|
||||
num_local_tokens=num_local_tokens,
|
||||
output_dtype=output.dtype,
|
||||
)
|
||||
assert result.shape == output.shape
|
||||
output.copy_(result)
|
||||
|
||||
@@ -106,6 +106,12 @@ def on_gfx9() -> bool:
|
||||
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||
|
||||
|
||||
@cache
|
||||
def on_gfx942() -> bool:
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
return any(arch in GPU_ARCH for arch in ["gfx942"])
|
||||
|
||||
|
||||
@cache
|
||||
def on_gfx950() -> bool:
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
|
||||
@@ -451,3 +451,13 @@ def has_helion() -> bool:
|
||||
# use helion...
|
||||
"""
|
||||
return _has_module("helion")
|
||||
|
||||
|
||||
def has_aiter() -> bool:
|
||||
"""Whether the optional `aiter` package is available."""
|
||||
return _has_module("aiter")
|
||||
|
||||
|
||||
def has_mori() -> bool:
|
||||
"""Whether the optional `mori` package is available."""
|
||||
return _has_module("mori")
|
||||
|
||||
Reference in New Issue
Block a user