[Perf] Add TRTLLM FP8 MoE Modular Kernel (#36307)
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -19,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import (
|
||||
TrtLlmFp8Experts,
|
||||
TrtLlmFp8ExpertsMonolithic,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
@@ -247,7 +247,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
||||
allow_new_interface=True,
|
||||
use_monolithic=True,
|
||||
),
|
||||
TrtLlmFp8Experts(
|
||||
TrtLlmFp8ExpertsMonolithic(
|
||||
moe_config=td.layer.moe,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
@@ -11,6 +12,9 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
activation_to_flashinfer_int,
|
||||
)
|
||||
@@ -22,10 +26,13 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
|
||||
|
||||
class TrtLlmFp8ExpertsBase:
|
||||
"""
|
||||
Fp8 TRTLLM-Gen MoE kernels. Supports monolithic interface.
|
||||
Fp8 TRTLLM-Gen MoE kernels. Shared base for modular and monolithic
|
||||
interfaces.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -33,8 +40,6 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(moe_config, quant_config)
|
||||
|
||||
self.routing_method_type = moe_config.routing_method
|
||||
self.topk = moe_config.experts_per_token
|
||||
self.intermediate_size_per_partition = (
|
||||
@@ -44,24 +49,7 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
|
||||
self.local_num_experts = moe_config.num_local_experts
|
||||
self.ep_rank = moe_config.moe_parallel_config.ep_rank
|
||||
|
||||
# Make additional scales for per-tensor interface.
|
||||
if self.quant_config.is_per_tensor:
|
||||
w1_scale = self.quant_config.w1_scale
|
||||
assert w1_scale is not None
|
||||
a1_scale = self.quant_config.a1_scale
|
||||
assert a1_scale is not None
|
||||
w2_scale = self.quant_config.w2_scale
|
||||
assert w2_scale is not None
|
||||
a2_scale = self.quant_config.a2_scale
|
||||
assert a2_scale is not None
|
||||
|
||||
self._g1_alphas = (w1_scale * a1_scale).squeeze()
|
||||
self._g2_alphas = (w2_scale * a2_scale).squeeze()
|
||||
self._g1_scale_c = (
|
||||
self._g1_alphas / self.quant_config.a2_scale
|
||||
if moe_config.is_act_and_mul
|
||||
else torch.ones_like(self._g1_alphas) / self.quant_config.a2_scale
|
||||
)
|
||||
self.quant_config = quant_config
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
@@ -79,50 +67,11 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
|
||||
"""Does not support non-gated MoE (i.e. Nanotron-3-Nano)."""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Supports Fp8 per-tensor and Fp8 block."""
|
||||
SUPPORTED_W_A = [
|
||||
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
|
||||
(kFp8StaticTensorSym, kFp8StaticTensorSym),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
"""Supports only SiLU and RELU^2 non-gated activation."""
|
||||
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
|
||||
@staticmethod
|
||||
def _supports_routing_method(
|
||||
routing_method: RoutingMethodType,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Monolithic kernels need to express router support."""
|
||||
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
|
||||
# NOTE(dbari): Default is not implemented and should not be enabled until it is
|
||||
if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
|
||||
# NOTE(rob): potentially allow others here. This is a conservative list.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
|
||||
# NOTE(dbari): as above, potentially allow others here.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Llama4,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
else:
|
||||
raise ValueError("Unsupported quantization scheme.")
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
"""Monolithic kernel so only use with naive DP/EP and TP."""
|
||||
@@ -153,6 +102,178 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
|
||||
"""
|
||||
Fp8 TRTLLM-Gen MoE kernels. Supports modular interface.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Supports Fp8 block."""
|
||||
SUPPORTED_W_A = [
|
||||
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# The workspaces for this implementation are managed by flashinfer.
|
||||
workspace1 = (0,)
|
||||
workspace2 = (0,)
|
||||
output = (M, K)
|
||||
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
import flashinfer
|
||||
|
||||
# Pack topk_ids and topk_weights into single tensor
|
||||
# Format: (expert_id << 16) | (weight_bf16.view(int16))
|
||||
packed_topk_ids = (topk_ids << 16) | topk_weights.to(torch.bfloat16).view(
|
||||
torch.int16
|
||||
)
|
||||
|
||||
# trtllm_fp8_block_scale_routed_moe does not support autotuning
|
||||
# so skip this kernel during dummy run for autotuning.
|
||||
import vllm.utils.flashinfer as fi_utils
|
||||
|
||||
if fi_utils._is_fi_autotuning:
|
||||
return
|
||||
|
||||
assert a1q_scale is not None
|
||||
|
||||
# `trtllm_fp8_block_scale_routed_moe` has a bug and does not write to the
|
||||
# output tensor in-place so we need to manually copy the result to the
|
||||
# output tensor
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/2703
|
||||
result = flashinfer.fused_moe.trtllm_fp8_block_scale_routed_moe(
|
||||
topk_ids=packed_topk_ids,
|
||||
routing_bias=None,
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=a1q_scale.t().contiguous(), # type: ignore[union-attr]
|
||||
gemm1_weights=w1,
|
||||
gemm1_weights_scale=self.quant_config.w1_scale,
|
||||
gemm2_weights=w2,
|
||||
gemm2_weights_scale=self.quant_config.w2_scale,
|
||||
num_experts=global_num_experts,
|
||||
top_k=self.topk,
|
||||
n_group=None,
|
||||
topk_group=None,
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
local_expert_offset=self.ep_rank * self.local_num_experts,
|
||||
local_num_experts=self.local_num_experts,
|
||||
routed_scaling_factor=None,
|
||||
routing_method_type=1,
|
||||
use_shuffled_weight=False,
|
||||
weight_layout=0,
|
||||
# output=output,
|
||||
)
|
||||
output.copy_(result)
|
||||
|
||||
|
||||
class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolithic):
|
||||
"""
|
||||
Fp8 TRTLLM-Gen MoE kernels. Supports monolithic interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(moe_config, quant_config)
|
||||
|
||||
# Make additional scales for per-tensor interface.
|
||||
if self.quant_config.is_per_tensor:
|
||||
w1_scale = self.quant_config.w1_scale
|
||||
assert w1_scale is not None
|
||||
a1_scale = self.quant_config.a1_scale
|
||||
assert a1_scale is not None
|
||||
w2_scale = self.quant_config.w2_scale
|
||||
assert w2_scale is not None
|
||||
a2_scale = self.quant_config.a2_scale
|
||||
assert a2_scale is not None
|
||||
|
||||
self._g1_alphas = (w1_scale * a1_scale).squeeze()
|
||||
self._g2_alphas = (w2_scale * a2_scale).squeeze()
|
||||
self._g1_scale_c = (
|
||||
self._g1_alphas / self.quant_config.a2_scale
|
||||
if moe_config.is_act_and_mul
|
||||
else torch.ones_like(self._g1_alphas) / self.quant_config.a2_scale
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Supports Fp8 per-tensor and Fp8 block."""
|
||||
SUPPORTED_W_A = [
|
||||
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
|
||||
(kFp8StaticTensorSym, kFp8StaticTensorSym),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_routing_method(
|
||||
routing_method: RoutingMethodType,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Monolithic kernels need to express router support."""
|
||||
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
|
||||
# NOTE(dbari): Default is not implemented and should not be enabled until it is
|
||||
if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
|
||||
# NOTE(rob): potentially allow others here. This is a conservative list.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
|
||||
# NOTE(dbari): as above, potentially allow others here.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Llama4,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
else:
|
||||
raise ValueError("Unsupported quantization scheme.")
|
||||
|
||||
def _apply_per_block(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@@ -104,83 +104,84 @@ def _get_priority_backends(
|
||||
|
||||
def backend_to_kernel_cls(
|
||||
backend: Fp8MoeBackend,
|
||||
) -> type[mk.FusedMoEExperts]:
|
||||
) -> list[type[mk.FusedMoEExperts]]:
|
||||
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import ( # noqa: E501
|
||||
TrtLlmFp8Experts,
|
||||
TrtLlmFp8ExpertsModular,
|
||||
TrtLlmFp8ExpertsMonolithic,
|
||||
)
|
||||
|
||||
return TrtLlmFp8Experts
|
||||
return [TrtLlmFp8ExpertsMonolithic, TrtLlmFp8ExpertsModular]
|
||||
|
||||
elif backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
)
|
||||
|
||||
return FlashInferExperts
|
||||
return [FlashInferExperts]
|
||||
|
||||
elif backend == Fp8MoeBackend.DEEPGEMM:
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
|
||||
return TritonOrDeepGemmExperts
|
||||
return [TritonOrDeepGemmExperts]
|
||||
|
||||
elif backend == Fp8MoeBackend.BATCHED_DEEPGEMM:
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts,
|
||||
)
|
||||
|
||||
return BatchedDeepGemmExperts
|
||||
return [BatchedDeepGemmExperts]
|
||||
|
||||
elif backend == Fp8MoeBackend.MARLIN:
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
MarlinExperts,
|
||||
)
|
||||
|
||||
return MarlinExperts
|
||||
return [MarlinExperts]
|
||||
|
||||
elif backend == Fp8MoeBackend.TRITON:
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
TritonExperts,
|
||||
)
|
||||
|
||||
return TritonExperts
|
||||
return [TritonExperts]
|
||||
|
||||
elif backend == Fp8MoeBackend.BATCHED_TRITON:
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts,
|
||||
)
|
||||
|
||||
return BatchedTritonExperts
|
||||
return [BatchedTritonExperts]
|
||||
|
||||
elif backend == Fp8MoeBackend.AITER:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
AiterExperts,
|
||||
)
|
||||
|
||||
return AiterExperts
|
||||
return [AiterExperts]
|
||||
|
||||
elif backend == Fp8MoeBackend.VLLM_CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.triton_cutlass_moe import (
|
||||
TritonOrCutlassExperts,
|
||||
)
|
||||
|
||||
return TritonOrCutlassExperts
|
||||
return [TritonOrCutlassExperts]
|
||||
|
||||
elif backend == Fp8MoeBackend.BATCHED_VLLM_CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
CutlassBatchedExpertsFp8,
|
||||
)
|
||||
|
||||
return CutlassBatchedExpertsFp8
|
||||
return [CutlassBatchedExpertsFp8]
|
||||
|
||||
elif backend == Fp8MoeBackend.XPU:
|
||||
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import (
|
||||
XPUExpertsFp8,
|
||||
)
|
||||
|
||||
return XPUExpertsFp8
|
||||
return [XPUExpertsFp8]
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown FP8 MoE backend: {backend.value}")
|
||||
@@ -215,8 +216,9 @@ def select_fp8_moe_backend(
|
||||
Select the primary FP8 MoE backend
|
||||
Note: Shape-specific fallbacks may still occur at runtime.
|
||||
"""
|
||||
|
||||
if config.is_lora_enabled:
|
||||
return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON)
|
||||
return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON)[0]
|
||||
|
||||
# NOTE: the kernels are selected in the following order.
|
||||
AVAILABLE_BACKENDS = _get_priority_backends(config, weight_key, activation_key)
|
||||
@@ -256,13 +258,13 @@ def select_fp8_moe_backend(
|
||||
activation_key: QuantKey | None,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]:
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend, k_cls
|
||||
for k_cls in backend_to_kernel_cls(backend):
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend, k_cls
|
||||
raise ValueError(_make_log_unsupported(backend, reason))
|
||||
|
||||
# Handle explicit moe_backend from user.
|
||||
@@ -312,7 +314,7 @@ def select_fp8_moe_backend(
|
||||
raise ValueError(
|
||||
f"FlashInfer MOE backend {fi_backend} does not support FP8 MoE."
|
||||
)
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
k_cls = backend_to_kernel_cls(backend)[0]
|
||||
return _return_or_raise(
|
||||
backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
@@ -322,23 +324,23 @@ def select_fp8_moe_backend(
|
||||
Fp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
Fp8MoeBackend.FLASHINFER_CUTLASS,
|
||||
]:
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls,
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend, k_cls
|
||||
else:
|
||||
logger.debug_once(
|
||||
_make_log_unsupported(backend, reason), scope="local"
|
||||
for k_cls in backend_to_kernel_cls(backend):
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls,
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend, k_cls
|
||||
else:
|
||||
logger.debug_once(
|
||||
_make_log_unsupported(backend, reason), scope="local"
|
||||
)
|
||||
|
||||
raise NotImplementedError(
|
||||
"Found VLLM_USE_FLASHINFER_MOE_FP8=1, but no "
|
||||
"FlashInfer FP8 MoE backend supports the configuration."
|
||||
@@ -382,20 +384,19 @@ def select_fp8_moe_backend(
|
||||
|
||||
# Select kernels in order of backend.
|
||||
for backend in AVAILABLE_BACKENDS:
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls,
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend, k_cls
|
||||
else:
|
||||
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
|
||||
for k_cls in backend_to_kernel_cls(backend):
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls,
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend, k_cls
|
||||
else:
|
||||
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
|
||||
|
||||
# TODO(rob): per discussion with TPU team, we need a way to register
|
||||
# MoE backends by OOT plugins, rather than having an explicit list
|
||||
|
||||
Reference in New Issue
Block a user