[Kernels] Remove BatchedTritonOrDeepGemmExperts and default fallback to Triton (#29929)
Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@@ -60,9 +60,6 @@ if HAS_TRITON:
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
CutlassBatchedExpertsFp8,
|
||||
CutlassExpertsFp8,
|
||||
@@ -98,7 +95,6 @@ if HAS_TRITON:
|
||||
"DeepGemmExperts",
|
||||
"BatchedDeepGemmExperts",
|
||||
"TritonOrDeepGemmExperts",
|
||||
"BatchedTritonOrDeepGemmExperts",
|
||||
]
|
||||
else:
|
||||
# Some model classes directly use the custom ops. Add placeholders
|
||||
|
||||
@@ -1,180 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
|
||||
from vllm.utils.deep_gemm import get_mk_alignment_for_contiguous_layout
|
||||
|
||||
|
||||
class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def __init__(
|
||||
self,
|
||||
max_num_tokens: int,
|
||||
num_dispatchers: int,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
allow_deep_gemm: bool = False,
|
||||
):
|
||||
super().__init__(quant_config)
|
||||
|
||||
self.batched_triton_experts = BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=num_dispatchers,
|
||||
quant_config=self.quant_config,
|
||||
)
|
||||
|
||||
self.allow_deep_gemm = (
|
||||
allow_deep_gemm
|
||||
and self.quant_config.use_fp8_w8a8
|
||||
and self.block_shape == get_mk_alignment_for_contiguous_layout()
|
||||
)
|
||||
|
||||
self.batched_deep_gemm_experts = (
|
||||
BatchedDeepGemmExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=num_dispatchers,
|
||||
quant_config=self.quant_config,
|
||||
)
|
||||
if self.allow_deep_gemm
|
||||
else None
|
||||
)
|
||||
|
||||
assert (
|
||||
self.batched_deep_gemm_experts is not None
|
||||
or self.batched_triton_experts is not None
|
||||
)
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self,
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
if self.batched_triton_experts is not None:
|
||||
assert (
|
||||
self.batched_deep_gemm_experts is None
|
||||
or self.batched_deep_gemm_experts.activation_formats
|
||||
== self.batched_triton_experts.activation_formats
|
||||
)
|
||||
return self.batched_triton_experts.activation_formats
|
||||
else:
|
||||
assert self.batched_deep_gemm_experts is not None
|
||||
return self.batched_deep_gemm_experts.activation_formats
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
bdge = self.batched_deep_gemm_experts
|
||||
bte = self.batched_triton_experts
|
||||
return (bdge is None or bdge.supports_chunking()) and (
|
||||
bte is None or bte.supports_chunking()
|
||||
)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
bdge = self.batched_deep_gemm_experts
|
||||
bte = self.batched_triton_experts
|
||||
return (bdge is None or bdge.supports_expert_map()) and (
|
||||
bte is None or bte.supports_expert_map()
|
||||
)
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
bdge = self.batched_deep_gemm_experts
|
||||
bte = self.batched_triton_experts
|
||||
bdge_war = bdge.finalize_weight_and_reduce_impl() if bdge else None
|
||||
bte_war = bte.finalize_weight_and_reduce_impl() if bte else None
|
||||
is_bdge_war = bdge_war is not None
|
||||
is_bte_war = bte_war is not None
|
||||
|
||||
if is_bdge_war and is_bte_war:
|
||||
assert bdge_war == bte_war, (
|
||||
"Both implementations should agree on WeightAndReduce impls. "
|
||||
f"Got bdge_war: {bdge_war}, and bte_war: {bte_war}"
|
||||
)
|
||||
|
||||
if bdge_war is not None:
|
||||
return bdge_war
|
||||
|
||||
assert bte_war is not None
|
||||
return bte_war
|
||||
|
||||
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
|
||||
return act_dtype
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_metadata: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||
# even if we fall back to triton later, e.g. if expert maps are set.
|
||||
if self.allow_deep_gemm:
|
||||
assert self.batched_deep_gemm_experts is not None
|
||||
return self.batched_deep_gemm_experts.workspace_shapes(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
topk,
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_metadata,
|
||||
)
|
||||
else:
|
||||
assert self.batched_triton_experts is not None
|
||||
return self.batched_triton_experts.workspace_shapes(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
topk,
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_metadata,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
experts = (
|
||||
self.batched_deep_gemm_experts
|
||||
if self.allow_deep_gemm
|
||||
else self.batched_triton_experts
|
||||
)
|
||||
assert experts is not None
|
||||
experts.apply(
|
||||
output,
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
a1q_scale,
|
||||
a2_scale,
|
||||
workspace13,
|
||||
workspace2,
|
||||
expert_tokens_meta,
|
||||
apply_router_weight_on_input,
|
||||
)
|
||||
@@ -90,8 +90,10 @@ from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.deep_gemm import (
|
||||
get_col_major_tma_aligned_tensor,
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
is_deep_gemm_e8m0_used,
|
||||
)
|
||||
from vllm.utils.import_utils import has_deep_gemm
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -1088,9 +1090,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
return experts
|
||||
|
||||
# triton path
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
@@ -1098,6 +1102,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
assert not self.rocm_aiter_moe_enabled and not self.use_marlin
|
||||
|
||||
use_deep_gemm = envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM
|
||||
|
||||
if (
|
||||
prepare_finalize.activation_format
|
||||
== FusedMoEActivationFormat.BatchedExperts
|
||||
@@ -1105,22 +1111,47 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
|
||||
assert max_num_tokens_per_rank is not None
|
||||
|
||||
logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
|
||||
return BatchedTritonOrDeepGemmExperts(
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
quant_config=self.moe_quant_config,
|
||||
allow_deep_gemm=(
|
||||
envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM
|
||||
),
|
||||
if use_deep_gemm and not has_deep_gemm():
|
||||
raise RuntimeError(
|
||||
"DeepGEMM requested for MoE layer but not installed."
|
||||
)
|
||||
|
||||
compatible_with_deep_gemm = (
|
||||
self.moe_quant_config.use_fp8_w8a8
|
||||
and self.moe_quant_config.block_shape
|
||||
== get_mk_alignment_for_contiguous_layout()
|
||||
)
|
||||
|
||||
# If this MoE layer is compatible with DeepGEMM, the proper env
|
||||
# vars are set and DeepGEMM is not installed, throw an error.
|
||||
if use_deep_gemm and compatible_with_deep_gemm and not has_deep_gemm():
|
||||
raise RuntimeError(
|
||||
f"MoE layer incompatible with DeepGEMM, expected "
|
||||
f"fp8==True, got {self.moe_quant_config.use_fp8_w8a8}"
|
||||
f"or block_shape {self.moe_quant_config.block_shape}"
|
||||
f"=={get_mk_alignment_for_contiguous_layout()}."
|
||||
)
|
||||
|
||||
if use_deep_gemm and compatible_with_deep_gemm and has_deep_gemm():
|
||||
logger.debug("BatchedDeepGemmExperts(%s)", self.__class__.__name__)
|
||||
return BatchedDeepGemmExperts(
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
else:
|
||||
logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
|
||||
return BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
else:
|
||||
logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__)
|
||||
return TritonOrDeepGemmExperts(
|
||||
self.moe_quant_config,
|
||||
allow_deep_gemm=(
|
||||
envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM
|
||||
),
|
||||
allow_deep_gemm=use_deep_gemm,
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
|
||||
Reference in New Issue
Block a user