[MoE Refactor] Create MK for TRTLLM Kernels (#32564)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Robert Shaw <rshaw@neuralmagic.com> Signed-off-by: Robert Shaw <robertgshaw2@gmail.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
@@ -44,7 +44,8 @@ steps:
|
||||
- vllm/envs.py
|
||||
- vllm/config
|
||||
commands:
|
||||
- pytest -v -s kernels/moe --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
- pytest -v -s kernels/moe --ignore=kernels/moe/test_modular_oai_triton_moe.py --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
- pytest -v -s kernels/moe/test_modular_oai_triton_moe.py --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
parallelism: 2
|
||||
|
||||
- label: Kernels Mamba Test
|
||||
|
||||
@@ -12,12 +12,12 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from tests.kernels.moe.utils import make_dummy_moe_config
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
@@ -137,15 +137,21 @@ def bench_run(
|
||||
per_out_ch_quant=per_out_ch,
|
||||
)
|
||||
|
||||
fn = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
moe_config = make_dummy_moe_config(
|
||||
num_experts=num_experts,
|
||||
hidden_dim=k,
|
||||
intermediate_size_per_partition=n,
|
||||
in_dtype=a.dtype,
|
||||
)
|
||||
fn = mk.FusedMoEKernel(
|
||||
maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=quant_config,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=False,
|
||||
),
|
||||
CutlassExpertsFp8(
|
||||
moe_config=make_dummy_moe_config(
|
||||
num_experts=num_experts,
|
||||
hidden_dim=k,
|
||||
intermediate_size_per_partition=n,
|
||||
in_dtype=a.dtype,
|
||||
),
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -15,6 +15,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from tests.kernels.moe.utils import make_dummy_moe_config
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config,
|
||||
nvfp4_moe_quant_config,
|
||||
@@ -23,9 +26,6 @@ from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
CutlassExpertsFp4,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
@@ -196,10 +196,21 @@ def bench_run(
|
||||
g2_alphas=w2_gs,
|
||||
)
|
||||
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
moe_config = make_dummy_moe_config(
|
||||
num_experts=num_experts,
|
||||
hidden_dim=k,
|
||||
intermediate_size_per_partition=n,
|
||||
in_dtype=a.dtype,
|
||||
)
|
||||
kernel = mk.FusedMoEKernel(
|
||||
maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=quant_config,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=False,
|
||||
),
|
||||
CutlassExpertsFp4(
|
||||
make_dummy_moe_config(),
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
@@ -240,11 +251,17 @@ def bench_run(
|
||||
g1_alphas=w1_gs,
|
||||
g2_alphas=w2_gs,
|
||||
)
|
||||
moe_config = make_dummy_moe_config()
|
||||
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
kernel = mk.FusedMoEKernel(
|
||||
maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=quant_config,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=False,
|
||||
),
|
||||
CutlassExpertsFp4(
|
||||
make_dummy_moe_config(),
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -9,15 +9,15 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from tests.kernels.moe.utils import make_dummy_moe_config
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_experts,
|
||||
fused_topk,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
@@ -131,16 +131,22 @@ def bench_run(
|
||||
w2_scale=w2_scale,
|
||||
per_act_token_quant=per_act_token,
|
||||
)
|
||||
moe_config = make_dummy_moe_config(
|
||||
num_experts=w2.shape[0],
|
||||
hidden_dim=w2.shape[1],
|
||||
intermediate_size_per_partition=w2.shape[2],
|
||||
in_dtype=a.dtype,
|
||||
)
|
||||
|
||||
fn = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
fn = mk.FusedMoEKernel(
|
||||
maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=quant_config,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=False,
|
||||
),
|
||||
CutlassExpertsFp8(
|
||||
moe_config=make_dummy_moe_config(
|
||||
num_experts=w2.shape[0],
|
||||
hidden_dim=w2.shape[1],
|
||||
intermediate_size_per_partition=w2.shape[2],
|
||||
in_dtype=a.dtype,
|
||||
),
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
@@ -163,16 +169,22 @@ def bench_run(
|
||||
w2_scale=w2_scale,
|
||||
per_act_token_quant=per_act_token,
|
||||
)
|
||||
moe_config = make_dummy_moe_config(
|
||||
num_experts=w2.shape[0],
|
||||
hidden_dim=w2.shape[1],
|
||||
intermediate_size_per_partition=w2.shape[2],
|
||||
in_dtype=a.dtype,
|
||||
)
|
||||
|
||||
fn = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
fn = mk.FusedMoEKernel(
|
||||
maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=quant_config,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=False,
|
||||
),
|
||||
CutlassExpertsFp8(
|
||||
moe_config=make_dummy_moe_config(
|
||||
num_experts=w2.shape[0],
|
||||
hidden_dim=w2.shape[1],
|
||||
intermediate_size_per_partition=w2.shape[2],
|
||||
in_dtype=a.dtype,
|
||||
),
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -17,6 +17,9 @@ from ray.experimental.tqdm_ray import tqdm
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -242,24 +245,33 @@ def benchmark_config(
|
||||
|
||||
deep_gemm_experts = None
|
||||
if use_deep_gemm:
|
||||
deep_gemm_experts = mk.FusedMoEModularKernel(
|
||||
prepare_finalize=MoEPrepareAndFinalizeNoEP(),
|
||||
moe_config = (
|
||||
FusedMoEConfig(
|
||||
num_experts=num_experts,
|
||||
experts_per_token=topk,
|
||||
hidden_dim=hidden_size,
|
||||
intermediate_size_per_partition=shard_intermediate_size,
|
||||
num_local_experts=num_experts,
|
||||
num_logical_experts=num_experts,
|
||||
activation=MoEActivation.SILU,
|
||||
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
|
||||
in_dtype=init_dtype,
|
||||
routing_method=RoutingMethodType.TopK,
|
||||
device="cuda",
|
||||
),
|
||||
)
|
||||
deep_gemm_experts = mk.FusedMoEKernel(
|
||||
prepare_finalize=maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=quant_config,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=False,
|
||||
),
|
||||
fused_experts=TritonOrDeepGemmExperts(
|
||||
moe_config=FusedMoEConfig(
|
||||
num_experts=num_experts,
|
||||
experts_per_token=topk,
|
||||
hidden_dim=hidden_size,
|
||||
intermediate_size_per_partition=shard_intermediate_size,
|
||||
num_local_experts=num_experts,
|
||||
num_logical_experts=num_experts,
|
||||
activation=MoEActivation.SILU,
|
||||
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
|
||||
in_dtype=init_dtype,
|
||||
routing_method=RoutingMethodType.TopK,
|
||||
device="cuda",
|
||||
),
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=not disable_inplace(),
|
||||
)
|
||||
|
||||
with override_config(config):
|
||||
@@ -269,8 +281,16 @@ def benchmark_config(
|
||||
|
||||
inplace = not disable_inplace()
|
||||
if use_deep_gemm:
|
||||
return deep_gemm_experts(
|
||||
x, w1, w2, topk_weights, topk_ids, inplace=inplace
|
||||
return deep_gemm_experts.apply(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=MoEActivation.SILU,
|
||||
global_num_experts=num_experts,
|
||||
apply_router_weight_on_input=False,
|
||||
expert_map=False,
|
||||
)
|
||||
return fused_experts(
|
||||
x,
|
||||
|
||||
@@ -81,7 +81,7 @@ The current implementation has all `dbo_yield` and `dbo_maybe_run_recv_hook` cal
|
||||
|
||||
The `make_ubatch_context` function initializes two `UBatchContexts`, one for each UBatch thread. It takes two CUDA streams, the preexisting `ForwardContexts` and a CPU thread barrier. This function should be used exclusively to instantiate `UBatchContexts`. It will handle all of the event initialization.
|
||||
|
||||
The `dbo_register_recv_hook` method registers a callback that can be returned by the `FusedMoEPrepareAndFinalize` class in the other UBatch thread’s `UBatchContext`. The callback will be run when the other thread calls `dbo_maybe_run_recv_hook`. This is typically used to wait on an all-to-all kernel.
|
||||
The `dbo_register_recv_hook` method registers a callback that can be returned by the `FusedMoEPrepareAndFinalizeModular` class in the other UBatch thread’s `UBatchContext`. The callback will be run when the other thread calls `dbo_maybe_run_recv_hook`. This is typically used to wait on an all-to-all kernel.
|
||||
|
||||
The `dbo_maybe_run_recv_hook` method runs a callback that’s set by the `dbo_register_recv_hook` function if that callback exists.
|
||||
|
||||
|
||||
@@ -37,31 +37,31 @@ The rest of the document will focus on the Contiguous / Non-Batched case. Extrap
|
||||
FusedMoEModularKernel splits the FusedMoE operation into 3 parts,
|
||||
|
||||
1. TopKWeightAndReduce
|
||||
2. FusedMoEPrepareAndFinalize
|
||||
3. FusedMoEPermuteExpertsUnpermute
|
||||
2. FusedMoEPrepareAndFinalizeModular
|
||||
3. FusedMoEExpertsModular
|
||||
|
||||
### TopKWeightAndReduce
|
||||
|
||||
The TopK Weight Application and Reduction components happen right after the Unpermute operation and before the All2All Combine. Note that the `FusedMoEPermuteExpertsUnpermute` is responsible for the Unpermute and `FusedMoEPrepareAndFinalize` is responsible for the All2All Combine. There is value in doing the TopK Weight Application and Reduction in the `FusedMoEPermuteExpertsUnpermute`. But some implementations choose to do it `FusedMoEPrepareAndFinalize`. In order to enable this flexibility, we have a TopKWeightAndReduce abstract class.
|
||||
The TopK Weight Application and Reduction components happen right after the Unpermute operation and before the All2All Combine. Note that the `FusedMoEExpertsModular` is responsible for the Unpermute and `FusedMoEPrepareAndFinalizeModular` is responsible for the All2All Combine. There is value in doing the TopK Weight Application and Reduction in the `FusedMoEExpertsModular`. But some implementations choose to do it `FusedMoEPrepareAndFinalizeModular`. In order to enable this flexibility, we have a TopKWeightAndReduce abstract class.
|
||||
|
||||
Please find the implementations of TopKWeightAndReduce [here](../../vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py).
|
||||
|
||||
`FusedMoEPrepareAndFinalize::finalize()` method accepts a `TopKWeightAndReduce` argument that is invoked inside the method.
|
||||
The `FusedMoEModularKernel` acts as a bridge between the `FusedMoEPermuteExpertsUnpermute` and `FusedMoEPerpareAndFinalize` implementations to determine where the TopK Weight Application and Reduction happens.
|
||||
`FusedMoEPrepareAndFinalizeModular::finalize()` method accepts a `TopKWeightAndReduce` argument that is invoked inside the method.
|
||||
The `FusedMoEModularKernel` acts as a bridge between the `FusedMoEExpertsModular` and `FusedMoEPerpareAndFinalize` implementations to determine where the TopK Weight Application and Reduction happens.
|
||||
|
||||
* `FusedMoEPermuteExpertsUnpermute::finalize_weight_and_reduce_impl` method returns `TopKWeightAndReduceNoOp` if the `FusedMoEPermuteExpertsUnpermute` implementation does the weight application and reduction itself.
|
||||
* `FusedMoEPermuteExpertsUnpermute::finalize_weight_and_reduce_impl` method returns `TopKWeightAndReduceContiguous` / `TopKWeightAndReduceNaiveBatched` / `TopKWeightAndReduceDelegate` if the `FusedMoEPermuteExpertsUnpermute` implementation needs the `FusedMoEPrepareAndFinalize::finalize()` to do the weight application and reduction.
|
||||
* `FusedMoEExpertsModular::finalize_weight_and_reduce_impl` method returns `TopKWeightAndReduceNoOp` if the `FusedMoEExpertsModular` implementation does the weight application and reduction itself.
|
||||
* `FusedMoEExpertsModular::finalize_weight_and_reduce_impl` method returns `TopKWeightAndReduceContiguous` / `TopKWeightAndReduceNaiveBatched` / `TopKWeightAndReduceDelegate` if the `FusedMoEExpertsModular` implementation needs the `FusedMoEPrepareAndFinalizeModular::finalize()` to do the weight application and reduction.
|
||||
|
||||
### FusedMoEPrepareAndFinalize
|
||||
### FusedMoEPrepareAndFinalizeModular
|
||||
|
||||
The `FusedMoEPrepareAndFinalize` abstract class exposes `prepare`, `prepare_no_receive` and `finalize` functions.
|
||||
The `prepare` function is responsible for input activation Quantization and All2All Dispatch. If implemented, The `prepare_no_receive` is like `prepare` except it does not wait to receive results from other workers. Instead it returns a "receiver" callback that must be invoked to wait for the final results of worker. It is not required that this method is supported by all `FusedMoEPrepareAndFinalize` classes, but if it is available, it can be used to interleave work with the initial all to all communication, e.g. interleaving shared experts with fused experts. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section)
|
||||
The `FusedMoEPrepareAndFinalizeModular` abstract class exposes `prepare`, `prepare_no_receive` and `finalize` functions.
|
||||
The `prepare` function is responsible for input activation Quantization and All2All Dispatch. If implemented, The `prepare_no_receive` is like `prepare` except it does not wait to receive results from other workers. Instead it returns a "receiver" callback that must be invoked to wait for the final results of worker. It is not required that this method is supported by all `FusedMoEPrepareAndFinalizeModular` classes, but if it is available, it can be used to interleave work with the initial all to all communication, e.g. interleaving shared experts with fused experts. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section)
|
||||
|
||||

|
||||

|
||||
|
||||
### FusedMoEPermuteExpertsUnpermute
|
||||
### FusedMoEExpertsModular
|
||||
|
||||
The `FusedMoEPermuteExpertsUnpermute` class is where the crux of the MoE operations happen. The `FusedMoEPermuteExpertsUnpermute` abstract class exposes a few important functions,
|
||||
The `FusedMoEExpertsModular` class is where the crux of the MoE operations happen. The `FusedMoEExpertsModular` abstract class exposes a few important functions,
|
||||
|
||||
* apply()
|
||||
* workspace_shapes()
|
||||
@@ -81,25 +81,25 @@ The `apply` method is where the implementations perform
|
||||
|
||||
#### workspace_shapes()
|
||||
|
||||
The core FusedMoE implementation performs a series of operations. It would be inefficient to create output memory for each of these operations separately. To that effect, implementations are required to declare 2 workspace shapes, the workspace datatype and the FusedMoE output shape as outputs of the workspace_shapes() method. This information is used to allocate the workspace tensors and the output tensor in `FusedMoEModularKernel::forward()` and passed on to the `FusedMoEPermuteExpertsUnpermute::apply()` method. The workspaces could then be used as intermediate buffers in the FusedMoE implementation.
|
||||
The core FusedMoE implementation performs a series of operations. It would be inefficient to create output memory for each of these operations separately. To that effect, implementations are required to declare 2 workspace shapes, the workspace datatype and the FusedMoE output shape as outputs of the workspace_shapes() method. This information is used to allocate the workspace tensors and the output tensor in `FusedMoEModularKernel::forward()` and passed on to the `FusedMoEExpertsModular::apply()` method. The workspaces could then be used as intermediate buffers in the FusedMoE implementation.
|
||||
|
||||
#### finalize_weight_and_reduce_impl()
|
||||
|
||||
It is sometimes efficient to perform TopK weight application and Reduction inside the `FusedMoEPermuteExpertsUnpermute::apply()`. Find an example [here](https://github.com/vllm-project/vllm/pull/20228). We have a `TopKWeightAndReduce` abstract class to facilitate such implementations. Please refer to the TopKWeightAndReduce section.
|
||||
`FusedMoEPermuteExpertsUnpermute::finalize_weight_and_reduce_impl()` returns the `TopKWeightAndReduce` object that the implementation wants the `FusedMoEPrepareAndFinalize::finalize()` to use.
|
||||
It is sometimes efficient to perform TopK weight application and Reduction inside the `FusedMoEExpertsModular::apply()`. Find an example [here](https://github.com/vllm-project/vllm/pull/20228). We have a `TopKWeightAndReduce` abstract class to facilitate such implementations. Please refer to the TopKWeightAndReduce section.
|
||||
`FusedMoEExpertsModular::finalize_weight_and_reduce_impl()` returns the `TopKWeightAndReduce` object that the implementation wants the `FusedMoEPrepareAndFinalizeModular::finalize()` to use.
|
||||
|
||||

|
||||

|
||||
|
||||
### FusedMoEModularKernel
|
||||
|
||||
`FusedMoEModularKernel` is composed of the `FusedMoEPrepareAndFinalize` and `FusedMoEPermuteExpertsUnpermute` objects.
|
||||
`FusedMoEModularKernel` is composed of the `FusedMoEPrepareAndFinalizeModular` and `FusedMoEExpertsModular` objects.
|
||||
`FusedMoEModularKernel` pseudocode/sketch,
|
||||
|
||||
```py
|
||||
class FusedMoEModularKernel:
|
||||
def __init__(self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
fused_experts: FusedMoEPermuteExpertsUnpermute):
|
||||
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
|
||||
fused_experts: FusedMoEExpertsModular):
|
||||
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
@@ -128,53 +128,53 @@ class FusedMoEModularKernel:
|
||||
|
||||
## How-To
|
||||
|
||||
### How To Add a FusedMoEPrepareAndFinalize Type
|
||||
### How To Add a FusedMoEPrepareAndFinalizeModular Type
|
||||
|
||||
Typically a FusedMoEPrepareAndFinalize type is backed by an All2All Dispatch & Combine implementation / kernel. For example,
|
||||
Typically a FusedMoEPrepareAndFinalizeModular type is backed by an All2All Dispatch & Combine implementation / kernel. For example,
|
||||
|
||||
* DeepEPHTPrepareAndFinalize type is backed by DeepEP High-Throughput All2All kernels, and
|
||||
* DeepEPLLPrepareAndFinalize type is backed by DeepEP Low-Latency All2All kernels.
|
||||
|
||||
#### Step 1: Add an All2All manager
|
||||
|
||||
The purpose of the All2All Manager is to set up the All2All kernel implementations. The `FusedMoEPrepareAndFinalize` implementations typically fetch a kernel-implementation "handle" from the All2All Manager to invoke the Dispatch and Combine functions. Please look at the All2All Manager implementations [here](../../vllm/distributed/device_communicators/all2all.py).
|
||||
The purpose of the All2All Manager is to set up the All2All kernel implementations. The `FusedMoEPrepareAndFinalizeModular` implementations typically fetch a kernel-implementation "handle" from the All2All Manager to invoke the Dispatch and Combine functions. Please look at the All2All Manager implementations [here](../../vllm/distributed/device_communicators/all2all.py).
|
||||
|
||||
#### Step 2: Add a FusedMoEPrepareAndFinalize Type
|
||||
#### Step 2: Add a FusedMoEPrepareAndFinalizeModular Type
|
||||
|
||||
This section describes the significance of the various functions exposed by the `FusedMoEPrepareAndFinalize` abstract class.
|
||||
This section describes the significance of the various functions exposed by the `FusedMoEPrepareAndFinalizeModular` abstract class.
|
||||
|
||||
`FusedMoEPrepareAndFinalize::prepare()`: The prepare method implements the Quantization and All2All Dispatch. Typically the Dispatch function from the relevant All2All Manager is invoked.
|
||||
`FusedMoEPrepareAndFinalizeModular::prepare()`: The prepare method implements the Quantization and All2All Dispatch. Typically the Dispatch function from the relevant All2All Manager is invoked.
|
||||
|
||||
`FusedMoEPrepareAndFinalize::has_prepare_no_receive()`: Indicates whether or not this subclass implements `prepare_no_receive`. Defaults to False.
|
||||
`FusedMoEPrepareAndFinalizeModular::has_prepare_no_receive()`: Indicates whether or not this subclass implements `prepare_no_receive`. Defaults to False.
|
||||
|
||||
`FusedMoEPrepareAndFinalize::prepare_no_receive()`: The prepare_no_receive method implements the Quantization and All2All Dispatch. It does not wait for the result of the dispatch operation but instead returns a thunk that can be invoked to wait for the final results. Typically the Dispatch function from the relevant All2All Manager is invoked.
|
||||
`FusedMoEPrepareAndFinalizeModular::prepare_no_receive()`: The prepare_no_receive method implements the Quantization and All2All Dispatch. It does not wait for the result of the dispatch operation but instead returns a thunk that can be invoked to wait for the final results. Typically the Dispatch function from the relevant All2All Manager is invoked.
|
||||
|
||||
`FusedMoEPrepareAndFinalize::finalize()`: Maybe perform TopK Weight Application and Reduction and All2All Combine. Typically the Combine function from the relevant All2AllManager is invoked.
|
||||
`FusedMoEPrepareAndFinalizeModular::finalize()`: Maybe perform TopK Weight Application and Reduction and All2All Combine. Typically the Combine function from the relevant All2AllManager is invoked.
|
||||
|
||||
`FusedMoEPrepareAndFinalize::activation_format()`: Return `FusedMoEActivationFormat.BatchedExperts` if the output of the prepare method (i.e. the All2All dispatch) is Batched. Return `FusedMoEActivationFormat.Standard` otherwise.
|
||||
`FusedMoEPrepareAndFinalizeModular::activation_format()`: Return `FusedMoEActivationFormat.BatchedExperts` if the output of the prepare method (i.e. the All2All dispatch) is Batched. Return `FusedMoEActivationFormat.Standard` otherwise.
|
||||
|
||||
`FusedMoEPrepareAndFinalize::topk_indices_dtype()`: Data type of the TopK ids. Some All2All kernels have strict requirements pertaining to the data type of the TopK ids. This requirement is passed on to the `FusedMoe::select_experts` function so it could be respected. If there are no strict requirements return None.
|
||||
`FusedMoEPrepareAndFinalizeModular::topk_indices_dtype()`: Data type of the TopK ids. Some All2All kernels have strict requirements pertaining to the data type of the TopK ids. This requirement is passed on to the `FusedMoe::select_experts` function so it could be respected. If there are no strict requirements return None.
|
||||
|
||||
`FusedMoEPrepareAndFinalize::max_num_tokens_per_rank()`: This is the maximum number of tokens that would be submitted to the All2All Dispatch at once.
|
||||
`FusedMoEPrepareAndFinalizeModular::max_num_tokens_per_rank()`: This is the maximum number of tokens that would be submitted to the All2All Dispatch at once.
|
||||
|
||||
`FusedMoEPrepareAndFinalize::num_dispatchers()`: Total number of dispatching units. This value determines the size of the Dispatch output. The Dispatch output is of shape (num_local_experts, max_num_tokens, K). Here max_num_tokens = num_dispatchers() * max_num_tokens_per_rank().
|
||||
`FusedMoEPrepareAndFinalizeModular::num_dispatchers()`: Total number of dispatching units. This value determines the size of the Dispatch output. The Dispatch output is of shape (num_local_experts, max_num_tokens, K). Here max_num_tokens = num_dispatchers() * max_num_tokens_per_rank().
|
||||
|
||||
We suggest picking an already existing `FusedMoEPrepareAndFinalize` implementation that matches your All2All implementation closely and using it as a reference.
|
||||
We suggest picking an already existing `FusedMoEPrepareAndFinalizeModular` implementation that matches your All2All implementation closely and using it as a reference.
|
||||
|
||||
### How To Add a FusedMoEPermuteExpertsUnpermute Type
|
||||
### How To Add a FusedMoEExpertsModular Type
|
||||
|
||||
FusedMoEPermuteExpertsUnpermute performs the core of the FusedMoE operations. The various functions exposed by the abstract class and their significance is as follows,
|
||||
FusedMoEExpertsModular performs the core of the FusedMoE operations. The various functions exposed by the abstract class and their significance is as follows,
|
||||
|
||||
`FusedMoEPermuteExpertsUnpermute::activation_formats()`: Return the supported Input and Output activation formats. i.e. Contiguous / Batched format.
|
||||
`FusedMoEExpertsModular::activation_formats()`: Return the supported Input and Output activation formats. i.e. Contiguous / Batched format.
|
||||
|
||||
`FusedMoEPermuteExpertsUnpermute::supports_chunking()`: Return True if the implementation supports chunking. Typically
|
||||
`FusedMoEExpertsModular::supports_chunking()`: Return True if the implementation supports chunking. Typically
|
||||
implementations that input `FusedMoEActivationFormat.Standard` support chunking and `FusedMoEActivationFormat.BatchedExperts` do not.
|
||||
|
||||
`FusedMoEPermuteExpertsUnpermute::supports_expert_map()`: Return True if the implementation supports expert map.
|
||||
`FusedMoEExpertsModular::supports_expert_map()`: Return True if the implementation supports expert map.
|
||||
|
||||
`FusedMoEPermuteExpertsUnpermute::workspace_shapes()` /
|
||||
`FusedMoEPermuteExpertsUnpermute::finalize_weight_and_reduce_impl` /
|
||||
`FusedMoEPermuteExpertsUnpermute::apply`: Refer to `FusedMoEPermuteExpertsUnpermute` section above.
|
||||
`FusedMoEExpertsModular::workspace_shapes()` /
|
||||
`FusedMoEExpertsModular::finalize_weight_and_reduce_impl` /
|
||||
`FusedMoEExpertsModular::apply`: Refer to `FusedMoEExpertsModular` section above.
|
||||
|
||||
### FusedMoEModularKernel Initialization
|
||||
|
||||
@@ -186,14 +186,14 @@ implementations that input `FusedMoEActivationFormat.Standard` support chunking
|
||||
|
||||
#### maybe_make_prepare_finalize
|
||||
|
||||
The `maybe_make_prepare_finalize` method is responsible for constructing an instance of `FusedMoEPrepareAndFinalize` when appropriate based on the current all2all backend, e.g. when EP + DP is enabled. The base class method currently constructs all the `FusedMoEPrepareAndFinalize` objects for the EP+DP case. Derived classes can override this method to construct prepare/finalize objects for different scenarios, e.g. `ModelOptNvFp4FusedMoE` can construct a `FlashInferCutlassMoEPrepareAndFinalize` for the EP+TP case.
|
||||
The `maybe_make_prepare_finalize` method is responsible for constructing an instance of `FusedMoEPrepareAndFinalizeModular` when appropriate based on the current all2all backend, e.g. when EP + DP is enabled. The base class method currently constructs all the `FusedMoEPrepareAndFinalizeModular` objects for the EP+DP case. Derived classes can override this method to construct prepare/finalize objects for different scenarios, e.g. `ModelOptNvFp4FusedMoE` can construct a `FlashInferCutlassMoEPrepareAndFinalize` for the EP+TP case.
|
||||
Please refer to the implementations in,
|
||||
|
||||
* `ModelOptNvFp4FusedMoE`
|
||||
|
||||
#### select_gemm_impl
|
||||
|
||||
The `select_gemm_impl` method is undefined in the base class. It is the responsibility of the derived class to implement a method that constructs a valid/appropriate `FusedMoEPermuteExpertsUnpermute` object.
|
||||
The `select_gemm_impl` method is undefined in the base class. It is the responsibility of the derived class to implement a method that constructs a valid/appropriate `FusedMoEExpertsModular` object.
|
||||
Please refer to the implementations in,
|
||||
|
||||
* `UnquantizedFusedMoEMethod`
|
||||
@@ -205,7 +205,7 @@ derived classes.
|
||||
|
||||
#### init_prepare_finalize
|
||||
|
||||
Based on the input and env settings, the `init_prepare_finalize` method creates the appropriate `FusedMoEPrepareAndFinalize` object. The method then queries `select_gemm_impl` for the appropriate `FusedMoEPermuteExpertsUnpermute` object and builds the `FusedMoEModularKernel` object
|
||||
Based on the input and env settings, the `init_prepare_finalize` method creates the appropriate `FusedMoEPrepareAndFinalizeModular` object. The method then queries `select_gemm_impl` for the appropriate `FusedMoEExpertsModular` object and builds the `FusedMoEModularKernel` object
|
||||
|
||||
Please take a look at [init_prepare_finalize](https://github.com/vllm-project/vllm/blob/1cbf951ba272c230823b947631065b826409fa62/vllm/model_executor/layers/fused_moe/layer.py#L188).
|
||||
**Important**: The `FusedMoEMethodBase` derived classes use the `FusedMoEMethodBase::fused_experts` object in their `apply` methods. When settings permit the construction of a valid `FusedMoEModularKernel` object, we override `FusedMoEMethodBase::fused_experts` with it. This essentially makes the derived classes agnostic to what FusedMoE implementation is used.
|
||||
@@ -214,9 +214,9 @@ Please take a look at [init_prepare_finalize](https://github.com/vllm-project/vl
|
||||
|
||||
We have `FusedMoEModularKernel` unit tests at [test_modular_kernel_combinations.py](../../tests/kernels/moe/test_modular_kernel_combinations.py).
|
||||
|
||||
The unit test iterates through all combinations of `FusedMoEPrepareAndFinalize` and `FusedMoEPremuteExpertsUnpermute` types and if they are
|
||||
The unit test iterates through all combinations of `FusedMoEPrepareAndFinalizeModular` and `FusedMoEPremuteExpertsUnpermute` types and if they are
|
||||
compatible, runs some correctness tests.
|
||||
If you are adding some `FusedMoEPrepareAndFinalize` / `FusedMoEPermuteExpertsUnpermute` implementations,
|
||||
If you are adding some `FusedMoEPrepareAndFinalizeModular` / `FusedMoEExpertsModular` implementations,
|
||||
|
||||
1. Add the implementation type to `MK_ALL_PREPARE_FINALIZE_TYPES` and `MK_FUSED_EXPERT_TYPES` in [mk_objects.py](../../tests/kernels/moe/modular_kernel_tools/mk_objects.py) respectively.
|
||||
2. Update `Config::is_batched_prepare_finalize()`, `Config::is_batched_fused_experts()`, `Config::is_standard_fused_experts()`,
|
||||
@@ -225,24 +225,24 @@ If you are adding some `FusedMoEPrepareAndFinalize` / `FusedMoEPermuteExpertsUnp
|
||||
|
||||
Doing this will add the new implementation to the test suite.
|
||||
|
||||
### How To Check `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` Compatibility
|
||||
### How To Check `FusedMoEPrepareAndFinalizeModular` & `FusedMoEExpertsModular` Compatibility
|
||||
|
||||
The unit test file [test_modular_kernel_combinations.py](../../tests/kernels/moe/test_modular_kernel_combinations.py) can also be executed as a standalone script.
|
||||
Example: `python3 -m tests.kernels.moe.test_modular_kernel_combinations --pf-type DeepEPLLPrepareAndFinalize --experts-type BatchedTritonExperts`
|
||||
As a side effect, this script can be used to test `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` compatibility. When invoked
|
||||
As a side effect, this script can be used to test `FusedMoEPrepareAndFinalizeModular` & `FusedMoEExpertsModular` compatibility. When invoked
|
||||
with incompatible types, the script will error.
|
||||
|
||||
### How To Profile
|
||||
|
||||
Please take a look at [profile_modular_kernel.py](../../tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py)
|
||||
The script can be used to generate Torch traces for a single `FusedMoEModularKernel::forward()` call for any compatible
|
||||
`FusedMoEPrepareAndFinalize` and `FusedMoEPermuteExpertsUnpermute` types.
|
||||
`FusedMoEPrepareAndFinalizeModular` and `FusedMoEExpertsModular` types.
|
||||
Example: `python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel --pf-type DeepEPLLPrepareAndFinalize --experts-type BatchedTritonExperts`
|
||||
|
||||
## FusedMoEPrepareAndFinalize Implementations
|
||||
## FusedMoEPrepareAndFinalizeModular Implementations
|
||||
|
||||
See [Fused MoE Kernel features](./moe_kernel_features.md#fused-moe-modular-all2all-backends) for a list of all the available modular prepare and finalize subclasses.
|
||||
|
||||
## FusedMoEPermuteExpertsUnpermute
|
||||
## FusedMoEExpertsModular
|
||||
|
||||
See [Fused MoE Kernel features](./moe_kernel_features.md#fused-moe-experts-kernels) for a list of all the available modular experts.
|
||||
|
||||
@@ -4,17 +4,17 @@ The purpose of this document is to provide an overview of the various MoE kernel
|
||||
|
||||
## Fused MoE Modular All2All backends
|
||||
|
||||
There are a number of all2all communication backends that are used to implement expert parallelism (EP) for the `FusedMoE` layer. The different `FusedMoEPrepareAndFinalize` subclasses provide an interface for each all2all backend.
|
||||
There are a number of all2all communication backends that are used to implement expert parallelism (EP) for the `FusedMoE` layer. The different `FusedMoEPrepareAndFinalizeModular` subclasses provide an interface for each all2all backend.
|
||||
|
||||
The following table describes the relevant features of each backend, i.e. activation format, supported quantization schemes and async support.
|
||||
|
||||
The output activation format (standard or batched) corresponds to the output of the prepare step of the `FusedMoEPrepareAndFinalize` subclass, and the finalize step requires the same format. All the backend `prepare` methods expect activations in the standard format and all the `finalize` methods return activations in standard format. More details on the formats can be found in the [Fused MoE Modular Kernel](./fused_moe_modular_kernel.md) document.
|
||||
The output activation format (standard or batched) corresponds to the output of the prepare step of the `FusedMoEPrepareAndFinalizeModular` subclass, and the finalize step requires the same format. All the backend `prepare` methods expect activations in the standard format and all the `finalize` methods return activations in standard format. More details on the formats can be found in the [Fused MoE Modular Kernel](./fused_moe_modular_kernel.md) document.
|
||||
|
||||
The quantization types and formats enumerate which quantization schemes are supported by each `FusedMoEPrepareAndFinalize` class. The quantization can happen before or after the dispatch based on the format the all2all backend supports, e.g. deepep_high_throughput supports only block-quantized fp8 format. Any other format will result in dispatching in higher precision and quantizing afterwards. The output of the prepare step for each backend is the quantized type. The finalize step generally requires the same input type as the original activations, e.g. if the original input is bfloat16 and the quantization scheme is fp8 with per-tensor scales, `prepare` will return fp8/per-tensor scale activations and `finalize` will take bfloat16 activations. See the diagrams in [Fused MoE Modular Kernel](./fused_moe_modular_kernel.md) for more details on the types and formats of activations at each step of the MoE process. If no quantization type is specified, the kernel operates on float16 and/or bfloat16.
|
||||
The quantization types and formats enumerate which quantization schemes are supported by each `FusedMoEPrepareAndFinalizeModular` class. The quantization can happen before or after the dispatch based on the format the all2all backend supports, e.g. deepep_high_throughput supports only block-quantized fp8 format. Any other format will result in dispatching in higher precision and quantizing afterwards. The output of the prepare step for each backend is the quantized type. The finalize step generally requires the same input type as the original activations, e.g. if the original input is bfloat16 and the quantization scheme is fp8 with per-tensor scales, `prepare` will return fp8/per-tensor scale activations and `finalize` will take bfloat16 activations. See the diagrams in [Fused MoE Modular Kernel](./fused_moe_modular_kernel.md) for more details on the types and formats of activations at each step of the MoE process. If no quantization type is specified, the kernel operates on float16 and/or bfloat16.
|
||||
|
||||
Async backends support the use of DBO (Dual Batch Overlap) and shared expert overlap (where shared experts are computed during the combine step).
|
||||
|
||||
Certain models require the topk weights to be applied to the input activations rather than the output activations when topk==1, e.g. Llama. For modular kernels, this feature is supported by the `FusedMoEPrepareAndFinalize` subclass. For non-modular kernels, it is up to the experts function to deal with this flag.
|
||||
Certain models require the topk weights to be applied to the input activations rather than the output activations when topk==1, e.g. Llama. For modular kernels, this feature is supported by the `FusedMoEPrepareAndFinalizeModular` subclass. For non-modular kernels, it is up to the experts function to deal with this flag.
|
||||
|
||||
Unless otherwise specified, backends are controlled via the `--all2all-backend` command-line argument (or the `all2all_backend` parameter in `ParallelConfig`). All backends except `flashinfer` only work with EP+DP or EP+TP. `Flashinfer` can work with EP or DP without EP.
|
||||
|
||||
@@ -36,8 +36,6 @@ th {
|
||||
| deepep_high_throughput | standard | fp8 | G(128),A,T<sup>2</sup> | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] |
|
||||
| deepep_low_latency | batched | fp8 | G(128),A,T<sup>3</sup> | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] |
|
||||
| flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferA2APrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize.FlashInferA2APrepareAndFinalize] |
|
||||
| MoEPrepareAndFinalizeNoEP<sup>5</sup> | standard | fp8,int8 | G,A,T | N | Y | [`MoEPrepareAndFinalizeNoEP`][vllm.model_executor.layers.fused_moe.prepare_finalize.MoEPrepareAndFinalizeNoEP] |
|
||||
| BatchedPrepareAndFinalize<sup>5</sup> | batched | fp8,int8 | G,A,T | N | Y | [`BatchedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedPrepareAndFinalize] |
|
||||
|
||||
!!! info "Table key"
|
||||
1. All types: mxfp4, nvfp4, int4, int8, fp8
|
||||
@@ -75,9 +73,9 @@ Each experts kernel supports one or more activation functions, e.g. silu or gelu
|
||||
|
||||
As with the backends, some experts support applying topk weights on the input activations. The entries in the column in this table only apply to the non-modular experts.
|
||||
|
||||
Most experts flavors include an equivalent modular interface which will be a subclass of `FusedMoEPermuteExpertsUnpermute`.
|
||||
Most experts flavors include an equivalent modular interface which will be a subclass of `FusedMoEExpertsModular`.
|
||||
|
||||
To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels must have compatible activation formats, quantization types and quantization formats.
|
||||
To be used with a particular `FusedMoEPrepareAndFinalizeModular` subclass, MoE kernels must have compatible activation formats, quantization types and quantization formats.
|
||||
|
||||
| Kernel | Input act. format | Quant. types | Quant. format | Activation function | Apply Weight On Input | Modular | Source |
|
||||
|--------|-------------------|--------------|---------------|---------------------|-----------------------|---------|--------|
|
||||
@@ -106,7 +104,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels
|
||||
|
||||
The following table shows "families" of modular kernels that are intended to work together. There are some combinations which may work but have not yet been tested, e.g. flashinfer with other fp8 experts. Note that the "naive" backend will work with any non-modular experts.
|
||||
|
||||
| backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses |
|
||||
| backend | `FusedMoEPrepareAndFinalizeModular` subclasses | `FusedMoEExpertsModular` subclasses |
|
||||
|---------|-----------------------------------------|----------------------------------------------|
|
||||
| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` |
|
||||
| deepep_low_latency | `DeepEPLLPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`CutlassBatchedExpertsFp8`,</br>`BatchedMarlinExperts` |
|
||||
|
||||
@@ -17,13 +17,13 @@ from .mk_objects import (
|
||||
|
||||
|
||||
def make_config_arg_parser(description: str):
|
||||
def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalize:
|
||||
def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalizeModular:
|
||||
for pf in MK_ALL_PREPARE_FINALIZE_TYPES:
|
||||
if pf.__name__ == s:
|
||||
return pf
|
||||
raise ValueError(f"Cannot find a PrepareFinalize type that matches {s}")
|
||||
|
||||
def to_experts_class_type(s: str) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
def to_experts_class_type(s: str) -> mk.FusedMoEExpertsModular:
|
||||
for fe in MK_FUSED_EXPERT_TYPES:
|
||||
if fe.__name__ == s:
|
||||
return fe
|
||||
|
||||
@@ -66,7 +66,7 @@ class Config:
|
||||
quant_config: TestMoEQuantConfig | None
|
||||
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute
|
||||
fused_experts_type: mk.FusedMoEExperts
|
||||
|
||||
fused_moe_chunk_size: int | None
|
||||
world_size: int
|
||||
@@ -566,7 +566,7 @@ def make_modular_kernel(
|
||||
config: Config,
|
||||
vllm_config: VllmConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
) -> mk.FusedMoEKernel:
|
||||
def next_power_of_2(x):
|
||||
import math
|
||||
|
||||
@@ -613,7 +613,7 @@ def make_modular_kernel(
|
||||
config.N,
|
||||
)
|
||||
|
||||
modular_kernel = mk.FusedMoEModularKernel(
|
||||
modular_kernel = mk.FusedMoEKernel(
|
||||
prepare_finalize=prepare_finalize,
|
||||
fused_experts=fused_experts,
|
||||
inplace=False,
|
||||
@@ -667,6 +667,7 @@ def run_modular_kernel(
|
||||
"w2": rank_weights.w2,
|
||||
"topk_weights": rank_tensors.topk_weights,
|
||||
"topk_ids": topk_ids,
|
||||
"activation": MoEActivation.SILU,
|
||||
"expert_map": rank_tensors.expert_map,
|
||||
"global_num_experts": config.E,
|
||||
"apply_router_weight_on_input": config.topk == 1
|
||||
@@ -684,6 +685,6 @@ def run_modular_kernel(
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
):
|
||||
out = mk.forward(**mk_kwargs)
|
||||
out = mk.apply(**mk_kwargs)
|
||||
|
||||
return out
|
||||
|
||||
@@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
NaiveBatchedExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
MoEPrepareAndFinalizeNoDPEPModular,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
@@ -71,12 +71,14 @@ class ExpertInfo:
|
||||
needs_aiter: bool = False
|
||||
|
||||
|
||||
PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize, PrepareFinalizeInfo] = {}
|
||||
EXPERT_INFO: dict[mk.FusedMoEPermuteExpertsUnpermute, ExpertInfo] = {}
|
||||
MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
|
||||
MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEPermuteExpertsUnpermute] = []
|
||||
PREPARE_FINALIZE_INFO: dict[
|
||||
mk.FusedMoEPrepareAndFinalizeModular, PrepareFinalizeInfo
|
||||
] = {}
|
||||
EXPERT_INFO: dict[mk.FusedMoEExpertsModular, ExpertInfo] = {}
|
||||
MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalizeModular] = []
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalizeModular] = []
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalizeModular] = []
|
||||
MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEExpertsModular] = []
|
||||
|
||||
standard_format = mk.FusedMoEActivationFormat.Standard
|
||||
batched_format = mk.FusedMoEActivationFormat.BatchedExperts
|
||||
@@ -162,7 +164,7 @@ def expert_info(kind) -> ExpertInfo:
|
||||
|
||||
|
||||
register_prepare_and_finalize(
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
MoEPrepareAndFinalizeNoDPEPModular,
|
||||
standard_format,
|
||||
common_float_types,
|
||||
blocked_quantization_support=True,
|
||||
@@ -239,14 +241,14 @@ if has_mori():
|
||||
|
||||
if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100):
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import ( # noqa: E501
|
||||
FlashInferCutlassMoEPrepareAndFinalize,
|
||||
FlashInferA2APrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
)
|
||||
|
||||
register_prepare_and_finalize(
|
||||
FlashInferCutlassMoEPrepareAndFinalize,
|
||||
FlashInferA2APrepareAndFinalize,
|
||||
standard_format,
|
||||
nvfp4_types + fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
@@ -430,12 +432,12 @@ def make_cutlass_strides(
|
||||
|
||||
|
||||
def make_fused_experts(
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
|
||||
fused_experts_type: mk.FusedMoEExpertsModular,
|
||||
moe: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
num_dispatchers: int,
|
||||
N: int,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
) -> mk.FusedMoEExpertsModular:
|
||||
if (
|
||||
fused_experts_type.activation_format()
|
||||
== mk.FusedMoEActivationFormat.BatchedExperts
|
||||
|
||||
@@ -72,7 +72,7 @@ def profile_modular_kernel(
|
||||
"apply_router_weight_on_input": config.topk == 1,
|
||||
}
|
||||
|
||||
do_profile(mk.forward, mk_kwargs, pgi, config)
|
||||
do_profile(mk.apply, mk_kwargs, pgi, config)
|
||||
|
||||
|
||||
def rank_worker(
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts,
|
||||
)
|
||||
@@ -12,7 +13,7 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize,
|
||||
BatchedTritonExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
|
||||
from vllm.utils.deep_gemm import calc_diff, is_deep_gemm_supported
|
||||
|
||||
from .test_deepgemm import make_block_quant_fp8_weights
|
||||
@@ -74,19 +75,22 @@ def test_batched_deepgemm_vs_triton(
|
||||
quant_config=quant_config,
|
||||
moe_config=make_dummy_moe_config(),
|
||||
)
|
||||
mk_triton = FusedMoEModularKernel(
|
||||
mk_triton = FusedMoEKernel(
|
||||
prep_finalize,
|
||||
triton_experts,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
out_triton = mk_triton(
|
||||
out_triton = mk_triton.apply(
|
||||
hidden_states=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=MoEActivation.SILU,
|
||||
global_num_experts=E,
|
||||
expert_map=None,
|
||||
apply_router_weight_on_input=False,
|
||||
)
|
||||
|
||||
# deepgemm
|
||||
@@ -96,19 +100,22 @@ def test_batched_deepgemm_vs_triton(
|
||||
quant_config=quant_config,
|
||||
moe_config=make_dummy_moe_config(),
|
||||
)
|
||||
mk_deepgemm = FusedMoEModularKernel(
|
||||
mk_deepgemm = FusedMoEKernel(
|
||||
prep_finalize,
|
||||
deepgemm_experts,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
out_deepgemm = mk_deepgemm(
|
||||
out_deepgemm = mk_deepgemm.apply(
|
||||
hidden_states=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=MoEActivation.SILU,
|
||||
global_num_experts=E,
|
||||
expert_map=None,
|
||||
apply_router_weight_on_input=False,
|
||||
)
|
||||
|
||||
diff = calc_diff(out_deepgemm, out_triton)
|
||||
|
||||
@@ -21,15 +21,16 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
fused_experts,
|
||||
fused_topk,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
_valid_deep_gemm_shape,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
@@ -193,7 +194,17 @@ def test_w8a8_block_fp8_fused_moe(
|
||||
a, w1, w2, topk_weights, topk_ids, quant_config=quant_config
|
||||
)
|
||||
|
||||
m_out = m_fused_moe(a, w1, w2, topk_weights, topk_ids)
|
||||
m_out = m_fused_moe.apply(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=MoEActivation.SILU,
|
||||
apply_router_weight_on_input=False,
|
||||
expert_map=None,
|
||||
global_num_experts=w1.shape[0],
|
||||
)
|
||||
|
||||
# 0.039 only needed for M >= 8192
|
||||
tol = 0.035 if M < 8192 else 0.039
|
||||
@@ -252,23 +263,33 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
|
||||
w2_scale=w2_s,
|
||||
block_shape=block_size,
|
||||
)
|
||||
moe_config = make_dummy_moe_config()
|
||||
|
||||
deep_gemm_experts = mk.FusedMoEModularKernel(
|
||||
prepare_finalize=MoEPrepareAndFinalizeNoEP(),
|
||||
deep_gemm_experts = mk.FusedMoEKernel(
|
||||
prepare_finalize=maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=quant_config,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=False,
|
||||
),
|
||||
fused_experts=TritonOrDeepGemmExperts(
|
||||
moe_config=make_dummy_moe_config(),
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
def deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids):
|
||||
return deep_gemm_experts(
|
||||
return deep_gemm_experts.apply(
|
||||
hidden_states=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
global_num_experts=E,
|
||||
activation=MoEActivation.SILU,
|
||||
apply_router_weight_on_input=False,
|
||||
expert_map=False,
|
||||
)
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
|
||||
@@ -13,6 +13,9 @@ from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
FusedMoEQuantConfig,
|
||||
@@ -22,9 +25,6 @@ from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
CutlassExpertsFp8,
|
||||
run_cutlass_moe_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
@@ -197,20 +197,26 @@ def run_with_expert_maps(
|
||||
for kwargs, new_quant_config in slice_experts():
|
||||
w2 = kwargs["w2"]
|
||||
a = kwargs["hidden_states"]
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
moe_config = make_dummy_moe_config(
|
||||
num_experts=w2.shape[0],
|
||||
hidden_dim=w2.shape[1],
|
||||
intermediate_size_per_partition=w2.shape[2],
|
||||
in_dtype=a.dtype,
|
||||
)
|
||||
kernel = mk.FusedMoEKernel(
|
||||
maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=new_quant_config,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=False,
|
||||
),
|
||||
CutlassExpertsFp8(
|
||||
moe_config=make_dummy_moe_config(
|
||||
num_experts=w2.shape[0],
|
||||
hidden_dim=w2.shape[1],
|
||||
intermediate_size_per_partition=w2.shape[2],
|
||||
in_dtype=a.dtype,
|
||||
),
|
||||
moe_config=moe_config,
|
||||
quant_config=new_quant_config,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
out_tensor = out_tensor + kernel(**kwargs)
|
||||
out_tensor = out_tensor + kernel.apply(**kwargs)
|
||||
|
||||
return out_tensor
|
||||
|
||||
@@ -252,25 +258,35 @@ def run_8_bit(
|
||||
"w2": moe_tensors.w2_q, # type: ignore[union-attr]
|
||||
"topk_weights": topk_weights,
|
||||
"topk_ids": topk_ids,
|
||||
"global_num_experts": moe_tensors.w1_q.shape[0], # type: ignore[union-attr]
|
||||
"activation": MoEActivation.SILU,
|
||||
"expert_map": None,
|
||||
"apply_router_weight_on_input": False,
|
||||
}
|
||||
|
||||
num_experts = moe_tensors.w1.size(0) # type: ignore[attr-defined]
|
||||
with_ep = num_local_experts is not None or num_local_experts == num_experts
|
||||
if not with_ep:
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
moe_config = make_dummy_moe_config(
|
||||
num_experts=moe_tensors.w2_q.shape[0], # type: ignore[union-attr]
|
||||
hidden_dim=moe_tensors.w2_q.shape[1], # type: ignore[union-attr]
|
||||
intermediate_size_per_partition=moe_tensors.w2_q.shape[2], # type: ignore[union-attr]
|
||||
in_dtype=moe_tensors.a.dtype,
|
||||
)
|
||||
kernel = mk.FusedMoEKernel(
|
||||
maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=quant_config,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=False,
|
||||
),
|
||||
CutlassExpertsFp8(
|
||||
moe_config=make_dummy_moe_config(
|
||||
num_experts=moe_tensors.w2_q.shape[0], # type: ignore[union-attr]
|
||||
hidden_dim=moe_tensors.w2_q.shape[1], # type: ignore[union-attr]
|
||||
intermediate_size_per_partition=moe_tensors.w2_q.shape[2], # type: ignore[union-attr]
|
||||
in_dtype=moe_tensors.a.dtype,
|
||||
),
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
return kernel(**kwargs)
|
||||
return kernel.apply(**kwargs)
|
||||
|
||||
assert num_local_experts is not None
|
||||
return run_with_expert_maps(
|
||||
|
||||
@@ -22,7 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
|
||||
from vllm.utils.deep_gemm import (
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
is_deep_gemm_e8m0_used,
|
||||
@@ -170,7 +170,7 @@ def make_ll_modular_kernel(
|
||||
q_dtype: torch.dtype | None,
|
||||
test_config: TestConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
) -> FusedMoEKernel:
|
||||
assert test_config.low_latency
|
||||
assert test_config.use_fp8_dispatch is not None
|
||||
|
||||
@@ -195,7 +195,7 @@ def make_ll_modular_kernel(
|
||||
quant_config=quant_config,
|
||||
moe_config=make_dummy_moe_config(),
|
||||
)
|
||||
return FusedMoEModularKernel(
|
||||
return FusedMoEKernel(
|
||||
prepare_finalize=a2a,
|
||||
fused_experts=fused_experts,
|
||||
inplace=False,
|
||||
@@ -210,7 +210,7 @@ def make_ht_modular_kernel(
|
||||
q_dtype: torch.dtype | None,
|
||||
test_config: TestConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
) -> FusedMoEKernel:
|
||||
assert not test_config.low_latency
|
||||
assert test_config.use_fp8_dispatch is None
|
||||
|
||||
@@ -228,7 +228,7 @@ def make_ht_modular_kernel(
|
||||
moe_config=make_dummy_moe_config(),
|
||||
quant_config=quant_config,
|
||||
)
|
||||
return FusedMoEModularKernel(
|
||||
return FusedMoEKernel(
|
||||
prepare_finalize=a2a,
|
||||
fused_experts=fused_experts,
|
||||
inplace=False,
|
||||
@@ -242,11 +242,11 @@ def make_modular_kernel(
|
||||
num_local_experts: int,
|
||||
test_tensors: TestTensors,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
) -> FusedMoEKernel:
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
test_config = test_tensors.config
|
||||
|
||||
mk: FusedMoEModularKernel
|
||||
mk: FusedMoEKernel
|
||||
# Make modular kernel
|
||||
if test_config.low_latency:
|
||||
max_tokens_per_rank = max(64, next_power_of_2(test_tensors.rank_tokens.size(0)))
|
||||
@@ -307,7 +307,7 @@ def deepep_deepgemm_moe_impl(
|
||||
)
|
||||
|
||||
# Make modular kernel
|
||||
mk: FusedMoEModularKernel = make_modular_kernel(
|
||||
mk: FusedMoEKernel = make_modular_kernel(
|
||||
pg=pg,
|
||||
pgi=pgi,
|
||||
dp_size=dp_size,
|
||||
@@ -319,7 +319,7 @@ def deepep_deepgemm_moe_impl(
|
||||
with with_dp_metadata(
|
||||
M=test_tensors.rank_tokens.size(0), world_size=pgi.world_size
|
||||
):
|
||||
out = mk.forward(
|
||||
out = mk.apply(
|
||||
hidden_states=test_tensors.rank_tokens,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
|
||||
@@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
@@ -135,7 +135,7 @@ def make_modular_kernel(
|
||||
q_dtype: torch.dtype | None,
|
||||
use_fp8_dispatch: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
) -> FusedMoEKernel:
|
||||
ht_args: DeepEPHTArgs | None = None
|
||||
ll_args: DeepEPLLArgs | None = None
|
||||
|
||||
@@ -180,7 +180,7 @@ def make_modular_kernel(
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
mk = FusedMoEModularKernel(
|
||||
mk = FusedMoEKernel(
|
||||
prepare_finalize=a2a,
|
||||
fused_experts=fused_experts,
|
||||
inplace=False,
|
||||
@@ -242,7 +242,7 @@ def deep_ep_moe_impl(
|
||||
)
|
||||
|
||||
# Make modular kernel
|
||||
mk: FusedMoEModularKernel = make_modular_kernel(
|
||||
mk: FusedMoEKernel = make_modular_kernel(
|
||||
pg,
|
||||
pgi,
|
||||
low_latency_mode,
|
||||
@@ -255,7 +255,7 @@ def deep_ep_moe_impl(
|
||||
quant_config,
|
||||
)
|
||||
|
||||
out = mk.forward(
|
||||
out = mk.apply(
|
||||
hidden_states=rank_tokens_chunk,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
|
||||
@@ -14,13 +14,16 @@ import torch
|
||||
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from tests.kernels.moe.utils import make_dummy_moe_config
|
||||
from vllm.model_executor.layers.fused_moe.activation import (
|
||||
MoEActivation,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
@@ -108,11 +111,17 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
a1_scale=a1_scale,
|
||||
block_shape=block_size,
|
||||
)
|
||||
moe_config = make_dummy_moe_config()
|
||||
|
||||
deep_gemm_experts = mk.FusedMoEModularKernel(
|
||||
prepare_finalize=MoEPrepareAndFinalizeNoEP(),
|
||||
deep_gemm_experts = mk.FusedMoEKernel(
|
||||
prepare_finalize=maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=quant_config,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=False,
|
||||
),
|
||||
fused_experts=TritonOrDeepGemmExperts(
|
||||
moe_config=make_dummy_moe_config(),
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=False,
|
||||
@@ -130,12 +139,16 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
)
|
||||
|
||||
# DeepGemm
|
||||
out_deepgemm = deep_gemm_experts(
|
||||
out_deepgemm = deep_gemm_experts.apply(
|
||||
hidden_states=tokens_bf16,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
global_num_experts=num_experts,
|
||||
activation=MoEActivation.SILU,
|
||||
apply_router_weight_on_input=False,
|
||||
expert_map=None,
|
||||
)
|
||||
diff = calc_diff(out_deepgemm, out_triton)
|
||||
assert diff < 0.001, f"Diff exceeded 1%: {diff}"
|
||||
|
||||
@@ -8,6 +8,9 @@ import torch
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -15,16 +18,14 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
RoutingMethodType,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import (
|
||||
TrtLlmFp8Experts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_fi_trtllm_fp8_per_tensor_moe,
|
||||
register_scales_for_trtllm_fp8_per_tensor_moe,
|
||||
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe,
|
||||
swap_w13_to_w31,
|
||||
)
|
||||
@@ -115,6 +116,7 @@ class TestData:
|
||||
e: int,
|
||||
is_trtllm: bool,
|
||||
activation: MoEActivation = MoEActivation.SILU,
|
||||
topk: int = 1,
|
||||
) -> "TestData":
|
||||
is_gated = activation.is_gated
|
||||
|
||||
@@ -152,13 +154,6 @@ class TestData:
|
||||
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
|
||||
layer.w13_weight, layer.w2_weight, is_gated
|
||||
)
|
||||
register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer,
|
||||
layer.w13_weight_scale,
|
||||
layer.w13_input_scale,
|
||||
layer.w2_weight_scale,
|
||||
layer.w2_input_scale,
|
||||
)
|
||||
layer.custom_routing_function = Llama4MoE.custom_routing_function
|
||||
layer.routing_method_type = RoutingMethodType.Llama4
|
||||
layer.renormalize = False
|
||||
@@ -166,6 +161,21 @@ class TestData:
|
||||
layer.ep_rank = 0
|
||||
layer.local_num_experts = e
|
||||
|
||||
layer.moe = FusedMoEConfig(
|
||||
num_experts=e,
|
||||
experts_per_token=topk,
|
||||
hidden_dim=k,
|
||||
intermediate_size_per_partition=n,
|
||||
num_local_experts=e,
|
||||
num_logical_experts=e,
|
||||
moe_parallel_config=layer.moe_parallel_config,
|
||||
in_dtype=hidden_states.dtype,
|
||||
is_act_and_mul=is_gated,
|
||||
routing_method=layer.routing_method_type,
|
||||
activation=activation,
|
||||
device=w13_quantized.device,
|
||||
)
|
||||
|
||||
return TestData(
|
||||
hidden_states=hidden_states,
|
||||
w13_quantized=w13_quantized,
|
||||
@@ -230,16 +240,29 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
flashinfer_output = apply_fi_trtllm_fp8_per_tensor_moe(
|
||||
layer=td.layer,
|
||||
kernel = mk.FusedMoEKernel(
|
||||
maybe_make_prepare_finalize(
|
||||
moe=td.layer.moe,
|
||||
quant_config=quant_config,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=True,
|
||||
),
|
||||
TrtLlmFp8Experts(
|
||||
moe_config=td.layer.moe,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
flashinfer_output = kernel.apply_monolithic(
|
||||
hidden_states=td.hidden_states,
|
||||
w1=td.layer.w13_weight,
|
||||
w2=td.layer.w2_weight,
|
||||
router_logits=score,
|
||||
routing_bias=None,
|
||||
activation=activation,
|
||||
global_num_experts=e,
|
||||
top_k=topk,
|
||||
num_expert_group=None,
|
||||
topk_group=None,
|
||||
expert_map=None,
|
||||
apply_router_weight_on_input=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
|
||||
check_accuracy(
|
||||
@@ -329,8 +352,13 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
routing_method=RoutingMethodType.TopK,
|
||||
)
|
||||
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
kernel = mk.FusedMoEKernel(
|
||||
maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=quant_config,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=False,
|
||||
),
|
||||
FlashInferExperts(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
@@ -338,7 +366,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
flashinfer_cutlass_output = kernel(
|
||||
flashinfer_cutlass_output = kernel.apply(
|
||||
td.hidden_states,
|
||||
td.layer.w13_weight,
|
||||
td.layer.w2_weight,
|
||||
|
||||
@@ -14,6 +14,9 @@ from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -23,10 +26,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
is_valid_flashinfer_cutlass_fused_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
@@ -107,19 +107,27 @@ def test_flashinfer_fp4_moe_no_graph(
|
||||
routing_method=RoutingMethodType.TopK,
|
||||
)
|
||||
|
||||
flashinfer_experts = FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
flashinfer_experts = FusedMoEKernel(
|
||||
maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=quant_config,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=False,
|
||||
),
|
||||
FlashInferExperts(moe_config=moe_config, quant_config=quant_config),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
flashinfer_output = flashinfer_experts(
|
||||
flashinfer_output = flashinfer_experts.apply(
|
||||
hidden_states=a,
|
||||
w1=w1_q,
|
||||
w2=w2_q,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
apply_router_weight_on_input=False,
|
||||
)
|
||||
|
||||
# Reference check:
|
||||
|
||||
@@ -221,16 +221,16 @@ def test_marlin_vs_trtllm_mxint4_moe_kimik2(monkeypatch, m, n, k, e, topk, group
|
||||
)
|
||||
|
||||
marlin_output = fused_marlin_moe(
|
||||
a,
|
||||
w1_marlin,
|
||||
w2_marlin,
|
||||
None,
|
||||
None,
|
||||
w1_scales_marlin,
|
||||
w2_scales_marlin,
|
||||
None, # gating_output not needed when topk_weights/ids provided
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
hidden_states=a,
|
||||
w1=w1_marlin,
|
||||
w2=w2_marlin,
|
||||
bias1=None,
|
||||
bias2=None,
|
||||
w1_scale=w1_scales_marlin,
|
||||
w2_scale=w2_scales_marlin,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
quant_type_id=scalar_types.uint4b8.id,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
global_scale1=None,
|
||||
@@ -244,7 +244,6 @@ def test_marlin_vs_trtllm_mxint4_moe_kimik2(monkeypatch, m, n, k, e, topk, group
|
||||
w1_zeros=None,
|
||||
w2_zeros=None,
|
||||
input_dtype=dtype,
|
||||
quant_type_id=scalar_types.uint4b8.id,
|
||||
is_k_full=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -168,7 +168,6 @@ FUSED_MOE_CHUNK_SIZEs = [None, 16]
|
||||
def is_nyi_config(config: Config) -> bool:
|
||||
# We know these configs to be legitimate. but still fail.
|
||||
info = expert_info(config.fused_experts_type)
|
||||
|
||||
if info.needs_matching_quant:
|
||||
# The triton kernels expect both per-act-token-quant and
|
||||
# per-out-ch-quant or neither.
|
||||
@@ -259,7 +258,7 @@ def test_modular_kernel_combinations_multigpu(
|
||||
dtype: torch.dtype,
|
||||
quant_config: TestMoEQuantConfig | None,
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
|
||||
fused_experts_type: mk.FusedMoEExperts,
|
||||
chunk_size: int | None,
|
||||
world_size: int,
|
||||
pytestconfig,
|
||||
@@ -301,7 +300,7 @@ def test_modular_kernel_combinations_singlegpu(
|
||||
dtype: torch.dtype,
|
||||
quant_config: TestMoEQuantConfig | None,
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
|
||||
fused_experts_type: mk.FusedMoEExperts,
|
||||
chunk_size: int | None,
|
||||
world_size: int,
|
||||
pytestconfig,
|
||||
|
||||
@@ -7,6 +7,7 @@ Test modular OAI Triton MoE
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.utils import wait_for_gpu_memory_to_clear
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.utils.import_utils import has_triton_kernels
|
||||
|
||||
@@ -24,15 +25,15 @@ from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.testing import assert_close
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
OAITritonExperts,
|
||||
UnfusedOAITritonExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
@@ -174,19 +175,25 @@ def oai_triton_moe_impl(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
moe_config = make_dummy_moe_config()
|
||||
|
||||
if unfused:
|
||||
fused_experts = UnfusedOAITritonExperts(make_dummy_moe_config(), quant_config)
|
||||
fused_experts = UnfusedOAITritonExperts(moe_config, quant_config)
|
||||
else:
|
||||
fused_experts = OAITritonExperts(make_dummy_moe_config(), quant_config)
|
||||
fused_experts = OAITritonExperts(moe_config, quant_config)
|
||||
|
||||
mk = FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
mk = FusedMoEKernel(
|
||||
maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=quant_config,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=False,
|
||||
),
|
||||
fused_experts,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
return mk.forward(
|
||||
return mk.apply(
|
||||
hidden_states=x,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
@@ -217,6 +224,7 @@ def test_oai_triton_moe(
|
||||
unfused: bool,
|
||||
workspace_init,
|
||||
):
|
||||
wait_for_gpu_memory_to_clear(devices=[0], threshold_ratio=0.1)
|
||||
set_random_seed(0)
|
||||
(
|
||||
w1,
|
||||
|
||||
@@ -346,14 +346,16 @@ def test_fused_moe(
|
||||
expert_map: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
return m_fused_moe_fn(
|
||||
return m_fused_moe_fn.apply(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=MoEActivation.SILU,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=False,
|
||||
)
|
||||
|
||||
fused_moe_fn = functools.partial(fused_moe, renormalize=False)
|
||||
@@ -500,14 +502,16 @@ def test_naive_block_assignment_moe(
|
||||
expert_map: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
return m_fused_moe_fn(
|
||||
return m_fused_moe_fn.apply(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=MoEActivation.SILU,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=False,
|
||||
)
|
||||
|
||||
fused_moe_fn = functools.partial(fused_moe, renormalize=False)
|
||||
|
||||
@@ -15,12 +15,15 @@ from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
CutlassExpertsFp4,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
make_moe_prepare_and_finalize_no_dp_ep,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
@@ -89,22 +92,32 @@ def test_cutlass_fp4_moe_no_graph(
|
||||
w1_scale=w1_blockscale,
|
||||
w2_scale=w2_blockscale,
|
||||
)
|
||||
moe_config = make_dummy_moe_config()
|
||||
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
kernel = mk.FusedMoEKernel(
|
||||
maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=quant_config,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=False,
|
||||
),
|
||||
CutlassExpertsFp4(
|
||||
moe_config=make_dummy_moe_config(),
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
cutlass_output = kernel(
|
||||
cutlass_output = kernel.apply(
|
||||
hidden_states=a,
|
||||
w1=w1_q,
|
||||
w2=w2_q,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
global_num_experts=e,
|
||||
activation=mk.MoEActivation.SILU,
|
||||
apply_router_weight_on_input=False,
|
||||
expert_map=None,
|
||||
)
|
||||
|
||||
# Reference check:
|
||||
@@ -207,8 +220,8 @@ def test_cutlass_fp4_moe_swiglustep(
|
||||
w2_scale=w2_blockscale,
|
||||
)
|
||||
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
kernel = mk.FusedMoEKernel(
|
||||
make_moe_prepare_and_finalize_no_dp_ep(use_monolithic=False),
|
||||
CutlassExpertsFp4(
|
||||
moe_config=make_dummy_moe_config(),
|
||||
quant_config=quant_config,
|
||||
@@ -216,13 +229,16 @@ def test_cutlass_fp4_moe_swiglustep(
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
cutlass_output = kernel(
|
||||
cutlass_output = kernel.apply(
|
||||
hidden_states=a,
|
||||
w1=w1_q,
|
||||
w2=w2_q,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=MoEActivation.SWIGLUSTEP,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
apply_router_weight_on_input=False,
|
||||
)
|
||||
|
||||
# Reference: dequantize everything and run torch_moe with swiglustep
|
||||
|
||||
@@ -8,6 +8,9 @@ from tests.kernels.quant_utils import per_block_cast_to_int8
|
||||
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -23,10 +26,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
TritonExperts,
|
||||
fused_experts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
||||
@@ -125,7 +125,9 @@ def batched_moe(
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
moe_config = make_dummy_moe_config()
|
||||
|
||||
fused_experts = FusedMoEKernel(
|
||||
BatchedPrepareAndFinalize(
|
||||
max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
|
||||
),
|
||||
@@ -133,12 +135,22 @@ def batched_moe(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
quant_config=quant_config,
|
||||
moe_config=make_dummy_moe_config(),
|
||||
moe_config=moe_config,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
return fused_experts.apply(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
global_num_experts=w1.shape[0],
|
||||
activation=moe_config.activation,
|
||||
apply_router_weight_on_input=False,
|
||||
expert_map=None,
|
||||
)
|
||||
|
||||
|
||||
def naive_batched_moe(
|
||||
@@ -166,8 +178,9 @@ def naive_batched_moe(
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
moe_config = make_dummy_moe_config()
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
fused_experts = FusedMoEKernel(
|
||||
BatchedPrepareAndFinalize(
|
||||
max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
|
||||
),
|
||||
@@ -175,12 +188,22 @@ def naive_batched_moe(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
quant_config=quant_config,
|
||||
moe_config=make_dummy_moe_config(),
|
||||
moe_config=moe_config,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
return fused_experts.apply(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
global_num_experts=w1.shape[0],
|
||||
activation=moe_config.activation,
|
||||
apply_router_weight_on_input=False,
|
||||
expert_map=None,
|
||||
)
|
||||
|
||||
|
||||
def chunk_scales(
|
||||
@@ -581,9 +604,14 @@ def modular_triton_fused_moe(
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
) -> FusedMoEModularKernel:
|
||||
return FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
) -> FusedMoEKernel:
|
||||
return FusedMoEKernel(
|
||||
maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=quant_config,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=False,
|
||||
),
|
||||
TritonExperts(moe_config, quant_config),
|
||||
shared_experts,
|
||||
inplace=False,
|
||||
|
||||
@@ -127,6 +127,14 @@ def test_deepseek_fp8_block_moe_deep_gemm(monkeypatch: pytest.MonkeyPatch):
|
||||
)
|
||||
|
||||
|
||||
def test_deepseek_fp8_block_moe_vllm_triton(monkeypatch: pytest.MonkeyPatch):
|
||||
can_initialize(
|
||||
"deepseek-ai/DeepSeek-V3.1",
|
||||
hf_overrides=HF_OVERRIDE_TEXT,
|
||||
extra_args=["--moe-backend=triton"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason=(
|
||||
"Known issue: lack of kernel support. "
|
||||
@@ -149,6 +157,14 @@ def test_deepseek_fp8_block_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatc
|
||||
)
|
||||
|
||||
|
||||
def test_deepseek_nvfp4_moe_flashinfer_vllm(monkeypatch: pytest.MonkeyPatch):
|
||||
can_initialize(
|
||||
"nvidia/DeepSeek-R1-0528-FP4-v2",
|
||||
hf_overrides=HF_OVERRIDE_TEXT,
|
||||
extra_args=["--moe-backend=cutlass"],
|
||||
)
|
||||
|
||||
|
||||
def test_deepseek_nvfp4_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch):
|
||||
can_initialize(
|
||||
"nvidia/DeepSeek-R1-0528-FP4-v2",
|
||||
@@ -200,3 +216,67 @@ def test_qwen3_next_bf16_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
|
||||
hf_overrides=HF_OVERRIDE_TEXT,
|
||||
extra_args=["--moe-backend=flashinfer_trtllm"],
|
||||
)
|
||||
|
||||
|
||||
## NemoTron ##
|
||||
|
||||
|
||||
def test_nemotron_fp8_moe_flashinfer_throughput(monkeypatch: pytest.MonkeyPatch):
|
||||
can_initialize(
|
||||
"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8",
|
||||
hf_overrides=HF_OVERRIDE_TEXT,
|
||||
extra_args=["--moe-backend=flashinfer_cutlass"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason=(
|
||||
"FP8 MoE backend FLASHINFER_TRTLLM does not support the "
|
||||
"deployment configuration since kernel does not support "
|
||||
"no act_and_mul MLP layer."
|
||||
)
|
||||
)
|
||||
def test_nemotron_fp8_moe_flashinfer_latency(monkeypatch: pytest.MonkeyPatch):
|
||||
can_initialize(
|
||||
"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8",
|
||||
hf_overrides=HF_OVERRIDE_TEXT,
|
||||
extra_args=["--moe-backend=flashinfer_trtllm"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason=(
|
||||
"FP8 MoE backend TRITON does not support the "
|
||||
"deployment configuration since kernel does not support "
|
||||
"no act_and_mul MLP layer."
|
||||
)
|
||||
)
|
||||
def test_nemotron_fp8_moe_vllm_triton(monkeypatch: pytest.MonkeyPatch):
|
||||
can_initialize(
|
||||
"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8",
|
||||
hf_overrides=HF_OVERRIDE_TEXT,
|
||||
extra_args=["--moe-backend=triton"],
|
||||
)
|
||||
|
||||
|
||||
def test_nemotron_fp4_moe_flashinfer_throughput(monkeypatch: pytest.MonkeyPatch):
|
||||
can_initialize(
|
||||
"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4",
|
||||
hf_overrides=HF_OVERRIDE_TEXT,
|
||||
extra_args=["--moe-backend=flashinfer_cutlass"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason=(
|
||||
"FP4 MoE backend FLASHINFER_TRTLLM does not support the "
|
||||
"deployment configuration since kernel does not support "
|
||||
"hidden_dim % 512 != 0."
|
||||
)
|
||||
)
|
||||
def test_nemotron_fp4_moe_flashinfer_latency(monkeypatch: pytest.MonkeyPatch):
|
||||
can_initialize(
|
||||
"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4",
|
||||
hf_overrides=HF_OVERRIDE_TEXT,
|
||||
extra_args=["--moe-backend=flashinfer_trtllm"],
|
||||
)
|
||||
|
||||
@@ -32,10 +32,10 @@ from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
UnfusedOAITritonExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel,
|
||||
FusedMoEKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
MoEPrepareAndFinalizeNoDPEPModular,
|
||||
)
|
||||
|
||||
from .utils import _get_lora_device, try_get_optimal_moe_lora_config
|
||||
@@ -136,7 +136,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
if getattr(self.base_layer.quant_method, "supports_internal_mk", False):
|
||||
# Use the existing modular kernel from the quant method
|
||||
m_fused_moe_fn = self.base_layer.quant_method.moe_mk
|
||||
m_fused_moe_fn = self.base_layer.quant_method.moe_kernel
|
||||
# Don't let the kernel own shared experts so the runner can
|
||||
# overlap them with routed experts via a separate CUDA stream.
|
||||
m_fused_moe_fn.shared_experts = None
|
||||
@@ -144,8 +144,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
# Create a new modular kernel via select_gemm_impl.
|
||||
# Don't pass shared_experts to the kernel so the runner can
|
||||
# overlap them with routed experts via a separate CUDA stream.
|
||||
prepare_finalize = MoEPrepareAndFinalizeNoEP()
|
||||
m_fused_moe_fn = FusedMoEModularKernel(
|
||||
prepare_finalize = MoEPrepareAndFinalizeNoDPEPModular()
|
||||
m_fused_moe_fn = FusedMoEKernel(
|
||||
prepare_finalize,
|
||||
self.base_layer.quant_method.select_gemm_impl(
|
||||
prepare_finalize, self.base_layer
|
||||
@@ -154,10 +154,11 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
if quant_config.use_mxfp4_w4a16:
|
||||
assert isinstance(
|
||||
m_fused_moe_fn.fused_experts, (MarlinExperts, UnfusedOAITritonExperts)
|
||||
m_fused_moe_fn.impl.fused_experts,
|
||||
(MarlinExperts, UnfusedOAITritonExperts),
|
||||
)
|
||||
else:
|
||||
assert isinstance(m_fused_moe_fn.fused_experts, TritonExperts)
|
||||
assert isinstance(m_fused_moe_fn.impl.fused_experts, TritonExperts)
|
||||
|
||||
def fwd_decorator(layer, func):
|
||||
def wrapper(*args, **kwargs):
|
||||
@@ -337,9 +338,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
return wrapper
|
||||
|
||||
fused_experts = m_fused_moe_fn.fused_experts
|
||||
fused_experts = m_fused_moe_fn.impl.fused_experts
|
||||
|
||||
m_fused_moe_fn.forward = fwd_decorator(self.base_layer, m_fused_moe_fn.forward)
|
||||
m_fused_moe_fn.apply = fwd_decorator(self.base_layer, m_fused_moe_fn.apply)
|
||||
fused_experts.activation = act_decorator(
|
||||
self.base_layer, fused_experts.activation
|
||||
)
|
||||
|
||||
@@ -22,8 +22,8 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEActivationFormat,
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
FusedMoEExpertsModular,
|
||||
FusedMoEPrepareAndFinalizeModular,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
|
||||
FusedMoERouter,
|
||||
@@ -62,9 +62,9 @@ __all__ = [
|
||||
"MoEActivation",
|
||||
"UnquantizedFusedMoEMethod",
|
||||
"FusedMoeWeightScaleSupported",
|
||||
"FusedMoEPermuteExpertsUnpermute",
|
||||
"FusedMoEExpertsModular",
|
||||
"FusedMoEActivationFormat",
|
||||
"FusedMoEPrepareAndFinalize",
|
||||
"FusedMoEPrepareAndFinalizeModular",
|
||||
"GateLinear",
|
||||
"RoutingMethodType",
|
||||
"SharedFusedMoE",
|
||||
|
||||
@@ -21,8 +21,8 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNaiveEP,
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
make_moe_prepare_and_finalize_naive_dp_ep,
|
||||
make_moe_prepare_and_finalize_no_dp_ep,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import has_deep_ep, has_mori
|
||||
@@ -77,6 +77,7 @@ def maybe_make_prepare_finalize(
|
||||
quant_config: FusedMoEQuantConfig | None,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
allow_new_interface: bool = False,
|
||||
use_monolithic: bool = False,
|
||||
) -> FusedMoEPrepareAndFinalize | None:
|
||||
# NOTE(rob): we are migrating each quant_method to hold the MK
|
||||
# in all cases. The allow_new_interface=False flag allow us to fall
|
||||
@@ -102,14 +103,15 @@ def maybe_make_prepare_finalize(
|
||||
"Detected DP deployment with no --enable-expert-parallel. "
|
||||
"Falling back to AllGather+ReduceScatter dispatch/combine."
|
||||
)
|
||||
return MoEPrepareAndFinalizeNaiveEP(
|
||||
return make_moe_prepare_and_finalize_naive_dp_ep(
|
||||
is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel,
|
||||
num_dispatchers=(
|
||||
get_ep_group().device_communicator.all2all_manager.world_size
|
||||
),
|
||||
use_monolithic=use_monolithic,
|
||||
)
|
||||
else:
|
||||
return MoEPrepareAndFinalizeNoEP()
|
||||
return make_moe_prepare_and_finalize_no_dp_ep(use_monolithic)
|
||||
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
@@ -201,8 +203,9 @@ def maybe_make_prepare_finalize(
|
||||
)
|
||||
|
||||
elif moe.use_naive_all2all_kernels and allow_new_interface:
|
||||
prepare_finalize = MoEPrepareAndFinalizeNaiveEP(
|
||||
is_sequence_parallel=(moe.moe_parallel_config.is_sequence_parallel),
|
||||
prepare_finalize = make_moe_prepare_and_finalize_naive_dp_ep(
|
||||
use_monolithic=use_monolithic,
|
||||
is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel,
|
||||
num_dispatchers=all2all_manager.world_size,
|
||||
)
|
||||
|
||||
|
||||
@@ -261,7 +261,7 @@ def persistent_masked_m_silu_mul_quant(
|
||||
return y_q, y_s
|
||||
|
||||
|
||||
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class BatchedDeepGemmExperts(mk.FusedMoEExpertsModular):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
|
||||
@@ -228,6 +228,7 @@ class FusedMoEQuantConfig:
|
||||
_a2: FusedMoEQuantDesc
|
||||
_w1: FusedMoEQuantDesc
|
||||
_w2: FusedMoEQuantDesc
|
||||
is_nvfp4_scale_swizzled: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
assert not self.per_act_token_quant or self.block_shape is None, (
|
||||
@@ -475,6 +476,7 @@ class FusedMoEQuantConfig:
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
weight_dtype: torch.dtype | str | None = None,
|
||||
is_nvfp4_scale_swizzled: bool = True,
|
||||
) -> "FusedMoEQuantConfig":
|
||||
"""
|
||||
General builder function for a FusedMoEQuantConfig.
|
||||
@@ -504,6 +506,7 @@ class FusedMoEQuantConfig:
|
||||
- w2_bias: Optional biases for w1 (GPT OSS Triton).
|
||||
- w1_zp: Optional w1 zero points for int4/int8 quantization.
|
||||
- w2_zp: Optional w2 zero points for int4/int8 quantization.
|
||||
- is_nvfp4_scale_swizzled: Whether to swizzle the nvfp4 scale swizzling.
|
||||
"""
|
||||
assert not isinstance(quant_dtype, str) or quant_dtype in {
|
||||
"nvfp4",
|
||||
@@ -536,6 +539,7 @@ class FusedMoEQuantConfig:
|
||||
_w2=FusedMoEQuantDesc(
|
||||
weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias
|
||||
),
|
||||
is_nvfp4_scale_swizzled=is_nvfp4_scale_swizzled,
|
||||
)
|
||||
assert quant_config.per_act_token_quant == per_act_token_quant
|
||||
assert quant_config.per_out_ch_quant == per_out_ch_quant
|
||||
@@ -737,6 +741,7 @@ def nvfp4_moe_quant_config(
|
||||
w2_scale: torch.Tensor,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
is_nvfp4_scale_swizzled: bool = True,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for mxfp4 activations and nvp4 weights.
|
||||
@@ -754,6 +759,7 @@ def nvfp4_moe_quant_config(
|
||||
per_act_token_quant=False,
|
||||
per_out_ch_quant=False,
|
||||
block_shape=None,
|
||||
is_nvfp4_scale_swizzled=is_nvfp4_scale_swizzled,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
moe_unpermute,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
MoEPrepareAndFinalizeNoDPEPModular,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate,
|
||||
@@ -262,7 +262,7 @@ def run_cutlass_moe_fp8(
|
||||
)
|
||||
|
||||
|
||||
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class CutlassExpertsFp8Base(mk.FusedMoEExpertsModular):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
@@ -661,7 +661,7 @@ def run_cutlass_moe_fp4(
|
||||
return
|
||||
|
||||
|
||||
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
|
||||
"""CUTLASS FP4 fused MoE expert implementation."""
|
||||
|
||||
@property
|
||||
@@ -928,7 +928,7 @@ def run_cutlass_moe_w4a8_fp8(
|
||||
)
|
||||
|
||||
|
||||
class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class CutlassExpertsW4A8Fp8(mk.FusedMoEExpertsModular):
|
||||
def __init__(
|
||||
self,
|
||||
out_dtype: torch.dtype | None,
|
||||
@@ -1170,8 +1170,8 @@ def cutlass_moe_w4a8_fp8(
|
||||
|
||||
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0)
|
||||
|
||||
fn = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
fn = mk.FusedMoEKernel(
|
||||
MoEPrepareAndFinalizeNoDPEPModular(),
|
||||
CutlassExpertsW4A8Fp8(
|
||||
out_dtype=a.dtype,
|
||||
a_strides1=a_strides1,
|
||||
@@ -1186,10 +1186,9 @@ def cutlass_moe_w4a8_fp8(
|
||||
quant_config=quant_config,
|
||||
group_size=group_size,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
return fn(
|
||||
return fn.apply(
|
||||
a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
|
||||
@@ -113,7 +113,7 @@ def _valid_deep_gemm(
|
||||
return True
|
||||
|
||||
|
||||
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class DeepGemmExperts(mk.FusedMoEExpertsModular):
|
||||
"""DeepGemm-based fused MoE expert implementation."""
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig):
|
||||
|
||||
@@ -25,7 +25,7 @@ from vllm.v1.worker.ubatching import (
|
||||
)
|
||||
|
||||
|
||||
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
"""
|
||||
Prepare/Finalize using DeepEP High-Throughput kernels.
|
||||
"""
|
||||
@@ -239,6 +239,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
quant_dtype=quant_config.quant_dtype,
|
||||
per_act_token_quant=False,
|
||||
block_shape=quant_config.block_shape,
|
||||
is_fp4_scale_swizzled=quant_config.is_nvfp4_scale_swizzled,
|
||||
)
|
||||
|
||||
return (
|
||||
|
||||
@@ -49,7 +49,7 @@ def dequant_fp8(
|
||||
return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.size())
|
||||
|
||||
|
||||
class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
"""
|
||||
Prepare/Finalize using DeepEP low-latency kernels.
|
||||
"""
|
||||
@@ -119,7 +119,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
# time. This setting is handled by post_init_setup.
|
||||
self.use_ue8m0_dispatch = False
|
||||
|
||||
def post_init_setup(self, fused_experts: mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def post_init_setup(self, fused_experts: mk.FusedMoEExperts):
|
||||
if not fused_experts.supports_packed_ue8m0_act_scales():
|
||||
# Early exit.
|
||||
return
|
||||
|
||||
335
vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py
Normal file
335
vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py
Normal file
@@ -0,0 +1,335 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
activation_to_flashinfer_int,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8Static128BlockSym,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
|
||||
"""
|
||||
Fp8 TRTLLM-Gen MoE kernels. Supports monolithic interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(moe_config, quant_config)
|
||||
|
||||
if moe_config.moe_parallel_config.use_ep and quant_config.is_per_tensor:
|
||||
raise NotImplementedError(
|
||||
"EP parallelism is not supported with TRTLLM"
|
||||
"per-tensor FP8 quantization."
|
||||
)
|
||||
|
||||
self.routing_method_type = moe_config.routing_method
|
||||
self.topk = moe_config.experts_per_token
|
||||
self.intermediate_size_per_partition = (
|
||||
moe_config.intermediate_size_per_partition
|
||||
)
|
||||
self.hidden_dim = moe_config.hidden_dim
|
||||
self.local_num_experts = moe_config.num_local_experts
|
||||
self.ep_rank = moe_config.moe_parallel_config.ep_rank
|
||||
|
||||
# Make additional scales for per-tensor interface.
|
||||
if self.quant_config.is_per_tensor:
|
||||
w1_scale = self.quant_config.w1_scale
|
||||
assert w1_scale is not None
|
||||
a1_scale = self.quant_config.a1_scale
|
||||
assert a1_scale is not None
|
||||
w2_scale = self.quant_config.w2_scale
|
||||
assert w2_scale is not None
|
||||
a2_scale = self.quant_config.a2_scale
|
||||
assert a2_scale is not None
|
||||
|
||||
self._g1_alphas = (w1_scale * a1_scale).squeeze()
|
||||
self._g2_alphas = (w2_scale * a2_scale).squeeze()
|
||||
self._g1_scale_c = (
|
||||
self._g1_alphas / self.quant_config.a2_scale
|
||||
if moe_config.is_act_and_mul
|
||||
else torch.ones_like(self._g1_alphas) / self.quant_config.a2_scale
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
"""Supports only Blackwell-family GPUs."""
|
||||
p = current_platform
|
||||
# Add check flashinfer trtllm is available
|
||||
return p.is_cuda() and p.is_device_capability_family(100)
|
||||
|
||||
@staticmethod
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
"""Does not support non-gated MoE (i.e. Nanotron-3-Nano)."""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Supports Fp8 per-tensor and Fp8 block."""
|
||||
SUPPORTED_W_A = [
|
||||
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
|
||||
(kFp8StaticTensorSym, kFp8StaticTensorSym),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
"""Supports only SiLU and RELU^2 non-gated activation."""
|
||||
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
|
||||
@staticmethod
|
||||
def _supports_routing_method(
|
||||
routing_method: RoutingMethodType,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Monolithic kernels need to express router support."""
|
||||
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
|
||||
# NOTE(dbari): Default is not implemented and should not be enabled until it is
|
||||
if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
|
||||
# NOTE(rob): potentially allow others here. This is a conservative list.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
|
||||
# NOTE(dbari): as above, potentially allow others here.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Llama4,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
else:
|
||||
raise ValueError("Unsupported quantization scheme.")
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
"""Monolithic kernel so only use with naive DP/EP and TP."""
|
||||
return (
|
||||
not moe_parallel_config.use_all2all_kernels
|
||||
or moe_parallel_config.use_naive_all2all_kernels
|
||||
) and not moe_parallel_config.enable_eplb
|
||||
|
||||
@staticmethod
|
||||
def _supports_router_logits_dtype(
|
||||
router_logits_dtype: torch.dtype | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""
|
||||
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
|
||||
Only DeepSeekV3 routing supports float32 router_logits (which is converted
|
||||
internally in the kernel).
|
||||
"""
|
||||
if router_logits_dtype == torch.float32:
|
||||
# Only DeepSeekV3 routing handles float32 logits
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/2469
|
||||
return routing_method == RoutingMethodType.DeepSeekV3
|
||||
return True
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def _apply_per_block(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Delay import for non-CUDA.
|
||||
import flashinfer
|
||||
|
||||
assert not apply_router_weight_on_input
|
||||
assert activation == MoEActivation.SILU
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
e_score_correction_bias = e_score_correction_bias.to(hidden_states.dtype)
|
||||
|
||||
if self.routing_method_type == RoutingMethodType.DeepSeekV3:
|
||||
router_logits = router_logits.to(torch.float32)
|
||||
|
||||
assert self.topk <= global_num_experts
|
||||
assert self.topk <= 10
|
||||
assert global_num_experts % 4 == 0
|
||||
assert self.quant_config.block_shape == [128, 128]
|
||||
# Routing kernel expects #experts <= #threads 512
|
||||
assert global_num_experts <= 512
|
||||
|
||||
# Kernel requires transposed hidden state scales
|
||||
# TODO: fuse into the quant kernel.
|
||||
assert a1q_scale is not None
|
||||
a1q_scale_t = a1q_scale.t().contiguous()
|
||||
|
||||
return flashinfer.fused_moe.trtllm_fp8_block_scale_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=a1q_scale_t,
|
||||
gemm1_weights=w1,
|
||||
gemm1_weights_scale=self.quant_config.w1_scale,
|
||||
gemm2_weights=w2,
|
||||
gemm2_weights_scale=self.quant_config.w2_scale,
|
||||
num_experts=global_num_experts,
|
||||
top_k=self.topk,
|
||||
n_group=(num_expert_group or 0),
|
||||
topk_group=(topk_group or 0),
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
local_expert_offset=self.ep_rank * self.local_num_experts,
|
||||
local_num_experts=self.local_num_experts,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
routing_method_type=self.routing_method_type,
|
||||
use_shuffled_weight=False,
|
||||
)
|
||||
|
||||
def _apply_per_tensor(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Delay import for non-CUDA.
|
||||
import flashinfer
|
||||
from flashinfer.fused_moe.core import ActivationType
|
||||
|
||||
# Confirm supported activation function.
|
||||
assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
|
||||
activation_type = ActivationType(activation_to_flashinfer_int(activation))
|
||||
|
||||
# Confirm Llama-4 routing is proper.
|
||||
if self.routing_method_type == RoutingMethodType.Llama4:
|
||||
assert apply_router_weight_on_input
|
||||
else:
|
||||
assert not apply_router_weight_on_input
|
||||
|
||||
# The DeepSeekV3 routing method requires float32 router logits.
|
||||
if self.routing_method_type == RoutingMethodType.DeepSeekV3:
|
||||
router_logits = router_logits.to(torch.float32)
|
||||
|
||||
out = flashinfer.fused_moe.trtllm_fp8_per_tensor_scale_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
hidden_states=hidden_states,
|
||||
gemm1_weights=w1,
|
||||
output1_scales_scalar=self._g1_scale_c,
|
||||
output1_scales_gate_scalar=self._g1_alphas,
|
||||
gemm2_weights=w2,
|
||||
output2_scales_scalar=self._g2_alphas,
|
||||
num_experts=global_num_experts,
|
||||
top_k=self.topk,
|
||||
n_group=num_expert_group or 0,
|
||||
topk_group=topk_group or 0,
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
local_expert_offset=self.ep_rank * self.local_num_experts,
|
||||
local_num_experts=self.local_num_experts,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
use_routing_scales_on_input=apply_router_weight_on_input,
|
||||
routing_method_type=self.routing_method_type,
|
||||
activation_type=activation_type,
|
||||
)
|
||||
return out
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.quant_config.block_shape is not None:
|
||||
return self._apply_per_block(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
router_logits,
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
a1q_scale,
|
||||
apply_router_weight_on_input,
|
||||
num_expert_group=num_expert_group,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
topk_group=topk_group,
|
||||
)
|
||||
elif self.quant_config.is_per_tensor:
|
||||
return self._apply_per_tensor(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
router_logits,
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
a1q_scale,
|
||||
apply_router_weight_on_input,
|
||||
num_expert_group=num_expert_group,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Only per-block and per-tensor quantization are supported in "
|
||||
f"{self.__class__.__name__}."
|
||||
)
|
||||
326
vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py
Normal file
326
vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py
Normal file
@@ -0,0 +1,326 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
activation_to_flashinfer_int,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kNvfp4Dynamic,
|
||||
kNvfp4Static,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class TrtLlmNvFp4ExpertsBase:
|
||||
"""
|
||||
NvFp4 TRTLLM-Gen MoE kernels. Supports modular and monolithic interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
self.moe_config = moe_config
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.routing_method_type = self.moe_config.routing_method
|
||||
self.topk = moe_config.experts_per_token
|
||||
self.intermediate_size_per_partition = (
|
||||
moe_config.intermediate_size_per_partition
|
||||
)
|
||||
self.hidden_dim = moe_config.hidden_dim
|
||||
self.local_num_experts = moe_config.num_local_experts
|
||||
self.ep_rank = moe_config.moe_parallel_config.ep_rank
|
||||
|
||||
assert self.quant_config.g1_alphas is not None
|
||||
assert self.quant_config.a2_gscale is not None
|
||||
if moe_config.is_act_and_mul:
|
||||
# g1_alpha_s = a13_scale * w13_scale_2
|
||||
# a2_gscale = (1 / a2_scale)
|
||||
# g1_scale_c = a13_scale * w13_scale_2 / a2_scale
|
||||
self.g1_scale_c = self.quant_config.g1_alphas * self.quant_config.a2_gscale
|
||||
else:
|
||||
self.g1_scale_c = (
|
||||
torch.ones_like(self.quant_config.a1_gscale)
|
||||
* self.quant_config.a2_gscale
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
"""Supports only Blackwell-family GPUs."""
|
||||
p = current_platform
|
||||
return p.is_cuda() and p.is_device_capability_family(100)
|
||||
|
||||
@staticmethod
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
"""Supports non-gated MoE (i.e. Nemotron-Nano)."""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Supports Nvfp4 quantization."""
|
||||
SUPPORTED_W_A = [
|
||||
(kNvfp4Static, kNvfp4Dynamic),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
"""Supports only SiLU and RELU^2 non-gated activation."""
|
||||
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
|
||||
@staticmethod
|
||||
def _supports_shape(hidden_dim: int) -> bool:
|
||||
"""Requires hidden dim to be multiple of 512."""
|
||||
return hidden_dim % 512 == 0
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class TrtLlmNvFp4ExpertsModular(TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsModular):
|
||||
"""
|
||||
Modular version of the implementation (just the experts).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
"""The modular implementation supports all parallel configs."""
|
||||
return True
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# The workspaces for this implementation are managed by flashinfer.
|
||||
workspace1 = (0,)
|
||||
workspace2 = (0,)
|
||||
|
||||
# Hidden states are Nvfp4, packed into int8 dtype, so we
|
||||
# need to multiply K by 2 to get the output shape right.
|
||||
assert self.hidden_dim == K * 2
|
||||
output = (M, self.hidden_dim)
|
||||
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
assert a1q_scale is not None
|
||||
assert self.quant_config.w1_scale is not None
|
||||
assert self.quant_config.w2_scale is not None
|
||||
|
||||
# Pack topk ids and weights into format expected by the kernel.
|
||||
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
|
||||
torch.bfloat16
|
||||
).view(torch.int16)
|
||||
|
||||
# trtllm_fp4_block_scale_routed_moe does not support autotuning
|
||||
# so skip this kernel during dummy run for autotuning.
|
||||
import vllm.utils.flashinfer as fi_utils
|
||||
|
||||
if fi_utils._is_fi_autotuning:
|
||||
return hidden_states
|
||||
|
||||
# Invoke kernel.
|
||||
flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
|
||||
topk_ids=packed_tensor,
|
||||
routing_bias=None,
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape(
|
||||
*hidden_states.shape[:-1], -1
|
||||
),
|
||||
gemm1_weights=w1,
|
||||
gemm1_weights_scale=self.quant_config.w1_scale.view(torch.float8_e4m3fn),
|
||||
gemm1_bias=None,
|
||||
gemm1_alpha=None,
|
||||
gemm1_beta=None,
|
||||
gemm1_clamp_limit=None,
|
||||
gemm2_weights=w2,
|
||||
gemm2_weights_scale=self.quant_config.w2_scale.view(torch.float8_e4m3fn),
|
||||
gemm2_bias=None,
|
||||
output1_scale_scalar=self.g1_scale_c,
|
||||
output1_scale_gate_scalar=self.quant_config.g1_alphas,
|
||||
output2_scale_scalar=self.quant_config.g2_alphas,
|
||||
num_experts=global_num_experts,
|
||||
top_k=self.topk,
|
||||
n_group=0,
|
||||
topk_group=0,
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
local_expert_offset=self.ep_rank * self.local_num_experts,
|
||||
local_num_experts=self.local_num_experts,
|
||||
routed_scaling_factor=None,
|
||||
routing_method_type=1,
|
||||
do_finalize=True,
|
||||
activation_type=activation_to_flashinfer_int(activation),
|
||||
output=output,
|
||||
)
|
||||
|
||||
|
||||
class TrtLlmNvFp4ExpertsMonolithic(
|
||||
TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsMonolithic
|
||||
):
|
||||
"""
|
||||
Monolithic version of the kernel (router + experts).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
"""The modular implementation should be used for the Dp/Ep or EPLB case."""
|
||||
return (
|
||||
not moe_parallel_config.use_all2all_kernels
|
||||
and not moe_parallel_config.enable_eplb
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_routing_method(
|
||||
routing_method_type: RoutingMethodType,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
# NOTE(rob): this is a conservative list.
|
||||
return routing_method_type in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
RoutingMethodType.Llama4,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _supports_router_logits_dtype(
|
||||
router_logits_dtype: torch.dtype | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""
|
||||
The FlashInfer TRTLLM NvFp4 kernel expects bfloat16 router_logits by default.
|
||||
Only DeepSeekV3 routing supports float32 router_logits (which is converted
|
||||
internally in the kernel).
|
||||
"""
|
||||
if router_logits_dtype == torch.float32:
|
||||
# Only DeepSeekV3 routing handles float32 logits
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/2469
|
||||
return routing_method == RoutingMethodType.DeepSeekV3
|
||||
return True
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
assert a1q_scale is not None
|
||||
assert self.quant_config.w1_scale is not None
|
||||
assert self.quant_config.w2_scale is not None
|
||||
assert (
|
||||
apply_router_weight_on_input
|
||||
and self.routing_method_type == RoutingMethodType.Llama4
|
||||
) or (
|
||||
not apply_router_weight_on_input
|
||||
and self.routing_method_type != RoutingMethodType.Llama4
|
||||
)
|
||||
|
||||
# Prepare routing bias into kernel format.
|
||||
routing_bias = e_score_correction_bias
|
||||
if routing_bias is not None:
|
||||
routing_bias = routing_bias.to(torch.bfloat16)
|
||||
router_logits = (
|
||||
router_logits.to(torch.float32)
|
||||
if self.routing_method_type == RoutingMethodType.DeepSeekV3
|
||||
else router_logits
|
||||
)
|
||||
|
||||
# Invoke kernel.
|
||||
return flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape(
|
||||
*hidden_states.shape[:-1], -1
|
||||
),
|
||||
gemm1_weights=w1,
|
||||
gemm1_weights_scale=self.quant_config.w1_scale.view(torch.float8_e4m3fn),
|
||||
gemm1_bias=None,
|
||||
gemm1_alpha=None,
|
||||
gemm1_beta=None,
|
||||
gemm1_clamp_limit=None,
|
||||
gemm2_weights=w2,
|
||||
gemm2_weights_scale=self.quant_config.w2_scale.view(torch.float8_e4m3fn),
|
||||
gemm2_bias=None,
|
||||
output1_scale_scalar=self.g1_scale_c,
|
||||
output1_scale_gate_scalar=self.quant_config.g1_alphas,
|
||||
output2_scale_scalar=self.quant_config.g2_alphas,
|
||||
num_experts=global_num_experts,
|
||||
top_k=self.topk,
|
||||
n_group=(num_expert_group or 0),
|
||||
topk_group=(topk_group or 0),
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
local_expert_offset=self.ep_rank * self.local_num_experts,
|
||||
local_num_experts=self.local_num_experts,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
routing_method_type=self.routing_method_type,
|
||||
do_finalize=True,
|
||||
)[0]
|
||||
@@ -11,13 +11,13 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||
|
||||
|
||||
class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
|
||||
class FallbackExperts(mk.FusedMoEExpertsModular, ABC):
|
||||
"""Base class for runtime dispatching of expert implementations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experts: mk.FusedMoEPermuteExpertsUnpermute,
|
||||
fallback_experts: mk.FusedMoEPermuteExpertsUnpermute,
|
||||
experts: mk.FusedMoEExpertsModular,
|
||||
fallback_experts: mk.FusedMoEExpertsModular,
|
||||
):
|
||||
super().__init__(
|
||||
moe_config=experts.moe_config, quant_config=experts.quant_config
|
||||
@@ -27,8 +27,8 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
|
||||
|
||||
@staticmethod
|
||||
def get_clses() -> tuple[
|
||||
type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
type[mk.FusedMoEExpertsModular],
|
||||
type[mk.FusedMoEExpertsModular],
|
||||
]:
|
||||
"""
|
||||
Get the cls for the experts and fallback experts.
|
||||
@@ -149,7 +149,7 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
) -> mk.FusedMoEExpertsModular:
|
||||
raise NotImplementedError
|
||||
|
||||
def apply(
|
||||
|
||||
@@ -18,7 +18,7 @@ def get_local_sizes():
|
||||
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
|
||||
|
||||
class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
"""Base class for FlashInfer MoE prepare and finalize operations."""
|
||||
|
||||
def __init__(
|
||||
@@ -185,8 +185,8 @@ def flashinfer_alltoall_dispatch(
|
||||
ep_size,
|
||||
)
|
||||
|
||||
# Swizzle after the A2A if nvfp4.
|
||||
if quant_config.quant_dtype == "nvfp4":
|
||||
# Swizzle after the A2A if MoE kernel expects swizzled scales.
|
||||
if quant_config.quant_dtype == "nvfp4" and quant_config.is_nvfp4_scale_swizzled:
|
||||
if x_sf.element_size() == 1:
|
||||
x_sf = x_sf.view(torch.uint8)
|
||||
x_sf = nvfp4_block_scale_interleave(x_sf)
|
||||
|
||||
@@ -30,7 +30,7 @@ from vllm.utils.flashinfer import (
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
|
||||
@@ -60,7 +60,7 @@ def is_valid_flashinfer_cutlass_fused_moe(
|
||||
return True
|
||||
|
||||
|
||||
class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class FlashInferExperts(mk.FusedMoEExpertsModular):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: mk.FusedMoEConfig,
|
||||
|
||||
@@ -10,16 +10,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEParallelConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8Static128BlockSym,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
@@ -39,49 +29,10 @@ def _supports_no_act_and_mul() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Supports Fp8 per-tensor and Fp8 block."""
|
||||
SUPPORTED_W_A = [
|
||||
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
|
||||
(kFp8StaticTensorSym, kFp8StaticTensorSym),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
|
||||
|
||||
def _supports_routing_method(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""Monolithic kernels need to express router support."""
|
||||
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
|
||||
# NOTE(dbari): Default is not implemented and should not be enabled until it is
|
||||
if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
|
||||
# NOTE(rob): potentially allow others here. This is a conservative list.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
|
||||
# NOTE(dbari): as above, potentially allow others here.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Llama4,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
else:
|
||||
raise ValueError("Unsupported quantization scheme.")
|
||||
|
||||
|
||||
def _supports_routing_method_bf16(
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
@@ -99,62 +50,6 @@ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bo
|
||||
return not moe_parallel_config.enable_eplb
|
||||
|
||||
|
||||
def _supports_router_logits_dtype(
|
||||
router_logits_dtype: torch.dtype | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""
|
||||
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
|
||||
Only DeepSeekV3 routing supports float32 router_logits (which is converted
|
||||
internally in the kernel).
|
||||
"""
|
||||
if router_logits_dtype == torch.float32:
|
||||
# Only DeepSeekV3 routing handles float32 logits
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/2469
|
||||
return routing_method == RoutingMethodType.DeepSeekV3
|
||||
return True
|
||||
|
||||
|
||||
def is_supported_config_trtllm_fp8(
|
||||
moe_config: FusedMoEConfig,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
|
||||
"""
|
||||
|
||||
def _make_reason(reason: str) -> str:
|
||||
return f"kernel does not support {reason}"
|
||||
|
||||
if not _supports_current_device():
|
||||
return False, _make_reason(f"current device {current_platform.device_name}")
|
||||
elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
|
||||
return False, _make_reason("no act_and_mul MLP layer")
|
||||
elif not _supports_activation(moe_config.activation):
|
||||
return False, _make_reason(f"{moe_config.activation} activation")
|
||||
elif not _supports_quant_scheme(weight_key, activation_key):
|
||||
return False, _make_reason(f"quantization scheme {weight_key}x{activation_key}")
|
||||
elif not _supports_parallel_config(moe_config.moe_parallel_config):
|
||||
return False, _make_reason(f"parallel config {moe_config.moe_parallel_config}")
|
||||
elif not _supports_routing_method(
|
||||
weight_key, activation_key, moe_config.routing_method
|
||||
):
|
||||
return False, _make_reason(f"routing method {moe_config.routing_method}")
|
||||
elif activation_format != mk.FusedMoEActivationFormat.Standard:
|
||||
return False, _make_reason(f"activation format {activation_format}")
|
||||
elif not _supports_router_logits_dtype(
|
||||
moe_config.router_logits_dtype, moe_config.routing_method
|
||||
):
|
||||
return False, _make_reason(
|
||||
"float32 router_logits with non-DeepSeekV3 routing "
|
||||
f"{moe_config.router_logits_dtype}x{moe_config.routing_method}"
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def is_supported_config_trtllm_bf16(
|
||||
moe_config: FusedMoEConfig,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
@@ -183,199 +78,6 @@ def is_supported_config_trtllm_bf16(
|
||||
return True, None
|
||||
|
||||
|
||||
def flashinfer_fused_moe_blockscale_fp8(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
x: torch.Tensor,
|
||||
w13_weight: torch.Tensor,
|
||||
w13_weight_scale_inv: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
w2_weight_scale_inv: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
intermediate_size: int,
|
||||
expert_offset: int,
|
||||
local_num_experts: int,
|
||||
block_shape: list[int],
|
||||
routing_method_type: int,
|
||||
routed_scaling: float | None = 1.0,
|
||||
) -> torch.Tensor:
|
||||
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
|
||||
|
||||
num_expert_group = num_expert_group if num_expert_group is not None else 0
|
||||
topk_group = topk_group if topk_group is not None else 0
|
||||
assert top_k <= global_num_experts
|
||||
assert top_k <= 10
|
||||
assert global_num_experts % 4 == 0
|
||||
assert block_shape == [128, 128]
|
||||
# Routing kernel expects #experts <= #threads 512
|
||||
assert global_num_experts <= 512
|
||||
|
||||
# The DeepSeekV3 routing method requires float32 router logits.
|
||||
if routing_method_type == RoutingMethodType.DeepSeekV3:
|
||||
routing_logits = routing_logits.to(torch.float32)
|
||||
|
||||
if routing_bias is not None:
|
||||
routing_bias = routing_bias.to(x.dtype)
|
||||
|
||||
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
|
||||
# NOTE: scales of hidden states have to be transposed!
|
||||
a_sf_t = a_sf.t().contiguous()
|
||||
return flashinfer_trtllm_fp8_block_scale_moe(
|
||||
routing_logits=routing_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=a_q,
|
||||
hidden_states_scale=a_sf_t,
|
||||
gemm1_weights=w13_weight,
|
||||
gemm1_weights_scale=w13_weight_scale_inv,
|
||||
gemm2_weights=w2_weight,
|
||||
gemm2_weights_scale=w2_weight_scale_inv,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
n_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=intermediate_size,
|
||||
local_expert_offset=expert_offset,
|
||||
local_num_experts=local_num_experts,
|
||||
routed_scaling_factor=routed_scaling,
|
||||
routing_method_type=routing_method_type,
|
||||
use_shuffled_weight=False,
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_fused_moe_blockscale_fp8_fake(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
x: torch.Tensor,
|
||||
w13_weight: torch.Tensor,
|
||||
w13_weight_scale_inv: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
w2_weight_scale_inv: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
intermediate_size: int,
|
||||
expert_offset: int,
|
||||
local_num_experts: int,
|
||||
block_shape: list[int],
|
||||
routing_method_type: int,
|
||||
routed_scaling: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
# TODO(bnell): Does this really need to be a torch.op?
|
||||
direct_register_custom_op(
|
||||
op_name="flashinfer_fused_moe_blockscale_fp8",
|
||||
op_func=flashinfer_fused_moe_blockscale_fp8,
|
||||
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order,),
|
||||
)
|
||||
|
||||
|
||||
def fi_trtllm_fp8_per_tensor_moe(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
hidden_states: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
output1_scales_scalar: torch.Tensor,
|
||||
output1_scales_gate_scalar: torch.Tensor,
|
||||
output2_scales_scalar: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
intermediate_size: int,
|
||||
local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
use_routing_scales_on_input: bool,
|
||||
routing_method_type: int,
|
||||
activation_type: int,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
num_expert_group = num_expert_group if num_expert_group is not None else 0
|
||||
topk_group = topk_group if topk_group is not None else 0
|
||||
|
||||
quant_hidden_states, _ = moe_kernel_quantize_input(
|
||||
hidden_states,
|
||||
input_scale,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
from flashinfer.fused_moe.core import ActivationType
|
||||
|
||||
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_per_tensor_scale_moe
|
||||
|
||||
# The DeepSeekV3 routing method requires float32 router logits.
|
||||
if routing_method_type == RoutingMethodType.DeepSeekV3:
|
||||
routing_logits = routing_logits.to(torch.float32)
|
||||
|
||||
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
|
||||
routing_logits=routing_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=quant_hidden_states,
|
||||
gemm1_weights=gemm1_weights,
|
||||
output1_scales_scalar=output1_scales_scalar,
|
||||
output1_scales_gate_scalar=output1_scales_gate_scalar,
|
||||
gemm2_weights=gemm2_weights,
|
||||
output2_scales_scalar=output2_scales_scalar,
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
n_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=intermediate_size,
|
||||
local_expert_offset=local_expert_offset,
|
||||
local_num_experts=local_num_experts,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
use_routing_scales_on_input=use_routing_scales_on_input,
|
||||
routing_method_type=routing_method_type,
|
||||
# TODO: enum type Required for flashinfer==0.6.3, remove with update
|
||||
# https://github.com/flashinfer-ai/flashinfer/pull/2508
|
||||
activation_type=ActivationType(activation_type),
|
||||
)
|
||||
|
||||
|
||||
def fi_trtllm_fp8_per_tensor_moe_fake(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
hidden_states: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
output1_scales_scalar: torch.Tensor,
|
||||
output1_scales_gate_scalar: torch.Tensor,
|
||||
output2_scales_scalar: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
intermediate_size: int,
|
||||
local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
use_routing_scales_on_input: bool,
|
||||
routing_method_type: int,
|
||||
activation_type: int,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
# TODO(bnell): Does this really need to be a torch.op?
|
||||
direct_register_custom_op(
|
||||
op_name="fi_trtllm_fp8_per_tensor_moe",
|
||||
op_func=fi_trtllm_fp8_per_tensor_moe,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=fi_trtllm_fp8_per_tensor_moe_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order,),
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_fused_moe_bf16(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
|
||||
@@ -489,7 +489,7 @@ def invoke_moe_batched_triton_kernel(
|
||||
)
|
||||
|
||||
|
||||
class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
"""
|
||||
A reference prepare/finalize class that reorganizes the tokens into
|
||||
expert batched format, i.e. E x max_num_tokens x K. This is the format
|
||||
@@ -645,7 +645,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
)
|
||||
|
||||
|
||||
class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class NaiveBatchedExperts(mk.FusedMoEExpertsModular):
|
||||
"""
|
||||
A reference MoE expert class that operates on expert batched format,
|
||||
i.e. E x max_num_tokens x K. This is the format that the batched
|
||||
@@ -877,7 +877,7 @@ def batched_moe_kernel_quantize_input(
|
||||
return A_q, A_q_scale
|
||||
|
||||
|
||||
class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class BatchedTritonExperts(mk.FusedMoEExpertsModular):
|
||||
"""
|
||||
A Triton based MoE expert class that operates on expert batched format,
|
||||
i.e. E x max_num_tokens x K. This is the format that the batched
|
||||
|
||||
@@ -526,7 +526,7 @@ def batched_fused_marlin_moe(
|
||||
return output
|
||||
|
||||
|
||||
class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class MarlinExpertsBase(mk.FusedMoEExpertsModular):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
|
||||
@@ -1736,7 +1736,7 @@ def fused_experts_impl(
|
||||
intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
|
||||
|
||||
# This needs separate memory since it's used concurrently with cache1
|
||||
activation_out_dim = mk.FusedMoEPermuteExpertsUnpermute.adjust_N_for_activation(
|
||||
activation_out_dim = mk.FusedMoEExpertsModular.adjust_N_for_activation(
|
||||
N, activation_enum
|
||||
)
|
||||
intermediate_cache2 = torch.empty(
|
||||
@@ -1924,7 +1924,7 @@ def fused_experts_impl(
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class TritonExperts(mk.FusedMoEExpertsModular):
|
||||
"""Triton-based fused MoE expert implementation."""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -12,8 +12,8 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
FusedMoEExpertsModular,
|
||||
FusedMoEPrepareAndFinalizeModular,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
@@ -27,19 +27,21 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
super().__init__()
|
||||
self.moe: FusedMoEConfig = moe
|
||||
self.moe_quant_config: FusedMoEQuantConfig | None = None
|
||||
self.moe_mk: mk.FusedMoEModularKernel | None = None
|
||||
self.moe_kernel: mk.FusedMoEKernel | None = None
|
||||
|
||||
@property
|
||||
def supports_internal_mk(self) -> bool:
|
||||
# NOTE(rob): temporary attribute to indicate support for
|
||||
# completed migration to the new internal MK interface.
|
||||
return self.moe_mk is not None
|
||||
return self.moe_kernel is not None
|
||||
|
||||
@property
|
||||
def mk_owns_shared_expert(self) -> bool:
|
||||
# NOTE(rob): temporary attribute to indicate support for
|
||||
# completed migration to the new internal MK interface.
|
||||
return self.moe_mk is not None and self.moe_mk.shared_experts is not None
|
||||
return (
|
||||
self.moe_kernel is not None and self.moe_kernel.shared_experts is not None
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(
|
||||
@@ -66,35 +68,25 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> FusedMoEPrepareAndFinalize | None:
|
||||
) -> FusedMoEPrepareAndFinalizeModular | None:
|
||||
from .all2all_utils import maybe_make_prepare_finalize
|
||||
|
||||
return maybe_make_prepare_finalize(
|
||||
pf = maybe_make_prepare_finalize(
|
||||
self.moe, self.moe_quant_config, routing_tables
|
||||
)
|
||||
assert pf is None or isinstance(pf, FusedMoEPrepareAndFinalizeModular)
|
||||
return pf
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
) -> FusedMoEExpertsModular:
|
||||
# based on the all2all implementation, select the appropriate
|
||||
# gemm implementation
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} must select appropriate gemm "
|
||||
"implementation based on the prepare_finalize"
|
||||
)
|
||||
|
||||
def prepare_dp_allgather_tensor(
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
"""Hook to prepare tensors and extra tensors for DP allgather + EP dispatch."""
|
||||
raise NotImplementedError(
|
||||
"Method 'prepare_dp_allgather_tensor' is not implemented in "
|
||||
f"{self.__class__.__name__}."
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@@ -105,8 +97,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
@property
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
if self.moe_mk is not None:
|
||||
return self.moe_mk.prepare_finalize.topk_indices_dtype()
|
||||
if self.moe_kernel is not None:
|
||||
return self.moe_kernel.prepare_finalize.topk_indices_dtype()
|
||||
return None
|
||||
|
||||
@property
|
||||
@@ -119,7 +111,12 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return False
|
||||
if self.moe_kernel is None:
|
||||
if hasattr(self, "experts_cls"):
|
||||
return self.experts_cls.is_monolithic()
|
||||
else:
|
||||
return False
|
||||
return self.moe_kernel.is_monolithic
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
@@ -13,8 +13,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
FusedMoEKernel,
|
||||
FusedMoEPrepareAndFinalizeModular,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -26,15 +26,15 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
# --8<-- [end:modular_fused_moe]
|
||||
|
||||
def __init__(
|
||||
self, old_quant_method: FusedMoEMethodBase, experts: FusedMoEModularKernel
|
||||
self, old_quant_method: FusedMoEMethodBase, moe_kernel: FusedMoEKernel
|
||||
):
|
||||
super().__init__(old_quant_method.moe)
|
||||
self.moe_quant_config = old_quant_method.moe_quant_config
|
||||
self.moe_mk = experts
|
||||
self.moe_kernel = moe_kernel
|
||||
self.disable_expert_map = getattr(
|
||||
old_quant_method,
|
||||
"disable_expert_map",
|
||||
not self.moe_mk.supports_expert_map(),
|
||||
not self.moe_kernel.supports_expert_map(),
|
||||
)
|
||||
self.old_quant_method = old_quant_method
|
||||
logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__)
|
||||
@@ -43,13 +43,13 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
def make(
|
||||
moe_layer: torch.nn.Module,
|
||||
old_quant_method: FusedMoEMethodBase,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
|
||||
shared_experts: torch.nn.Module | None,
|
||||
inplace: bool = False,
|
||||
) -> "FusedMoEModularMethod":
|
||||
return FusedMoEModularMethod(
|
||||
old_quant_method,
|
||||
FusedMoEModularKernel(
|
||||
FusedMoEKernel(
|
||||
prepare_finalize,
|
||||
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
|
||||
shared_experts,
|
||||
@@ -90,8 +90,8 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk(
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
|
||||
@@ -511,7 +511,7 @@ def make_routing_data(
|
||||
return routing_data, gather_indx, scatter_indx
|
||||
|
||||
|
||||
class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
raise NotImplementedError(
|
||||
|
||||
@@ -20,6 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache,
|
||||
@@ -56,25 +57,25 @@ logger = init_logger(__name__)
|
||||
# MoE kernel implementations.
|
||||
#
|
||||
# The following main classes are defined:
|
||||
# * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE
|
||||
# * FusedMoEPrepareAndFinalizeModular - an abstract base class for preparation of MoE
|
||||
# inputs (e.g. quantization, distribution) and finalization of Moe outputs.
|
||||
# The prepare method must take care of any needed quantization and the
|
||||
# finalize method, informed by the FusedMoEPermuteExpertsUnpermute method,
|
||||
# finalize method, informed by the FusedMoEExpertsModular method,
|
||||
# may apply weights and/or do the final reduction of the output.
|
||||
# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused
|
||||
# * FusedMoEExpertsModular - an abstract base class for the main fused
|
||||
# MoE operation, i.e matmul + act_mul + optionally quant + matmul.
|
||||
# Some FusedMoEPermuteExpertsUnpermute implementations may choose to do
|
||||
# Some FusedMoEExpertsModular implementations may choose to do
|
||||
# the weight application and/or reduction. The class communicates this
|
||||
# to [Finalize] via a TopKWeightAndReduce object.
|
||||
# * FusedMoEModularKernel - an interface class that combines a
|
||||
# FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to
|
||||
# FusedMoEPrepareAndFinalizeModular and a FusedMoEExpertsModular to
|
||||
# provide the standard fused MoE kernel interface.
|
||||
# * TopKWeightAndReduce - A TopKWeightAndReduce implementation chosen
|
||||
# by the FusedMoEPermuteExpertsUnpermute implementation that is passed
|
||||
# by the FusedMoEExpertsModular implementation that is passed
|
||||
# on to [Finalize].
|
||||
#
|
||||
# [Quantize-Prepare] and [Finalize] functionality are bundled into a single
|
||||
# class `FusedMoEPrepareAndFinalize` since they could use collective
|
||||
# class `FusedMoEPrepareAndFinalizeModular` since they could use collective
|
||||
# communication mechanisms that need to be consistent.
|
||||
#
|
||||
|
||||
@@ -155,25 +156,96 @@ PrepareResultType = tuple[
|
||||
torch.Tensor | None,
|
||||
]
|
||||
|
||||
#
|
||||
# PrepareResultType is a tuple of:
|
||||
# - quantized + dispatched a.
|
||||
# - quantized + dispatched a1_scales.
|
||||
# - dispatched router logits.
|
||||
#
|
||||
# See `prepare_monolithic` method below.
|
||||
#
|
||||
PrepareMonolithicResultType = tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor,
|
||||
]
|
||||
|
||||
ReceiverType = Callable[[], PrepareResultType]
|
||||
|
||||
################################################################################
|
||||
# Prepare/Finalize
|
||||
################################################################################
|
||||
|
||||
|
||||
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
|
||||
class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
|
||||
described above.
|
||||
|
||||
There are two variants of this class:
|
||||
* FusedMoEPrepareAndFinalizeModular - this operates on topk ids and weights
|
||||
* FusedMoEPrepareAndFinalizeMonolithic - the operates on router_logits
|
||||
"""
|
||||
|
||||
def post_init_setup(self, fused_experts: "FusedMoEPermuteExpertsUnpermute"):
|
||||
def post_init_setup(self, fused_experts: "FusedMoEExperts"):
|
||||
"""
|
||||
Initialize FusedMoEPrepareAndFinalize settings that depend on
|
||||
FusedMoEPermuteExpertsUnpermute experts object.
|
||||
The FusedMoEPrepareAndFinalize implementations that have such
|
||||
Initialize FusedMoEPrepareAndFinalizeModular settings that depend on
|
||||
FusedMoEExpertsModular experts object.
|
||||
The FusedMoEPrepareAndFinalizeModular implementations that have such
|
||||
dependencies may choose to override this function.
|
||||
"""
|
||||
return
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def activation_format(self) -> FusedMoEActivationFormat:
|
||||
"""
|
||||
A property indicating the output format of the activations for the
|
||||
'prepare' method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
"""
|
||||
The PrepareFinalize All2All implementations generally constrain the
|
||||
dtype of the topk_ids they support. This function returns the
|
||||
required topk indices dtype so it can be respected.
|
||||
Return None if there are no such restrictions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
"""
|
||||
Some PrepareFinalize All2All implementations are batched. Meaning,
|
||||
they can process only as set of tokens at a time. This
|
||||
function returns the batch size i.e the maximum number of tokens
|
||||
the implementation can process at a time.
|
||||
Return None if there are no such restrictions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def num_dispatchers(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def output_is_reduced(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not the output of finalize is reduced across all
|
||||
ranks.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
|
||||
class FusedMoEPrepareAndFinalizeModular(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
|
||||
described above for the Modular case.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def prepare(
|
||||
self,
|
||||
@@ -198,7 +270,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
activations, before quantization + dispatching.
|
||||
- quant_config: Quantization info provided by the fused experts.
|
||||
- defer_input_quant: Runtime parameter indicating whether or not to
|
||||
defer input quantization to the FusedMoEPermuteExpertsUnpermute
|
||||
defer input quantization to the FusedMoEExpertsModular
|
||||
in cases where the compute kernel expects unquantized inputs
|
||||
|
||||
Returns a tuple of:
|
||||
@@ -245,7 +317,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
- apply_router_weight_on_input: When True, apply the weights to the
|
||||
activations, before quantization + dispatching.
|
||||
- defer_input_quant: Runtime parameter indicating whether or not to
|
||||
defer input quantization to the FusedMoEPermuteExpertsUnpermute
|
||||
defer input quantization to the FusedMoEExpertsModular
|
||||
in cases where the compute kernel expects unquantized inputs
|
||||
|
||||
Returns a callback or a hook callback pair that when invoked waits for
|
||||
@@ -338,56 +410,58 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
|
||||
class FusedMoEPrepareAndFinalizeMonolithic(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
|
||||
described above for the monolithic case.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def activation_format(self) -> FusedMoEActivationFormat:
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> PrepareMonolithicResultType:
|
||||
"""
|
||||
A property indicating the output format of the activations for the
|
||||
'prepare' method.
|
||||
Optional method for subclasses compatible with monolithic
|
||||
FusedMoEExpertsModular kernels.
|
||||
|
||||
Perform any quantization (and/or) dispatching needed for this kernel.
|
||||
- a1: The (unquantized) input to the MoE layer.
|
||||
- quant_config: Quantization info provided by the fused experts.
|
||||
- defer_input_quant: Runtime parameter indicating whether or not to
|
||||
defer input quantization to the FusedMoEExpertsModular
|
||||
|
||||
Returns a tuple of:
|
||||
- quantized + dispatched a.
|
||||
- Optional quantized + dispatched a1_scales.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
def finalize(self, fused_expert_output: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
The PrepareFinalize All2All implementations generally constrain the
|
||||
dtype of the topk_ids they support. This function returns the
|
||||
required topk indices dtype so it can be respected.
|
||||
Return None if there are no such restrictions.
|
||||
Optional method for subclasses compatible with monolithic
|
||||
FusedMoEExpertsModular kernels.
|
||||
|
||||
Perform any combine plus apply weights and perform a reduction on the
|
||||
fused experts output.
|
||||
- fused_expert_output: The unweighted, unreduced output of the fused
|
||||
experts, it will have (M, topk, K) shape.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
"""
|
||||
Some PrepareFinalize All2All implementations are batched. Meaning,
|
||||
they can process only as set of tokens at a time. This
|
||||
function returns the batch size i.e the maximum number of tokens
|
||||
the implementation can process at a time.
|
||||
Return None if there are no such restrictions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def num_dispatchers(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def output_is_reduced(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not the output of finalize is reduced across all
|
||||
ranks.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
################################################################################
|
||||
# Experts
|
||||
################################################################################
|
||||
|
||||
|
||||
# TODO: add supported activations method (return string)
|
||||
class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
An abstract base class for the [Permute-Experts-Unpermute] step described
|
||||
above.
|
||||
"""
|
||||
|
||||
class FusedMoEExperts(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
@@ -419,6 +493,10 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.num_dispatchers = num_dispatchers
|
||||
|
||||
@staticmethod
|
||||
def is_monolithic() -> bool:
|
||||
raise NotImplementedError("Implemented by subclasses.")
|
||||
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
"""
|
||||
@@ -439,49 +517,6 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def moe_problem_size(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> tuple[int, int, int, int, int]:
|
||||
"""
|
||||
Extract the MoE problem size from the given tensor arguments:
|
||||
- a: The hidden states, input to the MoE layer.
|
||||
- w1: The first set of expert weights.
|
||||
- w2: The second set of expert weights.
|
||||
- topk_ids: The topk ids.
|
||||
|
||||
Note: extracting the problem shape from the weight and activation
|
||||
tensors is not obvious. It needs to be done this way specifically
|
||||
due to subtle issues with particular kernels, e.g. the int4 kernels
|
||||
divide the trailing dimension by two, so it's not "correct" to
|
||||
extract N or K from the trailing dimension of w1 or w2. Similarly,
|
||||
some kernels transpose the weights, so this needs to be kept in mind.
|
||||
|
||||
Note: This implementation covers most cases. However, if experts
|
||||
require a specialized implementation, like MarlinExperts, they are free
|
||||
to override this function.
|
||||
"""
|
||||
assert w1.dim() == 3 and w2.dim() == 3
|
||||
E, N, _ = w1.size()
|
||||
K = a1.size(-1)
|
||||
|
||||
if a1.dim() == 2:
|
||||
# Make sure we are using the correct a1 (pre-permute).
|
||||
assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}"
|
||||
M = a1.size(0)
|
||||
else:
|
||||
assert a1.dim() == 3
|
||||
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
|
||||
M = a1.size(1) # This is max_num_tokens
|
||||
|
||||
assert topk_ids.dim() == 2
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
return E, M, N, K, topk
|
||||
|
||||
#
|
||||
# Various helpers for registering support for various features.
|
||||
# Used by the oracle to select a particular kernel for a deployment.
|
||||
@@ -489,7 +524,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
|
||||
@staticmethod
|
||||
def is_supported_config(
|
||||
cls: type["FusedMoEPermuteExpertsUnpermute"],
|
||||
cls: type["FusedMoEExperts"],
|
||||
moe_config: FusedMoEConfig,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
@@ -512,6 +547,21 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
return False, _make_reason(
|
||||
f"parallel config {moe_config.moe_parallel_config}"
|
||||
)
|
||||
elif not cls._supports_routing_method(
|
||||
moe_config.routing_method, weight_key, activation_key
|
||||
):
|
||||
return False, _make_reason(f"routing method {moe_config.routing_method}")
|
||||
elif not cls._supports_router_logits_dtype(
|
||||
moe_config.router_logits_dtype,
|
||||
moe_config.routing_method,
|
||||
):
|
||||
return False, _make_reason(
|
||||
f"router logits dtype {moe_config.router_logits_dtype}"
|
||||
)
|
||||
elif not cls._supports_shape(moe_config.hidden_dim):
|
||||
return False, _make_reason(
|
||||
f"{moe_config.hidden_dim} hidden dim is not supported"
|
||||
)
|
||||
elif activation_format != cls.activation_format():
|
||||
return False, _make_reason(f"{activation_format.value} activation format")
|
||||
return True, None
|
||||
@@ -554,10 +604,48 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
@abstractmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
"""
|
||||
Whether the kernel supports deployment in expert parallel.
|
||||
Whether the kernel supports deployment in particular parallel config.
|
||||
|
||||
Can be overriden if a kernel does not support EP, SP or some other
|
||||
configuration.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _supports_routing_method(
|
||||
routing_method: RoutingMethodType,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""
|
||||
Whether the kernel supports a routing method (e.g. GroupedTopK).
|
||||
|
||||
Can be overriden by monolithic kernels that execute the router
|
||||
in addition to the experts if certain routers are not supported.
|
||||
"""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_router_logits_dtype(
|
||||
router_logits_dtype: torch.dtype | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""
|
||||
Whether a kernel supports a particular dtype for router logits input.
|
||||
|
||||
Can be overriden by monolithic kernels that execute the router
|
||||
in addition to the experts if certain dtypes are not supported.
|
||||
"""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_shape(hidden_dim: int) -> bool:
|
||||
"""
|
||||
Whether a kernel supports a particular shape. Can be overridden if a kernel
|
||||
has specific shape requirements.
|
||||
"""
|
||||
return True
|
||||
|
||||
#
|
||||
# Various helpers for accessing quantization parameters from the
|
||||
# quant_config.
|
||||
@@ -654,6 +742,65 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
return False
|
||||
|
||||
def enable_chunking(self):
|
||||
return (
|
||||
envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and self.supports_chunking()
|
||||
)
|
||||
|
||||
|
||||
class FusedMoEExpertsModular(FusedMoEExperts):
|
||||
"""
|
||||
An abstract base class for the [Permute-Experts-Unpermute] step described
|
||||
above.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def is_monolithic() -> bool:
|
||||
return False
|
||||
|
||||
def moe_problem_size(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> tuple[int, int, int, int, int]:
|
||||
"""
|
||||
Extract the MoE problem size from the given tensor arguments:
|
||||
- a: The hidden states, input to the MoE layer.
|
||||
- w1: The first set of expert weights.
|
||||
- w2: The second set of expert weights.
|
||||
- topk_ids: The topk ids.
|
||||
|
||||
Note: extracting the problem shape from the weight and activation
|
||||
tensors is not obvious. It needs to be done this way specifically
|
||||
due to subtle issues with particular kernels, e.g. the int4 kernels
|
||||
divide the trailing dimension by two, so it's not "correct" to
|
||||
extract N or K from the trailing dimension of w1 or w2. Similarly,
|
||||
some kernels transpose the weights, so this needs to be kept in mind.
|
||||
|
||||
Note: This implementation covers most cases. However, if experts
|
||||
require a specialized implementation, like MarlinExperts, they are free
|
||||
to override this function.
|
||||
"""
|
||||
assert w1.dim() == 3 and w2.dim() == 3
|
||||
E, N, _ = w1.size()
|
||||
K = a1.size(-1)
|
||||
|
||||
if a1.dim() == 2:
|
||||
# Make sure we are using the correct a1 (pre-permute).
|
||||
assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}"
|
||||
M = a1.size(0)
|
||||
else:
|
||||
assert a1.dim() == 3
|
||||
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
|
||||
M = a1.size(1) # This is max_num_tokens
|
||||
|
||||
assert topk_ids.dim() == 2
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
return E, M, N, K, topk
|
||||
|
||||
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
|
||||
"""
|
||||
Workspace type: The dtype to use for the workspace tensors.
|
||||
@@ -726,11 +873,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
) -> None:
|
||||
apply_moe_activation(activation, output, input)
|
||||
|
||||
def enable_chunking(self):
|
||||
return (
|
||||
envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and self.supports_chunking()
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -791,6 +934,67 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FusedMoEExpertsMonolithic(FusedMoEExperts):
|
||||
"""
|
||||
An abstract base class for the [Permute-Experts-Unpermute] step described
|
||||
above, but with the monolithic interface (accepts router logits
|
||||
rather than topk ids and weights).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _supports_routing_method(
|
||||
routing_method: RoutingMethodType,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""
|
||||
Whether the kernel supports a routing method (e.g. GroupedTopK).
|
||||
|
||||
Monolithic kernels should explicitly opt-in to support.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _supports_router_logits_dtype(
|
||||
router_logits_dtype: torch.dtype | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""
|
||||
Whether the kernel supports a dtype for router logits.
|
||||
|
||||
Modular kernels should opt-in to support.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def is_monolithic() -> bool:
|
||||
return True
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Same as apply(), except uses router_logits as opposed
|
||||
to the topk_ids and topk_weights. This is useful for kernels
|
||||
with fused router and fused_experts (e.g. FLASHINFER_TRTLLM).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _slice_scales(
|
||||
scales: torch.Tensor | None, start: int, end: int
|
||||
) -> torch.Tensor | None:
|
||||
@@ -802,75 +1006,32 @@ def _slice_scales(
|
||||
return None
|
||||
|
||||
|
||||
################################################################################
|
||||
# Kernel
|
||||
################################################################################
|
||||
|
||||
|
||||
@final
|
||||
class FusedMoEModularKernel(torch.nn.Module):
|
||||
"""
|
||||
This class combines a FusedMoEPrepareAndFinalize instance and
|
||||
a FusedMoEPermuteExpertsUnpermute to provide an interface that
|
||||
is compatible with the `fused_experts` function in fused_moe.py.
|
||||
|
||||
It takes care of managing any required scratch space.
|
||||
|
||||
Note: Instances of this class should only be used for a single model
|
||||
layer due to any layer specific state that may be used by the component
|
||||
objects.
|
||||
"""
|
||||
|
||||
class FusedMoEKernelModularImpl:
|
||||
def __init__(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
fused_experts: FusedMoEPermuteExpertsUnpermute,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
|
||||
fused_experts: FusedMoEExpertsModular,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
moe_parallel_config: FusedMoEParallelConfig | None = None,
|
||||
inplace: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
self.shared_experts = shared_experts
|
||||
self.moe_parallel_config = moe_parallel_config
|
||||
self.inplace = inplace
|
||||
|
||||
# prefer an explicit FusedMoEParallelConfig when available (from
|
||||
# FusedMoE layers / tests).
|
||||
# if not provided, assume this kernel is
|
||||
# running in a non-DP+EP context
|
||||
self.moe_parallel_config: FusedMoEParallelConfig | None = moe_parallel_config
|
||||
self.is_dp_ep = (
|
||||
moe_parallel_config is not None
|
||||
and moe_parallel_config.dp_size > 1
|
||||
and moe_parallel_config.use_ep
|
||||
)
|
||||
|
||||
self._post_init_setup()
|
||||
assert (
|
||||
prepare_finalize.activation_format == fused_experts.activation_format()
|
||||
), (
|
||||
f"{prepare_finalize.__class__.__name__}."
|
||||
f"{prepare_finalize.activation_format} == "
|
||||
f"{fused_experts.__class__.__name__}."
|
||||
f"{fused_experts.activation_format()}"
|
||||
)
|
||||
|
||||
def _post_init_setup(self):
|
||||
"""
|
||||
Resolve any leftover setup dependencies between self.prepare_finalize
|
||||
and self.fused_experts here.
|
||||
"""
|
||||
self.prepare_finalize.post_init_setup(self.fused_experts)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
"""
|
||||
A flag indicating whether or not this class supports expert maps.
|
||||
"""
|
||||
return self.fused_experts.supports_expert_map()
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not the output of fused MoE kernel
|
||||
is reduced across all ranks.
|
||||
"""
|
||||
return self.prepare_finalize.output_is_reduced()
|
||||
|
||||
def _chunk_info(self, M: int) -> tuple[int, int]:
|
||||
"""
|
||||
Compute number of chunks and chunk size for given M.
|
||||
@@ -919,7 +1080,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
workspace_dtype = self.fused_experts.workspace_dtype(out_dtype)
|
||||
|
||||
# Force worst-case allocation in profiling run for
|
||||
# "mk.FusedMoEModularKernel.Standard" formats where this is only bounded
|
||||
# "mk.FusedMoEKernel.Standard" formats where this is only bounded
|
||||
# by `VLLM_FUSED_MOE_CHUNK_SIZE` and may not be seen during profiling with
|
||||
# DP+EP due to the random token routing.
|
||||
is_profile_run = (
|
||||
@@ -1313,13 +1474,13 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
assert shared_output is not None
|
||||
return shared_output, output
|
||||
|
||||
def forward(
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
activation: MoEActivation = MoEActivation.SILU,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
@@ -1334,8 +1495,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
- hidden_states: (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- topk_weights (torch.Tensor): The topk weights applied at the end of
|
||||
the layer.
|
||||
- topk_weights (torch.Tensor): The topk weights applied at the end of the layer.
|
||||
- topk_ids (torch.Tensor): A map of row to expert id.
|
||||
- activation (MoEActivation): The activation function to apply after the first
|
||||
MoE layer.
|
||||
@@ -1354,7 +1514,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
|
||||
if self.inplace:
|
||||
assert self.shared_experts is None
|
||||
assert not disable_inplace()
|
||||
@@ -1400,3 +1559,206 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
apply_router_weight_on_input,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class FusedMoEKernelMonolithicImpl:
|
||||
def __init__(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalizeMonolithic,
|
||||
fused_experts: FusedMoEExpertsMonolithic,
|
||||
):
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Same as forward(), except uses router_logits as opposed
|
||||
to the topk_ids and topk_weights. This is used for kernels
|
||||
that have fused router + experts (e.g. FLASHINFER_TRTLLM).
|
||||
"""
|
||||
|
||||
# TODO(rob): add inplace support.
|
||||
a1q, a1q_scale, router_logits = self.prepare_finalize.prepare(
|
||||
hidden_states,
|
||||
router_logits=router_logits,
|
||||
quant_config=self.fused_experts.quant_config,
|
||||
defer_input_quant=self.fused_experts.expects_unquantized_inputs,
|
||||
)
|
||||
|
||||
fused_out = self.fused_experts.apply(
|
||||
hidden_states=a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
router_logits=router_logits,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
a1q_scale=a1q_scale,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group=num_expert_group,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
topk_group=topk_group,
|
||||
)
|
||||
|
||||
output = self.prepare_finalize.finalize(fused_out)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@final
|
||||
class FusedMoEKernel:
|
||||
def __init__(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
fused_experts: FusedMoEExperts,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
moe_parallel_config: FusedMoEParallelConfig | None = None,
|
||||
inplace: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.shared_experts = shared_experts # NOTE: check if we can remove
|
||||
|
||||
# Initialize the implementation (monolithic or modular).
|
||||
self.impl: FusedMoEKernelModularImpl | FusedMoEKernelMonolithicImpl
|
||||
if isinstance(
|
||||
prepare_finalize, FusedMoEPrepareAndFinalizeModular
|
||||
) and isinstance(fused_experts, FusedMoEExpertsModular):
|
||||
self.impl = FusedMoEKernelModularImpl(
|
||||
prepare_finalize,
|
||||
fused_experts,
|
||||
shared_experts,
|
||||
moe_parallel_config,
|
||||
inplace,
|
||||
)
|
||||
|
||||
elif isinstance(
|
||||
prepare_finalize, FusedMoEPrepareAndFinalizeMonolithic
|
||||
) and isinstance(fused_experts, FusedMoEExpertsMonolithic):
|
||||
assert shared_experts is None
|
||||
assert not inplace
|
||||
self.impl = FusedMoEKernelMonolithicImpl(
|
||||
prepare_finalize,
|
||||
fused_experts,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"prepare_finalize and fused_experts must both be either monolithic "
|
||||
f"or non-monolithic but got {prepare_finalize.__class__.__name__} "
|
||||
f"and {fused_experts.__class__.__name__}"
|
||||
)
|
||||
|
||||
self._post_init_setup()
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return isinstance(self.impl, FusedMoEKernelMonolithicImpl)
|
||||
|
||||
@property
|
||||
def prepare_finalize(self) -> FusedMoEPrepareAndFinalize:
|
||||
return self.impl.prepare_finalize
|
||||
|
||||
@property
|
||||
def fused_experts(self) -> FusedMoEExperts:
|
||||
return self.impl.fused_experts
|
||||
|
||||
def _post_init_setup(self):
|
||||
"""
|
||||
Resolve any leftover setup dependencies between self.prepare_finalize
|
||||
and self.fused_experts here.
|
||||
"""
|
||||
self.prepare_finalize.post_init_setup(self.impl.fused_experts)
|
||||
assert (
|
||||
self.prepare_finalize.activation_format
|
||||
== self.fused_experts.activation_format()
|
||||
)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
"""
|
||||
A flag indicating whether or not this class supports expert maps.
|
||||
"""
|
||||
return self.fused_experts.supports_expert_map()
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
"""
|
||||
Indicates whether or not the output of fused MoE kernel
|
||||
is reduced across all ranks.
|
||||
"""
|
||||
return self.prepare_finalize.output_is_reduced()
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(self.impl, FusedMoEKernelMonolithicImpl)
|
||||
return self.impl.apply(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
router_logits=router_logits,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
num_expert_group=num_expert_group,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
topk_group=topk_group,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
shared_experts_input: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(self.impl, FusedMoEKernelModularImpl)
|
||||
return self.impl.apply(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
@@ -12,7 +12,7 @@ from vllm.platforms import current_platform
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
"""
|
||||
Prepare/Finalize using MoRI kernels.
|
||||
"""
|
||||
|
||||
@@ -18,13 +18,9 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config,
|
||||
fp8_w8a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
|
||||
is_supported_config_trtllm_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
FlashinferMoeBackend,
|
||||
get_flashinfer_moe_backend,
|
||||
make_fp8_moe_alpha_scales_for_fi,
|
||||
prepare_fp8_moe_layer_for_fi,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
@@ -103,9 +99,13 @@ def _get_priority_backends(
|
||||
|
||||
def backend_to_kernel_cls(
|
||||
backend: Fp8MoeBackend,
|
||||
) -> type[mk.FusedMoEPermuteExpertsUnpermute]:
|
||||
) -> type[mk.FusedMoEExperts]:
|
||||
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
raise NotImplementedError
|
||||
from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import ( # noqa: E501
|
||||
TrtLlmFp8Experts,
|
||||
)
|
||||
|
||||
return TrtLlmFp8Experts
|
||||
|
||||
elif backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
@@ -205,13 +205,11 @@ def select_fp8_moe_backend(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
allow_vllm_cutlass: bool = False,
|
||||
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]:
|
||||
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts] | None]:
|
||||
"""
|
||||
Select the primary FP8 MoE backend
|
||||
Note: Shape-specific fallbacks may still occur at runtime.
|
||||
"""
|
||||
k_cls: type[mk.FusedMoEPermuteExpertsUnpermute] | None = None
|
||||
|
||||
if config.is_lora_enabled:
|
||||
return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON)
|
||||
|
||||
@@ -252,7 +250,7 @@ def select_fp8_moe_backend(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]:
|
||||
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]:
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls, config, weight_key, activation_key, activation_format
|
||||
@@ -287,16 +285,6 @@ def select_fp8_moe_backend(
|
||||
"vLLM CUTLASS FP8 MoE backend is disabled for this configuration."
|
||||
)
|
||||
|
||||
# Handle FLASHINFER_TRTLLM specially (no kernel class).
|
||||
if requested_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
supported, reason = is_supported_config_trtllm_fp8(
|
||||
config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(requested_backend))
|
||||
return requested_backend, None
|
||||
raise ValueError(_make_log_unsupported(requested_backend, reason))
|
||||
|
||||
return _return_or_raise(
|
||||
requested_backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
@@ -311,51 +299,32 @@ def select_fp8_moe_backend(
|
||||
elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
|
||||
# If user is explicit about backend, validate it.
|
||||
fi_backend = get_flashinfer_moe_backend()
|
||||
|
||||
if fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
backend = Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
supported, reason = is_supported_config_trtllm_fp8(
|
||||
config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend))
|
||||
return backend, None
|
||||
else:
|
||||
raise ValueError(_make_log_unsupported(backend, reason))
|
||||
|
||||
elif fi_backend == FlashinferMoeBackend.CUTLASS:
|
||||
if fi_backend == FlashinferMoeBackend.CUTLASS:
|
||||
backend = Fp8MoeBackend.FLASHINFER_CUTLASS
|
||||
return _return_or_raise(
|
||||
backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
|
||||
elif fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
backend = Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
else:
|
||||
assert fi_backend == FlashinferMoeBackend.CUTEDSL
|
||||
raise ValueError("FlashInfer MaskedGEMM not supported for FP8")
|
||||
|
||||
raise ValueError(
|
||||
f"FlashInfer MOE backend {fi_backend} does not support FP8 MoE."
|
||||
)
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
return _return_or_raise(
|
||||
backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
else:
|
||||
# If the user is not explicit about the backend, try both.
|
||||
for backend in [
|
||||
Fp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
Fp8MoeBackend.FLASHINFER_CUTLASS,
|
||||
]:
|
||||
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
k_cls = None
|
||||
supported, reason = is_supported_config_trtllm_fp8(
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
else:
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls,
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls,
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
@@ -408,23 +377,14 @@ def select_fp8_moe_backend(
|
||||
|
||||
# Select kernels in order of backend.
|
||||
for backend in AVAILABLE_BACKENDS:
|
||||
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
k_cls = None
|
||||
supported, reason = is_supported_config_trtllm_fp8(
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
else:
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls,
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls,
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
@@ -510,7 +470,7 @@ def make_fp8_moe_quant_config(
|
||||
block_shape: list[int] | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
per_out_ch_quant: bool = False,
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Create FusedMoEQuantConfig for the specified FP8 Backend.
|
||||
The FusedMoEQuantConfig holds the scales that are used
|
||||
@@ -523,9 +483,6 @@ def make_fp8_moe_quant_config(
|
||||
In a future PR, we will have this function should be
|
||||
a method of the modular kernel itself.
|
||||
"""
|
||||
# TRTLLM does not use Modular Kernel abstraction yet.
|
||||
if fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
return None
|
||||
|
||||
# MARLIN is mixed precision W8A16 config.
|
||||
if fp8_backend == Fp8MoeBackend.MARLIN:
|
||||
@@ -539,12 +496,6 @@ def make_fp8_moe_quant_config(
|
||||
# (alpha = w_scale * a_scale) and inverse a2 scale.
|
||||
if fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS and block_shape is None:
|
||||
assert a1_scale is not None and a2_scale is not None
|
||||
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
|
||||
w1_scale,
|
||||
a1_scale,
|
||||
w2_scale,
|
||||
a2_scale,
|
||||
)
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
@@ -552,8 +503,8 @@ def make_fp8_moe_quant_config(
|
||||
a2_scale=a2_scale,
|
||||
a1_gscale=(1.0 / a1_scale),
|
||||
a2_gscale=(1.0 / a2_scale),
|
||||
g1_alphas=g1_alphas,
|
||||
g2_alphas=g2_alphas,
|
||||
g1_alphas=(w1_scale * a1_scale).squeeze(),
|
||||
g2_alphas=(w2_scale * a2_scale).squeeze(),
|
||||
)
|
||||
# All other backends use normal config.
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
@@ -570,17 +521,18 @@ def make_fp8_moe_quant_config(
|
||||
def make_fp8_moe_kernel(
|
||||
moe_quant_config: FusedMoEQuantConfig,
|
||||
moe_config: FusedMoEConfig,
|
||||
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
experts_cls: type[mk.FusedMoEExperts],
|
||||
fp8_backend: Fp8MoeBackend,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
) -> mk.FusedMoEKernel:
|
||||
# Create Prepare/Finalize.
|
||||
prepare_finalize = maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=moe_quant_config,
|
||||
routing_tables=routing_tables,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=issubclass(experts_cls, mk.FusedMoEExpertsMonolithic),
|
||||
)
|
||||
assert prepare_finalize is not None
|
||||
|
||||
@@ -605,7 +557,7 @@ def make_fp8_moe_kernel(
|
||||
# NOTE(rob): we only want the mk to control the shared_expert
|
||||
# if using all2all (for SBO). bnell is making this explicit in
|
||||
# the new MoE runner class.
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
kernel = mk.FusedMoEKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
shared_experts=(
|
||||
|
||||
@@ -19,7 +19,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
nvfp4_w4a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
is_supported_config_trtllm,
|
||||
prepare_nvfp4_moe_layer_for_fi_or_cutlass,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
@@ -67,39 +66,46 @@ def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool:
|
||||
|
||||
def backend_to_kernel_cls(
|
||||
backend: NvFp4MoeBackend,
|
||||
) -> type[mk.FusedMoEPermuteExpertsUnpermute]:
|
||||
) -> list[type[mk.FusedMoEExperts]]:
|
||||
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
raise NotImplementedError(
|
||||
"FLASHINFER_TRTLLM doesn't support Modular Kernel Interface"
|
||||
from vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe import (
|
||||
TrtLlmNvFp4ExpertsModular,
|
||||
TrtLlmNvFp4ExpertsMonolithic,
|
||||
)
|
||||
|
||||
# NOTE: prefer Monolthic > Modular, so return Monolithic first.
|
||||
return [
|
||||
TrtLlmNvFp4ExpertsMonolithic,
|
||||
TrtLlmNvFp4ExpertsModular,
|
||||
]
|
||||
|
||||
elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
)
|
||||
|
||||
return FlashInferExperts
|
||||
return [FlashInferExperts]
|
||||
|
||||
elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
|
||||
FlashInferCuteDSLExperts,
|
||||
)
|
||||
|
||||
return FlashInferCuteDSLExperts
|
||||
return [FlashInferCuteDSLExperts]
|
||||
|
||||
elif backend == NvFp4MoeBackend.VLLM_CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
CutlassExpertsFp4,
|
||||
)
|
||||
|
||||
return CutlassExpertsFp4
|
||||
return [CutlassExpertsFp4]
|
||||
|
||||
elif backend == NvFp4MoeBackend.MARLIN:
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
MarlinExperts,
|
||||
)
|
||||
|
||||
return MarlinExperts
|
||||
return [MarlinExperts]
|
||||
else:
|
||||
raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}")
|
||||
|
||||
@@ -125,7 +131,7 @@ def select_nvfp4_moe_backend(
|
||||
config: FusedMoEConfig,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]:
|
||||
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]:
|
||||
"""
|
||||
Select the primary NvFP4 MoE backend
|
||||
Note: Shape-specific fallbacks may still occur at runtime.
|
||||
@@ -175,29 +181,21 @@ def select_nvfp4_moe_backend(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]:
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend))
|
||||
return backend, k_cls
|
||||
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]:
|
||||
for k_cls in backend_to_kernel_cls(backend):
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend))
|
||||
return backend, k_cls
|
||||
|
||||
raise ValueError(_make_log_unsupported(backend, reason))
|
||||
|
||||
# Handle explicit moe_backend from user.
|
||||
runner_backend = config.moe_backend
|
||||
if runner_backend != "auto":
|
||||
requested_backend = map_nvfp4_backend(runner_backend)
|
||||
if requested_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
supported, reason = is_supported_config_trtllm(
|
||||
config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(requested_backend))
|
||||
return requested_backend, None
|
||||
raise ValueError(_make_log_unsupported(requested_backend, reason))
|
||||
|
||||
return _return_or_raise(
|
||||
requested_backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
@@ -210,36 +208,14 @@ def select_nvfp4_moe_backend(
|
||||
|
||||
elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
|
||||
# If user is explicit about backend, validate it.
|
||||
fi_backend = get_flashinfer_moe_backend()
|
||||
|
||||
if fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
backend = NvFp4MoeBackend.FLASHINFER_TRTLLM
|
||||
supported, reason = is_supported_config_trtllm(
|
||||
config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend))
|
||||
return backend, None
|
||||
else:
|
||||
raise ValueError(_make_log_unsupported(backend, reason))
|
||||
else:
|
||||
backend = fi_2_vllm_backend_map[fi_backend]
|
||||
return _return_or_raise(
|
||||
backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()]
|
||||
return _return_or_raise(
|
||||
backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
else:
|
||||
# If the user is not explicit about the backend, try each.
|
||||
for backend in FLASHINFER_NVFP4_MOE_BACKENDS:
|
||||
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
k_cls = None
|
||||
supported, reason = is_supported_config_trtllm(
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
else:
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
for k_cls in backend_to_kernel_cls(backend):
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls,
|
||||
config,
|
||||
@@ -247,13 +223,13 @@ def select_nvfp4_moe_backend(
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend, None
|
||||
else:
|
||||
logger.debug_once(
|
||||
_make_log_unsupported(backend, reason), scope="local"
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend, k_cls
|
||||
else:
|
||||
logger.debug_once(
|
||||
_make_log_unsupported(backend, reason), scope="local"
|
||||
)
|
||||
|
||||
raise NotImplementedError(
|
||||
"Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no "
|
||||
@@ -268,16 +244,7 @@ def select_nvfp4_moe_backend(
|
||||
|
||||
# Select kernels in order of backend.
|
||||
for backend in AVAILABLE_BACKENDS:
|
||||
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
k_cls = None # type: ignore[assignment]
|
||||
supported, reason = is_supported_config_trtllm(
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
else:
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
for k_cls in backend_to_kernel_cls(backend):
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls,
|
||||
config,
|
||||
@@ -286,11 +253,11 @@ def select_nvfp4_moe_backend(
|
||||
activation_format,
|
||||
)
|
||||
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend, k_cls
|
||||
else:
|
||||
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend, k_cls
|
||||
else:
|
||||
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
|
||||
|
||||
raise NotImplementedError(
|
||||
"No NvFp4 MoE backend supports the deployment configuration."
|
||||
@@ -398,12 +365,8 @@ def make_nvfp4_moe_quant_config(
|
||||
w2_scale_2: torch.Tensor,
|
||||
a13_scale: torch.Tensor,
|
||||
a2_scale: torch.Tensor,
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
UNSUPPORTED = [NvFp4MoeBackend.FLASHINFER_TRTLLM]
|
||||
if backend in UNSUPPORTED:
|
||||
return None
|
||||
|
||||
elif backend == NvFp4MoeBackend.MARLIN:
|
||||
) -> FusedMoEQuantConfig:
|
||||
if backend == NvFp4MoeBackend.MARLIN:
|
||||
return nvfp4_w4a16_moe_quant_config(
|
||||
g1_alphas=w13_scale_2,
|
||||
g2_alphas=w2_scale_2,
|
||||
@@ -420,22 +383,27 @@ def make_nvfp4_moe_quant_config(
|
||||
a2_gscale=(1.0 / a2_scale),
|
||||
w1_scale=w13_scale,
|
||||
w2_scale=w2_scale,
|
||||
# NOTE(rob): this is a hack until the MoE kernels
|
||||
# create their own quant configs. TRTLLM kernel
|
||||
# does not accept swizzled input quant scales.
|
||||
is_nvfp4_scale_swizzled=(backend != NvFp4MoeBackend.FLASHINFER_TRTLLM),
|
||||
)
|
||||
|
||||
|
||||
def make_nvfp4_moe_kernel(
|
||||
moe_quant_config: FusedMoEQuantConfig,
|
||||
moe_config: FusedMoEConfig,
|
||||
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
experts_cls: type[mk.FusedMoEExperts],
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
) -> mk.FusedMoEKernel:
|
||||
# Create Prepare/Finalize.
|
||||
prepare_finalize = maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=moe_quant_config,
|
||||
routing_tables=routing_tables,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=issubclass(experts_cls, mk.FusedMoEExpertsMonolithic),
|
||||
)
|
||||
assert prepare_finalize is not None
|
||||
|
||||
@@ -460,7 +428,7 @@ def make_nvfp4_moe_kernel(
|
||||
# NOTE(rob): we only want the mk to control the shared_expert
|
||||
# if using all2all (for SBO). bnell is making this explicit in
|
||||
# the new MoE runner class.
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
kernel = mk.FusedMoEKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
shared_experts=(
|
||||
|
||||
@@ -19,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
|
||||
is_supported_config_trtllm_bf16,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
MoEPrepareAndFinalizeNoDPEPModular,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
swap_w13_to_w31,
|
||||
@@ -209,7 +209,7 @@ def make_unquantized_moe_kernel(
|
||||
backend: UnquantizedMoeBackend,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
moe_config: FusedMoEConfig,
|
||||
) -> mk.FusedMoEModularKernel | None:
|
||||
) -> mk.FusedMoEKernel | None:
|
||||
if backend in UNSUPPORTED_BACKEND:
|
||||
return None
|
||||
|
||||
@@ -218,8 +218,8 @@ def make_unquantized_moe_kernel(
|
||||
FlashInferExperts,
|
||||
)
|
||||
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
kernel = mk.FusedMoEKernel(
|
||||
MoEPrepareAndFinalizeNoDPEPModular(),
|
||||
FlashInferExperts(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
@@ -232,8 +232,8 @@ def make_unquantized_moe_kernel(
|
||||
AiterExperts,
|
||||
)
|
||||
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
kernel = mk.FusedMoEKernel(
|
||||
MoEPrepareAndFinalizeNoDPEPModular(),
|
||||
AiterExperts(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
@@ -243,8 +243,8 @@ def make_unquantized_moe_kernel(
|
||||
elif backend == UnquantizedMoeBackend.TRITON:
|
||||
from vllm.model_executor.layers.fused_moe import TritonExperts
|
||||
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
kernel = mk.FusedMoEKernel(
|
||||
MoEPrepareAndFinalizeNoDPEPModular(),
|
||||
TritonExperts(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
@@ -254,8 +254,8 @@ def make_unquantized_moe_kernel(
|
||||
elif backend == UnquantizedMoeBackend.XPU:
|
||||
from vllm.model_executor.layers.fused_moe import XPUExperts
|
||||
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
kernel = mk.FusedMoEKernel(
|
||||
MoEPrepareAndFinalizeNoDPEPModular(),
|
||||
XPUExperts(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
|
||||
@@ -1,209 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceContiguous,
|
||||
TopKWeightAndReduceDelegate,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
|
||||
|
||||
|
||||
class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize):
|
||||
def __init__(
|
||||
self,
|
||||
is_sequence_parallel: bool = False,
|
||||
num_dispatchers: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_sequence_parallel = is_sequence_parallel
|
||||
self._num_dispatchers = num_dispatchers
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return None
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return self._num_dispatchers
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return False
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareResultType:
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
assert topk == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
)
|
||||
# Note: do not use inplace for shared experts overlap
|
||||
a1 = a1 * topk_weights.to(a1.dtype)
|
||||
|
||||
# Defer input quantization to the MoE kernel.
|
||||
use_nvfp4 = quant_config.use_nvfp4_w4a4
|
||||
if defer_input_quant:
|
||||
a1q = a1
|
||||
a1q_scale = None
|
||||
else:
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
quant_config.a1_gscale if use_nvfp4 else quant_config.a1_scale,
|
||||
quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
# NOTE: swizzling pads the scales to multiple of 128
|
||||
# which makes the scales tensor different shape than
|
||||
# the hidden states, breaking the A2A kernel. So, we
|
||||
# delay the swizzling until after the A2A.
|
||||
is_fp4_scale_swizzled=False,
|
||||
)
|
||||
|
||||
# Skip gathering scales if we have static quantization
|
||||
# (the scale is a scalar, replicated on all ranks) or
|
||||
# if quantization is deferred.
|
||||
skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0
|
||||
scales = None if skip_gather_scales else [a1q_scale]
|
||||
|
||||
res = get_ep_group().dispatch(
|
||||
a1q,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
extra_tensors=scales,
|
||||
)
|
||||
if skip_gather_scales:
|
||||
a1q, topk_weights, topk_ids = res
|
||||
else:
|
||||
a1q, topk_weights, topk_ids, scales = res
|
||||
assert scales is not None and len(scales) == 1
|
||||
a1q_scale = scales[0]
|
||||
if quant_config.quant_dtype == "nvfp4":
|
||||
assert a1q_scale is not None
|
||||
if a1q_scale.element_size() == 1:
|
||||
a1q_scale = a1q_scale.view(torch.uint8)
|
||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||
|
||||
return a1q, a1q_scale, None, topk_ids, topk_weights
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
|
||||
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
|
||||
|
||||
out = weight_and_reduce_impl.apply(
|
||||
output=None,
|
||||
fused_expert_output=fused_expert_output,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
output.copy_(
|
||||
get_ep_group().combine(out, is_sequence_parallel=self.is_sequence_parallel)
|
||||
)
|
||||
|
||||
|
||||
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
"""MoE prepare and finalize without expert parallelism."""
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return None
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return 1
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return False
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareResultType:
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
# TODO: this only works for topK=1, will need to update for topK>1
|
||||
assert topk == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
)
|
||||
# Note: do not use inplace for shared experts overlap
|
||||
a1 = a1 * topk_weights.to(a1.dtype)
|
||||
|
||||
# Defer input quant to moe kernel for backends (e.g. AITER, FI)
|
||||
# which use a single kernel call for quant + experts.
|
||||
if defer_input_quant:
|
||||
return a1, None, None, None, None
|
||||
|
||||
input_sf = (
|
||||
quant_config.a1_gscale
|
||||
if quant_config.use_nvfp4_w4a4
|
||||
else quant_config.a1_scale
|
||||
)
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
input_sf,
|
||||
quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
)
|
||||
|
||||
return a1q, a1q_scale, None, None, None
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
|
||||
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
|
||||
weight_and_reduce_impl.apply(
|
||||
output=output,
|
||||
fused_expert_output=fused_expert_output,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
@@ -0,0 +1,22 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize.naive_dp_ep import (
|
||||
MoEPrepareAndFinalizeNaiveDPEPModular,
|
||||
MoEPrepareAndFinalizeNaiveDPEPMonolithic,
|
||||
make_moe_prepare_and_finalize_naive_dp_ep,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize.no_dp_ep import (
|
||||
MoEPrepareAndFinalizeNoDPEPModular,
|
||||
MoEPrepareAndFinalizeNoDPEPMonolithic,
|
||||
make_moe_prepare_and_finalize_no_dp_ep,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MoEPrepareAndFinalizeNaiveDPEPMonolithic",
|
||||
"MoEPrepareAndFinalizeNaiveDPEPModular",
|
||||
"make_moe_prepare_and_finalize_naive_dp_ep",
|
||||
"MoEPrepareAndFinalizeNoDPEPMonolithic",
|
||||
"MoEPrepareAndFinalizeNoDPEPModular",
|
||||
"make_moe_prepare_and_finalize_no_dp_ep",
|
||||
]
|
||||
@@ -0,0 +1,253 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceContiguous,
|
||||
TopKWeightAndReduceDelegate,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
|
||||
|
||||
|
||||
def _quantize_and_setup_dispatch(
|
||||
a1: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, list[torch.Tensor] | None]:
|
||||
# Defer input quantization to the MoE kernel.
|
||||
if defer_input_quant:
|
||||
a1q = a1
|
||||
a1q_scale = None
|
||||
else:
|
||||
input_sf = (
|
||||
quant_config.a1_gscale
|
||||
if quant_config.use_nvfp4_w4a4
|
||||
else quant_config.a1_scale
|
||||
)
|
||||
|
||||
# NOTE: swizzling pads the scales to multiple of 128
|
||||
# which makes the scales tensor different shape than
|
||||
# the hidden states, breaking the A2A kernel. So, we
|
||||
# delay the swizzling until after the A2A.
|
||||
a1q, a1q_scale = a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
input_sf,
|
||||
quant_dtype=quant_config.quant_dtype,
|
||||
per_act_token_quant=quant_config.per_act_token_quant,
|
||||
block_shape=quant_config.block_shape,
|
||||
is_fp4_scale_swizzled=False,
|
||||
)
|
||||
|
||||
# Skip gathering scales if we have static quantization
|
||||
# (the scale is a scalar, replicated on all ranks) or
|
||||
# if quantization is deferred.
|
||||
skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0
|
||||
scales = None if skip_gather_scales else [a1q_scale]
|
||||
|
||||
return a1q, scales
|
||||
|
||||
|
||||
def _unwrap_scale_and_prepare_for_moe(
|
||||
scales: list[torch.Tensor] | None,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> torch.Tensor:
|
||||
assert scales is not None and len(scales) == 1
|
||||
a1q_scale = scales[0]
|
||||
# Apply swizzling after a2a if the MoE kernel needs it.
|
||||
if quant_config.quant_dtype == "nvfp4" and quant_config.is_nvfp4_scale_swizzled:
|
||||
assert a1q_scale is not None
|
||||
if a1q_scale.element_size() == 1:
|
||||
a1q_scale = a1q_scale.view(torch.uint8)
|
||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||
|
||||
return a1q_scale
|
||||
|
||||
|
||||
class MoEPrepareAndFinalizeNaiveDPEPModular(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
"""
|
||||
Naive Prepare/Finalize for Dp/Ep case for Modular Kernels.
|
||||
|
||||
Uses Torch AR/RS or AR for dispatch/combine operations, applied
|
||||
to the topk weights and ids.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_sequence_parallel: bool = False,
|
||||
num_dispatchers: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_sequence_parallel = is_sequence_parallel
|
||||
self._num_dispatchers = num_dispatchers
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return None
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return self._num_dispatchers
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return False
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareResultType:
|
||||
"""Quantize and Dispatch Topk Weights and Topk Ids."""
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
assert topk == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
)
|
||||
# Note: do not use inplace for shared experts overlap
|
||||
a1 = a1 * topk_weights.to(a1.dtype)
|
||||
|
||||
a1q, scales = _quantize_and_setup_dispatch(a1, quant_config, defer_input_quant)
|
||||
|
||||
res = get_ep_group().dispatch(
|
||||
a1q,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
extra_tensors=scales,
|
||||
)
|
||||
|
||||
if scales is None:
|
||||
a1q, topk_weights, topk_ids = res
|
||||
a1q_scale = None
|
||||
else:
|
||||
a1q, topk_weights, topk_ids, scales = res
|
||||
a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config)
|
||||
|
||||
return a1q, a1q_scale, None, topk_ids, topk_weights
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
|
||||
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
|
||||
|
||||
out = weight_and_reduce_impl.apply(
|
||||
output=None,
|
||||
fused_expert_output=fused_expert_output,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
output.copy_(
|
||||
get_ep_group().combine(out, is_sequence_parallel=self.is_sequence_parallel)
|
||||
)
|
||||
|
||||
|
||||
class MoEPrepareAndFinalizeNaiveDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMonolithic):
|
||||
"""
|
||||
Naive Prepare/Finalize for Dp/Ep case for Modular Kernels.
|
||||
|
||||
Uses Torch AR/RS or AR for dispatch/combine operations, applied
|
||||
to the router logits (the MoE kernel runs the router internally).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_sequence_parallel: bool = False,
|
||||
num_dispatchers: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_sequence_parallel = is_sequence_parallel
|
||||
self._num_dispatchers = num_dispatchers
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return None
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return self._num_dispatchers
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return False
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareMonolithicResultType:
|
||||
"""Quantize and Dispatch Router Logits."""
|
||||
|
||||
a1q, scales = _quantize_and_setup_dispatch(a1, quant_config, defer_input_quant)
|
||||
|
||||
res = get_ep_group().dispatch_router_logits(
|
||||
a1q,
|
||||
router_logits,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
extra_tensors=scales,
|
||||
)
|
||||
|
||||
if scales is None:
|
||||
a1q, router_logits = res
|
||||
a1q_scale = None
|
||||
else:
|
||||
a1q, router_logits, scales = res
|
||||
a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config)
|
||||
|
||||
return a1q, a1q_scale, router_logits
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
fused_expert_output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
out = get_ep_group().combine(
|
||||
fused_expert_output, is_sequence_parallel=self.is_sequence_parallel
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def make_moe_prepare_and_finalize_naive_dp_ep(
|
||||
use_monolithic: bool,
|
||||
is_sequence_parallel: bool = False,
|
||||
num_dispatchers: int = 1,
|
||||
) -> MoEPrepareAndFinalizeNaiveDPEPModular | MoEPrepareAndFinalizeNaiveDPEPMonolithic:
|
||||
return (
|
||||
MoEPrepareAndFinalizeNaiveDPEPMonolithic(
|
||||
is_sequence_parallel=is_sequence_parallel,
|
||||
num_dispatchers=num_dispatchers,
|
||||
)
|
||||
if use_monolithic
|
||||
else MoEPrepareAndFinalizeNaiveDPEPModular(
|
||||
is_sequence_parallel=is_sequence_parallel,
|
||||
num_dispatchers=num_dispatchers,
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,141 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceContiguous,
|
||||
TopKWeightAndReduceDelegate,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
|
||||
|
||||
def _quantize_input(
|
||||
a1: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
# Defer input quant to moe kernel for backends (e.g. AITER, FI)
|
||||
# which use a single kernel call for quant + experts.
|
||||
if defer_input_quant:
|
||||
return a1, None
|
||||
|
||||
input_sf = (
|
||||
quant_config.a1_gscale if quant_config.use_nvfp4_w4a4 else quant_config.a1_scale
|
||||
)
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
input_sf,
|
||||
quant_dtype=quant_config.quant_dtype,
|
||||
per_act_token_quant=quant_config.per_act_token_quant,
|
||||
block_shape=quant_config.block_shape,
|
||||
is_fp4_scale_swizzled=quant_config.is_nvfp4_scale_swizzled,
|
||||
)
|
||||
|
||||
return a1q, a1q_scale
|
||||
|
||||
|
||||
class MoEPrepareAndFinalizeNoDPEPModular(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return None
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return 1
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return False
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareResultType:
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
# TODO: this only works for topK=1, will need to update for topK>1
|
||||
assert topk == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
)
|
||||
# Note: do not use inplace for shared experts overlap
|
||||
a1 = a1 * topk_weights.to(a1.dtype)
|
||||
|
||||
a1q, a1q_scale = _quantize_input(a1, quant_config, defer_input_quant)
|
||||
|
||||
return a1q, a1q_scale, None, None, None
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
|
||||
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
|
||||
weight_and_reduce_impl.apply(
|
||||
output=output,
|
||||
fused_expert_output=fused_expert_output,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
class MoEPrepareAndFinalizeNoDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMonolithic):
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return None
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return 1
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return False
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareMonolithicResultType:
|
||||
a1q, a1q_scale = _quantize_input(a1, quant_config, defer_input_quant)
|
||||
return a1q, a1q_scale, router_logits
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
fused_expert_output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return fused_expert_output
|
||||
|
||||
|
||||
def make_moe_prepare_and_finalize_no_dp_ep(
|
||||
use_monolithic: bool,
|
||||
) -> MoEPrepareAndFinalizeNoDPEPModular | MoEPrepareAndFinalizeNoDPEPMonolithic:
|
||||
return (
|
||||
MoEPrepareAndFinalizeNoDPEPMonolithic()
|
||||
if use_monolithic
|
||||
else MoEPrepareAndFinalizeNoDPEPModular()
|
||||
)
|
||||
@@ -292,7 +292,7 @@ def rocm_aiter_fused_experts(
|
||||
)
|
||||
|
||||
|
||||
class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class AiterExperts(mk.FusedMoEExpertsModular):
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -64,7 +64,7 @@ if current_platform.is_cuda_alike():
|
||||
|
||||
# TODO(bowen): When using `FusedMoEModularKernel`, this
|
||||
# can be done in a more unified way, since
|
||||
# `FusedMoEPrepareAndFinalize` will return the expert
|
||||
# `FusedMoEPrepareAndFinalizeModular` will return the expert
|
||||
# token count, in some cases directly from the kernel.
|
||||
# However, now there are many code paths not using
|
||||
# the modular kernel, e.g. calling `fused_experts`,
|
||||
|
||||
@@ -320,8 +320,8 @@ class DefaultMoERunner(MoERunner):
|
||||
"""
|
||||
assert self.quant_method is not None
|
||||
return (
|
||||
self.quant_method.moe_mk is not None
|
||||
and self.quant_method.moe_mk.output_is_reduced()
|
||||
self.quant_method.moe_kernel is not None
|
||||
and self.quant_method.moe_kernel.output_is_reduced()
|
||||
)
|
||||
|
||||
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
|
||||
@@ -640,45 +640,6 @@ class DefaultMoERunner(MoERunner):
|
||||
)
|
||||
|
||||
with sp_ctx:
|
||||
extra_tensors = None
|
||||
if do_naive_dispatch_combine:
|
||||
post_quant_allgather = (
|
||||
self.quant_method is not None
|
||||
and self.moe_config.dp_size > 1
|
||||
and self.moe_config.use_ep
|
||||
and getattr(self.quant_method, "do_post_quant_allgather", False)
|
||||
)
|
||||
if post_quant_allgather:
|
||||
hidden_states_to_dispatch, extra_tensors = (
|
||||
self.quant_method.prepare_dp_allgather_tensor(
|
||||
layer, hidden_states, router_logits
|
||||
)
|
||||
)
|
||||
else:
|
||||
hidden_states_to_dispatch = hidden_states
|
||||
|
||||
dispatch_res = get_ep_group().dispatch_router_logits(
|
||||
hidden_states_to_dispatch,
|
||||
router_logits,
|
||||
self.moe_config.is_sequence_parallel,
|
||||
extra_tensors=extra_tensors,
|
||||
)
|
||||
if extra_tensors is not None:
|
||||
(
|
||||
orig_hidden_states,
|
||||
router_logits,
|
||||
extra_tensors_combined,
|
||||
) = dispatch_res
|
||||
hidden_states_combined = (
|
||||
orig_hidden_states,
|
||||
extra_tensors_combined[0],
|
||||
)
|
||||
else:
|
||||
hidden_states_combined, router_logits = dispatch_res
|
||||
orig_hidden_states = hidden_states_combined
|
||||
else:
|
||||
orig_hidden_states = hidden_states
|
||||
|
||||
# Run shared experts before matrix multiply.
|
||||
# because matrix multiply maybe modify the hidden_states.
|
||||
if has_separate_shared_experts and not use_shared_experts_stream:
|
||||
@@ -688,6 +649,17 @@ class DefaultMoERunner(MoERunner):
|
||||
)
|
||||
shared_output = self.shared_experts(shared_input)
|
||||
|
||||
# For naive dispatch/combine Dp/Ep, dispatch the hidden states and
|
||||
# router logits to all experts.
|
||||
# NOTE: this will be removed once all kernels are migrated into the
|
||||
# MoEKernel framework.
|
||||
if do_naive_dispatch_combine:
|
||||
hidden_states, router_logits = get_ep_group().dispatch_router_logits(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
self.moe_config.is_sequence_parallel,
|
||||
)
|
||||
|
||||
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
|
||||
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
|
||||
# we should modify All2AllManager abstract to better support PCP.
|
||||
@@ -701,31 +673,22 @@ class DefaultMoERunner(MoERunner):
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# TODO(bnell): deal with fp4 flashinfer tuple hidden states hack (#30014).
|
||||
# Figure out nicer way to do this.
|
||||
if do_naive_dispatch_combine:
|
||||
x = hidden_states_combined
|
||||
x_orig = orig_hidden_states
|
||||
else:
|
||||
x = hidden_states
|
||||
x_orig = hidden_states
|
||||
|
||||
# Matrix multiply.
|
||||
if self.quant_method.is_monolithic:
|
||||
final_hidden_states = self.quant_method.apply_monolithic(
|
||||
layer=layer,
|
||||
x=x,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = self.router.select_experts(
|
||||
hidden_states=x_orig,
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=layer,
|
||||
x=x, # The type signture of this is wrong due to the hack.
|
||||
x=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
shared_experts_input=shared_input,
|
||||
|
||||
@@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
|
||||
class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce):
|
||||
"""
|
||||
Useful in the case when some FusedMoEPermuteExpertsUnpermute
|
||||
Useful in the case when some FusedMoEExpertsModular
|
||||
implementation does not perform weight application and reduction
|
||||
but cannot address the needs of all the compatible PrepareAndFinalize
|
||||
implementations.
|
||||
@@ -62,7 +62,7 @@ class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce):
|
||||
if output is None:
|
||||
return fused_expert_output
|
||||
|
||||
# MoEPrepareAndFinalizeNoEP needs the output to be in the `output`
|
||||
# MoEPrepareAndFinalizeNoDPEPModular needs the output to be in the `output`
|
||||
# tensor.
|
||||
assert output.size() == fused_expert_output.size(), (
|
||||
"output shape is expected to match the fused_expert_output shape. "
|
||||
|
||||
@@ -32,8 +32,8 @@ class TritonOrCutlassExperts(FallbackExperts):
|
||||
|
||||
@staticmethod
|
||||
def get_clses() -> tuple[
|
||||
type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
type[mk.FusedMoEExpertsModular],
|
||||
type[mk.FusedMoEExpertsModular],
|
||||
]:
|
||||
return (CutlassExpertsFp8, TritonExperts)
|
||||
|
||||
@@ -77,7 +77,7 @@ class TritonOrCutlassExperts(FallbackExperts):
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
) -> mk.FusedMoEExpertsModular:
|
||||
# Small batch fallback for sm100.
|
||||
if self.is_sm100 and hidden_states.shape[0] <= 8:
|
||||
return self.fallback_experts
|
||||
|
||||
@@ -32,8 +32,8 @@ class TritonOrDeepGemmExperts(FallbackExperts):
|
||||
|
||||
@staticmethod
|
||||
def get_clses() -> tuple[
|
||||
type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
type[mk.FusedMoEPermuteExpertsUnpermute],
|
||||
type[mk.FusedMoEExpertsModular],
|
||||
type[mk.FusedMoEExpertsModular],
|
||||
]:
|
||||
return (DeepGemmExperts, TritonExperts)
|
||||
|
||||
@@ -79,7 +79,7 @@ class TritonOrDeepGemmExperts(FallbackExperts):
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
) -> mk.FusedMoEExpertsModular:
|
||||
if is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2):
|
||||
return self.experts
|
||||
else:
|
||||
|
||||
@@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
|
||||
|
||||
class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class TrtLlmGenExperts(mk.FusedMoEExpertsModular):
|
||||
"""TensorRT-LLM-based fused MoE expert implementation."""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -24,8 +24,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEActivationFormat,
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
FusedMoEExpertsModular,
|
||||
FusedMoEPrepareAndFinalizeModular,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
|
||||
UnquantizedMoeBackend,
|
||||
@@ -70,7 +70,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
self.rocm_aiter_moe_enabled = (
|
||||
rocm_aiter_ops.is_fused_moe_enabled() and moe.is_act_and_mul
|
||||
)
|
||||
self.kernel: mk.FusedMoEModularKernel | None = None
|
||||
self.kernel: mk.FusedMoEKernel | None = None
|
||||
self._is_monolithic = (
|
||||
current_platform.is_cpu()
|
||||
or self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
|
||||
@@ -107,7 +107,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> FusedMoEPrepareAndFinalize | None:
|
||||
) -> FusedMoEPrepareAndFinalizeModular | None:
|
||||
if self.unquantized_backend == UnquantizedMoeBackend.AITER:
|
||||
return None
|
||||
else:
|
||||
@@ -115,9 +115,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
) -> FusedMoEExpertsModular:
|
||||
assert self.moe_quant_config is not None
|
||||
if (
|
||||
prepare_finalize.activation_format
|
||||
@@ -325,7 +325,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.kernel is not None
|
||||
|
||||
return self.kernel(
|
||||
return self.kernel.apply(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
|
||||
@@ -23,7 +23,7 @@ if current_platform.is_xpu():
|
||||
from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe
|
||||
|
||||
|
||||
class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class XPUExperts(mk.FusedMoEExpertsModular):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
|
||||
@@ -19,8 +19,8 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
FusedMoEActivationFormat,
|
||||
FusedMoEExpertsModular,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoeWeightScaleSupported,
|
||||
UnquantizedFusedMoEMethod,
|
||||
)
|
||||
@@ -40,7 +40,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
fused_marlin_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
Fp8MoeBackend,
|
||||
convert_to_fp8_moe_kernel_format,
|
||||
make_fp8_moe_kernel,
|
||||
make_fp8_moe_quant_config,
|
||||
@@ -59,18 +58,11 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compress
|
||||
WNA16_SUPPORTED_BITS,
|
||||
WNA16_SUPPORTED_TYPES_MAP,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
flashinfer_trtllm_fp4_moe,
|
||||
flashinfer_trtllm_fp4_routed_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_mxint4_moe import (
|
||||
flashinfer_trtllm_mxint4_moe,
|
||||
is_flashinfer_mxint4_moe_available,
|
||||
prepare_static_weights_for_trtllm_mxint4_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_fi_trtllm_fp8_per_tensor_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
process_fp8_input_tensor_strategy_moe,
|
||||
process_fp8_weight_tensor_strategy_moe,
|
||||
@@ -336,7 +328,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config is not None:
|
||||
self.moe_mk = make_nvfp4_moe_kernel(
|
||||
self.moe_kernel = make_nvfp4_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
experts_cls=self.experts_cls,
|
||||
@@ -352,8 +344,8 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk(
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@@ -562,43 +554,27 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w13_input_scale = a13_scale
|
||||
layer.w2_input_scale = a2_scale
|
||||
|
||||
# Setup modular kernel for TP case and naive DP/EP case.
|
||||
# In non-naive DP/EP case, we will create a ModularKernelMethod.
|
||||
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
|
||||
# in both cases.
|
||||
# Setup modular kernel.
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.moe_mk = make_nvfp4_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
experts_cls=self.experts_cls,
|
||||
shared_experts=layer.shared_experts,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
)
|
||||
assert self.experts_cls is not None
|
||||
self.moe_kernel = make_nvfp4_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
experts_cls=self.experts_cls,
|
||||
shared_experts=layer.shared_experts,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
)
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||
return make_nvfp4_moe_quant_config(
|
||||
backend=self.nvfp4_backend,
|
||||
w13_scale=layer.w13_weight_scale,
|
||||
@@ -609,13 +585,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return (
|
||||
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
|
||||
and not self.moe.moe_parallel_config.enable_eplb
|
||||
)
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
@@ -623,24 +592,20 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.is_monolithic
|
||||
assert layer.activation == MoEActivation.SILU, (
|
||||
f"Only SiLU activation is supported, not {layer.activation}."
|
||||
)
|
||||
assert (
|
||||
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
|
||||
and not layer.enable_eplb
|
||||
)
|
||||
return flashinfer_trtllm_fp4_moe(
|
||||
layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=layer.top_k,
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply_monolithic(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
router_logits,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
custom_routing_function=layer.custom_routing_function,
|
||||
e_score_correction_bias=layer.e_score_correction_bias,
|
||||
routed_scaling_factor=layer.routed_scaling_factor,
|
||||
)
|
||||
|
||||
def apply(
|
||||
@@ -651,34 +616,19 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert not self.is_monolithic
|
||||
|
||||
# EPLB path
|
||||
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
assert layer.enable_eplb
|
||||
return flashinfer_trtllm_fp4_routed_moe(
|
||||
layer=layer,
|
||||
x=x,
|
||||
topk_ids=topk_ids,
|
||||
topk_weights=topk_weights,
|
||||
top_k=layer.top_k,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
)
|
||||
else:
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
@@ -966,7 +916,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.moe_mk = make_fp8_moe_kernel(
|
||||
self.moe_kernel = make_fp8_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
@@ -978,94 +928,47 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
w1_scale = layer.w13_weight_scale
|
||||
w2_scale = layer.w2_weight_scale
|
||||
a1_scale = layer.w13_input_scale
|
||||
a2_scale = layer.w2_input_scale
|
||||
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||
is_per_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
|
||||
return make_fp8_moe_quant_config(
|
||||
fp8_backend=self.fp8_backend,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
per_act_token_quant=(
|
||||
self.input_quant.strategy == QuantizationStrategy.TOKEN
|
||||
),
|
||||
per_out_ch_quant=(self.input_quant.strategy == QuantizationStrategy.TOKEN),
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
per_act_token_quant=is_per_token,
|
||||
per_out_ch_quant=is_per_token,
|
||||
block_shape=self.weight_block_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.is_monolithic
|
||||
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
assert layer.activation == MoEActivation.SILU, (
|
||||
f"Only SiLU activation is supported, not {layer.activation}."
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply_monolithic(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
router_logits,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
e_score_correction_bias=layer.e_score_correction_bias,
|
||||
routed_scaling_factor=layer.routed_scaling_factor,
|
||||
)
|
||||
|
||||
if self.block_quant:
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
|
||||
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=layer.e_score_correction_bias,
|
||||
x=x,
|
||||
w13_weight=layer.w13_weight,
|
||||
w13_weight_scale_inv=layer.w13_weight_scale,
|
||||
w2_weight=layer.w2_weight,
|
||||
w2_weight_scale_inv=layer.w2_weight_scale,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
top_k=layer.top_k,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
block_shape=self.weight_block_size,
|
||||
routing_method_type=layer.routing_method_type,
|
||||
routed_scaling=layer.routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
return apply_fi_trtllm_fp8_per_tensor_moe(
|
||||
layer=layer,
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
routing_bias=layer.e_score_correction_bias,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
top_k=layer.top_k,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
@@ -1075,8 +978,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert not self.is_monolithic
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk(
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@@ -1652,9 +1555,9 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
) -> mk.FusedMoEExpertsModular:
|
||||
assert self.num_bits == 4, "only supporting w4"
|
||||
layer.w13_weight = layer.w13_weight_packed
|
||||
layer.w2_weight = layer.w2_weight_packed
|
||||
@@ -1943,9 +1846,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
) -> mk.FusedMoEExpertsModular:
|
||||
if self.moe.is_lora_enabled:
|
||||
assert self.moe_quant_config is not None
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
@@ -2527,7 +2430,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
|
||||
return super().maybe_make_prepare_finalize(routing_tables)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
@@ -2548,9 +2451,9 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
) -> mk.FusedMoEExpertsModular:
|
||||
assert self.moe_quant_config is not None
|
||||
assert (
|
||||
prepare_finalize.activation_format == FusedMoEActivationFormat.Standard
|
||||
@@ -2558,7 +2461,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import CutlassExpertsW4A8Fp8
|
||||
|
||||
experts: FusedMoEPermuteExpertsUnpermute
|
||||
experts: FusedMoEExpertsModular
|
||||
|
||||
logger.debug("CutlassExpertsW4A8Fp8(%s)", self.__class__.__name__)
|
||||
experts = CutlassExpertsW4A8Fp8(
|
||||
|
||||
@@ -23,17 +23,13 @@ from vllm.model_executor.layers.batch_invariant import (
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
FusedMoeWeightScaleSupported,
|
||||
MoEActivation,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
|
||||
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
Fp8MoeBackend,
|
||||
convert_to_fp8_moe_kernel_format,
|
||||
make_fp8_moe_kernel,
|
||||
make_fp8_moe_quant_config,
|
||||
@@ -50,9 +46,6 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_fi_trtllm_fp8_per_tensor_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
create_fp8_input_scale,
|
||||
@@ -860,14 +853,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
|
||||
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
|
||||
|
||||
# Setup modular kernel for TP case and naive DP/EP case.
|
||||
# In non-naive DP/EP case, we will create a ModularKernelMethod.
|
||||
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
|
||||
# in both cases.
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.moe_mk = make_fp8_moe_kernel(
|
||||
self.moe_kernel = make_fp8_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
@@ -930,29 +919,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
# TRTLLM does not use Modular Kernel.
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
return None
|
||||
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||
w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
|
||||
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
|
||||
a1_scale = layer.w13_input_scale
|
||||
@@ -983,10 +956,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
def supports_eplb(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
@@ -994,50 +963,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.is_monolithic
|
||||
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
|
||||
# TODO(rob): convert this to MK.
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
|
||||
assert layer.activation == MoEActivation.SILU, (
|
||||
f"Expected 'silu' activation but got {layer.activation}"
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply_monolithic(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
router_logits,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
e_score_correction_bias=layer.e_score_correction_bias,
|
||||
routed_scaling_factor=layer.routed_scaling_factor,
|
||||
)
|
||||
|
||||
if self.block_quant:
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
|
||||
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=layer.e_score_correction_bias,
|
||||
x=x,
|
||||
w13_weight=layer.w13_weight,
|
||||
w13_weight_scale_inv=layer.w13_weight_scale_inv,
|
||||
w2_weight=layer.w2_weight,
|
||||
w2_weight_scale_inv=layer.w2_weight_scale_inv,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
top_k=layer.top_k,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
block_shape=self.weight_block_size,
|
||||
routing_method_type=layer.routing_method_type,
|
||||
routed_scaling=layer.routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
return apply_fi_trtllm_fp8_per_tensor_moe(
|
||||
layer=layer,
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
routing_bias=layer.e_score_correction_bias,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
top_k=layer.top_k,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
@@ -1046,9 +987,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.moe_mk is not None
|
||||
assert not self.is_monolithic
|
||||
return self.moe_mk(
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
|
||||
@@ -13,7 +13,6 @@ from vllm.model_executor.kernels.linear import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.attention import Attention, MLAAttention
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
@@ -24,14 +23,12 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
Fp8MoeBackend,
|
||||
convert_to_fp8_moe_kernel_format,
|
||||
make_fp8_moe_kernel,
|
||||
make_fp8_moe_quant_config,
|
||||
select_fp8_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
|
||||
NvFp4MoeBackend,
|
||||
convert_to_nvfp4_moe_kernel_format,
|
||||
is_global_sf_supported_for_nvfp4_backend,
|
||||
make_nvfp4_moe_kernel,
|
||||
@@ -49,13 +46,6 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
flashinfer_trtllm_fp4_moe,
|
||||
flashinfer_trtllm_fp4_routed_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_fi_trtllm_fp8_per_tensor_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
process_fp8_input_tensor_strategy_moe,
|
||||
@@ -746,7 +736,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
@@ -754,9 +744,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
) -> mk.FusedMoEExpertsModular:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
@@ -871,16 +861,15 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
# Setup modular kernel.
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.moe_mk = make_fp8_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
experts_cls=self.experts_cls,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
shared_experts=layer.shared_experts,
|
||||
)
|
||||
assert self.experts_cls is not None
|
||||
self.moe_kernel = make_fp8_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
experts_cls=self.experts_cls,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
shared_experts=layer.shared_experts,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
w13 = layer.w13_weight
|
||||
@@ -913,9 +902,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||
w1_scale = layer.w13_weight_scale
|
||||
w2_scale = layer.w2_weight_scale
|
||||
a1_scale = layer.w13_input_scale
|
||||
@@ -929,10 +916,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
@@ -940,28 +923,20 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.is_monolithic
|
||||
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend."
|
||||
)
|
||||
# TODO(rob): this validation should happen at kernel selection
|
||||
# time in the oracle rather than here.
|
||||
SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
assert layer.activation in SUPPORTED_ACTIVATIONS, (
|
||||
f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
|
||||
f"TRTLLM FP4 MoE, {layer.activation} found instead."
|
||||
)
|
||||
return apply_fi_trtllm_fp8_per_tensor_moe(
|
||||
layer=layer,
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
routing_bias=layer.e_score_correction_bias,
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply_monolithic(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
router_logits,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
top_k=layer.top_k,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
e_score_correction_bias=layer.e_score_correction_bias,
|
||||
routed_scaling_factor=layer.routed_scaling_factor,
|
||||
)
|
||||
|
||||
def apply(
|
||||
@@ -973,25 +948,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert not self.is_monolithic
|
||||
|
||||
# TODO(rob): this validation should happen at kernel selection
|
||||
# time in the oracle rather than here.
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
assert layer.activation in (
|
||||
MoEActivation.SILU,
|
||||
MoEActivation.RELU2_NO_MUL,
|
||||
), (
|
||||
"Expected activation to be in ('silu', 'relu2_no_mul'),"
|
||||
f"but got {layer.activation}"
|
||||
)
|
||||
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
@@ -1235,17 +1198,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
@@ -1420,51 +1373,18 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
replace_parameter(layer, "w2_weight_scale_2", w2_scale_2)
|
||||
replace_parameter(layer, "w2_input_scale", a2_scale)
|
||||
|
||||
# Setup modular kernel for TP case and naive DP/EP case.
|
||||
# In non-naive DP/EP case, we will create a ModularKernelMethod.
|
||||
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
|
||||
# in both cases.
|
||||
# Setup modular kernel.
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.moe_mk = make_nvfp4_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
experts_cls=self.experts_cls,
|
||||
shared_experts=layer.shared_experts,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
)
|
||||
|
||||
@property
|
||||
def do_post_quant_allgather(self):
|
||||
return self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
|
||||
|
||||
def prepare_dp_allgather_tensor(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
"""Optionally prepare extra tensors to carry through DP allgather/EP."""
|
||||
if self.nvfp4_backend != NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
raise RuntimeError(
|
||||
"prepare_dp_allgather_tensor is only supported for "
|
||||
"FlashInfer TRTLLM NVFP4 MoE backend."
|
||||
)
|
||||
|
||||
import flashinfer
|
||||
|
||||
hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
|
||||
hidden_states,
|
||||
layer.a1_gscale,
|
||||
is_sf_swizzled_layout=False,
|
||||
assert self.experts_cls is not None
|
||||
self.moe_kernel = make_nvfp4_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
experts_cls=self.experts_cls,
|
||||
shared_experts=layer.shared_experts,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
)
|
||||
extra_tensors: list[torch.Tensor] = [hidden_states_sf]
|
||||
return hidden_states_fp4, extra_tensors
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||
return make_nvfp4_moe_quant_config(
|
||||
backend=self.nvfp4_backend,
|
||||
w13_scale=layer.w13_weight_scale,
|
||||
@@ -1479,13 +1399,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
def supports_eplb(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return (
|
||||
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
|
||||
and not self.moe.moe_parallel_config.enable_eplb
|
||||
)
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
@@ -1493,22 +1406,20 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.is_monolithic
|
||||
assert (
|
||||
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
|
||||
and not layer.enable_eplb
|
||||
)
|
||||
|
||||
return flashinfer_trtllm_fp4_moe(
|
||||
layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=layer.top_k,
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply_monolithic(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
router_logits,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
custom_routing_function=layer.custom_routing_function,
|
||||
e_score_correction_bias=layer.e_score_correction_bias,
|
||||
routed_scaling_factor=layer.routed_scaling_factor,
|
||||
)
|
||||
|
||||
def apply(
|
||||
@@ -1520,33 +1431,19 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert not self.is_monolithic
|
||||
|
||||
# EPLB path
|
||||
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
assert layer.enable_eplb
|
||||
return flashinfer_trtllm_fp4_routed_moe(
|
||||
layer=layer,
|
||||
x=x,
|
||||
topk_ids=topk_ids,
|
||||
topk_weights=topk_weights,
|
||||
top_k=layer.top_k,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
)
|
||||
else:
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
|
||||
ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
|
||||
|
||||
@@ -266,7 +266,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
||||
# Initialized in process_weights_after_loading for CUTLASS/SM90 backends
|
||||
self.moe_mk: mk.FusedMoEModularKernel | None = None
|
||||
self.moe_kernel: mk.FusedMoEKernel | None = None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -440,7 +440,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
assert prepare_finalize is not None
|
||||
|
||||
self.moe_mk = mk.FusedMoEModularKernel(
|
||||
self.moe_kernel = mk.FusedMoEKernel(
|
||||
prepare_finalize,
|
||||
MarlinExperts(
|
||||
self.moe,
|
||||
@@ -789,7 +789,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
assert prepare_finalize is not None
|
||||
|
||||
self.moe_mk = mk.FusedMoEModularKernel(
|
||||
self.moe_kernel = mk.FusedMoEKernel(
|
||||
prepare_finalize,
|
||||
FlashInferExperts(
|
||||
moe_config=self.moe,
|
||||
@@ -954,9 +954,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
) -> mk.FusedMoEExpertsModular:
|
||||
if (
|
||||
prepare_finalize.activation_format
|
||||
== mk.FusedMoEActivationFormat.BatchedExperts
|
||||
@@ -1043,8 +1043,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
or self.mxfp4_backend == Mxfp4Backend.MARLIN
|
||||
)
|
||||
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk(
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
|
||||
@@ -6,28 +6,18 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
activation_to_flashinfer_int,
|
||||
align_fp4_moe_weights_for_fi,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
||||
swizzle_blockscale,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kNvfp4Dynamic,
|
||||
kNvfp4Static,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import (
|
||||
has_flashinfer_cutlass_fused_moe,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
@@ -42,92 +32,15 @@ __all__ = [
|
||||
"reorder_w1w3_to_w3w1",
|
||||
]
|
||||
|
||||
#
|
||||
# Methods used by the oracle for kernel selection.
|
||||
#
|
||||
|
||||
|
||||
def _supports_current_device() -> bool:
|
||||
"""Supports only Blackwell-family GPUs."""
|
||||
p = current_platform
|
||||
return p.is_cuda() and p.is_device_capability_family(100)
|
||||
|
||||
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
"""Supports non-gated MoE."""
|
||||
return True
|
||||
|
||||
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Supports Nvfp4 quantization."""
|
||||
SUPPORTED_W_A = [
|
||||
(kNvfp4Static, kNvfp4Dynamic),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
|
||||
|
||||
def _supports_routing_method(
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""Monolithic kernels need to express router support."""
|
||||
# NOTE(rob): potentially allow others here. This is a conservative list.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
RoutingMethodType.Llama4,
|
||||
]
|
||||
|
||||
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
"""
|
||||
TRTLLM is a monolithic kernel that requires dispatch_router_logits() for
|
||||
the naive dispatch/combine path. DeepEP HT only implements dispatch() for
|
||||
the modular kernel path, so TRTLLM is incompatible with DeepEP HT.
|
||||
"""
|
||||
return not moe_parallel_config.use_deepep_ht_kernels
|
||||
|
||||
|
||||
def is_supported_config_trtllm(
|
||||
moe_config: FusedMoEConfig,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
|
||||
"""
|
||||
|
||||
def _make_reason(reason: str) -> str:
|
||||
return f"kernel does not support {reason}"
|
||||
|
||||
if not _supports_current_device():
|
||||
return False, _make_reason(f"current device {current_platform.device_name}")
|
||||
elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
|
||||
return False, _make_reason("no act_and_mul MLP layer")
|
||||
elif not _supports_activation(moe_config.activation):
|
||||
return False, _make_reason(f"{moe_config.activation} activation")
|
||||
elif not _supports_quant_scheme(weight_key, activation_key):
|
||||
return False, _make_reason(f"quantization scheme {weight_key}x{activation_key}")
|
||||
elif not _supports_parallel_config(moe_config.moe_parallel_config):
|
||||
return False, _make_reason(f"parallel config {moe_config.moe_parallel_config}")
|
||||
elif not _supports_routing_method(moe_config.routing_method):
|
||||
return False, _make_reason(f"routing method {moe_config.routing_method}")
|
||||
elif activation_format != mk.FusedMoEActivationFormat.Standard:
|
||||
return False, _make_reason(f"activation format {activation_format}")
|
||||
elif moe_config.hidden_dim % 512 != 0:
|
||||
return False, _make_reason(
|
||||
f"hidden_dim must be divisible by 512, found {moe_config.hidden_dim}"
|
||||
)
|
||||
|
||||
return True, None
|
||||
def is_flashinfer_fp4_cutlass_moe_available() -> bool:
|
||||
"""Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
|
||||
return (
|
||||
envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||
and has_flashinfer_cutlass_fused_moe()
|
||||
and current_platform.is_cuda()
|
||||
and current_platform.has_device_capability(100)
|
||||
)
|
||||
|
||||
|
||||
def reorder_w1w3_to_w3w1(
|
||||
@@ -276,190 +189,6 @@ def prepare_static_weights_for_trtllm_fp4_moe(
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_trtllm_fp4_moe(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
custom_routing_function: object | None,
|
||||
e_score_correction_bias: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply FlashInfer TensorRT-LLM FP4 MoE kernel.
|
||||
|
||||
Args:
|
||||
layer: The MoE layer with weights and scales
|
||||
x: Input tensor
|
||||
router_logits: Router logits for expert selection
|
||||
top_k: Number of experts to select per token
|
||||
activation: Activation function to use
|
||||
global_num_experts: Total number of experts across all ranks
|
||||
num_expert_group: Number of expert groups (for grouped routing)
|
||||
topk_group: Top-k within each group
|
||||
custom_routing_function: Custom routing function (e.g., Llama4)
|
||||
e_score_correction_bias: Optional routing bias correction
|
||||
|
||||
Returns:
|
||||
Output tensor from the MoE layer
|
||||
"""
|
||||
import flashinfer
|
||||
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
|
||||
SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
assert activation in SUPPORTED_ACTIVATIONS, (
|
||||
f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
|
||||
f"TRTLLM FP4 MoE, {activation} found instead."
|
||||
)
|
||||
|
||||
# Quantize input to FP4
|
||||
if isinstance(x, tuple):
|
||||
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
|
||||
else:
|
||||
# hidden_states is the already quantized
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
|
||||
x, layer.a1_gscale, is_sf_swizzled_layout=False
|
||||
)
|
||||
|
||||
# Determine routing method type
|
||||
use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function
|
||||
routing_method_type = layer.routing_method_type
|
||||
if use_llama4_routing:
|
||||
routing_method_type = flashinfer.RoutingMethodType.Llama4
|
||||
|
||||
# Cast to Fp32 (required by kernel).
|
||||
router_logits = (
|
||||
router_logits.to(torch.float32)
|
||||
if routing_method_type == RoutingMethodType.DeepSeekV3
|
||||
else router_logits
|
||||
)
|
||||
|
||||
# Determine activation type
|
||||
activation_type = activation_to_flashinfer_int(layer.activation)
|
||||
|
||||
# Call TRT-LLM FP4 block-scale MoE kernel
|
||||
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
hidden_states=hidden_states_fp4,
|
||||
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
||||
torch.float8_e4m3fn
|
||||
).reshape(*hidden_states_fp4.shape[:-1], -1),
|
||||
gemm1_weights=layer.w13_weight.data,
|
||||
gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
|
||||
gemm1_bias=None,
|
||||
gemm1_alpha=None,
|
||||
gemm1_beta=None,
|
||||
gemm1_clamp_limit=None,
|
||||
gemm2_weights=layer.w2_weight.data,
|
||||
gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
|
||||
gemm2_bias=None,
|
||||
output1_scale_scalar=layer.g1_scale_c.data,
|
||||
output1_scale_gate_scalar=layer.g1_alphas.data,
|
||||
output2_scale_scalar=layer.g2_alphas.data,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
n_group=num_expert_group if num_expert_group is not None else 0,
|
||||
topk_group=topk_group if topk_group is not None else 0,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
routed_scaling_factor=None,
|
||||
routing_method_type=routing_method_type,
|
||||
do_finalize=True,
|
||||
activation_type=activation_type,
|
||||
)[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def flashinfer_trtllm_fp4_routed_moe(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
top_k: int,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply FlashInfer TensorRT-LLM FP4 MoE kernel. Uses packed
|
||||
input top k expert indices and scores rather than computing
|
||||
top k expert indices from scores.
|
||||
|
||||
Args:
|
||||
layer: The MoE layer with weights and scales
|
||||
x: Input tensor
|
||||
topk_ids: Ids of selected experts
|
||||
top_k: Number of experts to select per token
|
||||
activation: Activation function to use
|
||||
global_num_experts: Total number of experts across all ranks
|
||||
|
||||
Returns:
|
||||
Output tensor from the MoE layer
|
||||
"""
|
||||
import flashinfer
|
||||
|
||||
# https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2535
|
||||
assert activation == MoEActivation.SILU, (
|
||||
"Only SiLU activation is supported for FlashInfer TRTLLM FP4 Routed MoE. "
|
||||
f"{activation} found instead."
|
||||
)
|
||||
|
||||
# Pack top k ids and expert weights into a single int32 tensor, as
|
||||
# required by TRT-LLM
|
||||
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
|
||||
torch.bfloat16
|
||||
).view(torch.int16)
|
||||
|
||||
if isinstance(x, tuple):
|
||||
# Hidden_states is the already quantized
|
||||
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
|
||||
else:
|
||||
# Quantize input to FP4
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
|
||||
x, layer.a1_gscale, is_sf_swizzled_layout=False
|
||||
)
|
||||
|
||||
# Call TRT-LLM FP4 block-scale MoE kernel
|
||||
out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
|
||||
topk_ids=packed_tensor,
|
||||
routing_bias=None,
|
||||
hidden_states=hidden_states_fp4,
|
||||
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
||||
torch.float8_e4m3fn
|
||||
).reshape(*hidden_states_fp4.shape[:-1], -1),
|
||||
gemm1_weights=layer.w13_weight.data,
|
||||
gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
|
||||
gemm1_bias=None,
|
||||
gemm1_alpha=None,
|
||||
gemm1_beta=None,
|
||||
gemm1_clamp_limit=None,
|
||||
gemm2_weights=layer.w2_weight.data,
|
||||
gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
|
||||
gemm2_bias=None,
|
||||
output1_scale_scalar=layer.g1_scale_c.data,
|
||||
output1_scale_gate_scalar=layer.g1_alphas.data,
|
||||
output2_scale_scalar=layer.g2_alphas.data,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
n_group=0,
|
||||
topk_group=0,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
routed_scaling_factor=None,
|
||||
routing_method_type=1,
|
||||
do_finalize=True,
|
||||
)[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
|
||||
backend: "NvFp4MoeBackend",
|
||||
layer: "FusedMoE",
|
||||
@@ -526,6 +255,7 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
|
||||
)
|
||||
)
|
||||
layer.intermediate_size_per_partition = padded_intermediate
|
||||
layer.moe_config.intermediate_size_per_partition = padded_intermediate
|
||||
|
||||
w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe(
|
||||
w13,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
@@ -10,6 +11,9 @@ from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from flashinfer.fused_moe.core import ActivationType
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -20,6 +24,10 @@ class FlashinferMoeBackend(Enum):
|
||||
|
||||
|
||||
def activation_to_flashinfer_int(activation: MoEActivation) -> int:
|
||||
return activation_to_flashinfer_type(activation).value
|
||||
|
||||
|
||||
def activation_to_flashinfer_type(activation: MoEActivation) -> "ActivationType":
|
||||
from flashinfer.fused_moe.core import ActivationType
|
||||
|
||||
# silu and gelu are mapped to their gated versions SwiGLU and GeGLU respectively
|
||||
@@ -30,7 +38,7 @@ def activation_to_flashinfer_int(activation: MoEActivation) -> int:
|
||||
MoEActivation.GELU: ActivationType.Geglu,
|
||||
MoEActivation.RELU2_NO_MUL: ActivationType.Relu2,
|
||||
}
|
||||
return ACTIVATION_TO_FI_ACTIVATION[activation].value
|
||||
return ACTIVATION_TO_FI_ACTIVATION[activation]
|
||||
|
||||
|
||||
def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -87,104 +95,6 @@ def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
|
||||
)
|
||||
|
||||
|
||||
def register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer: torch.nn.Module,
|
||||
w13_scale: torch.Tensor,
|
||||
w13_input_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
) -> None:
|
||||
"""Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel"""
|
||||
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
|
||||
w13_scale=w13_scale,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_scale=w2_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
layer.w2_input_scale_inv = 1.0 / w2_input_scale
|
||||
layer.output1_scales_gate_scalar = g1_alphas
|
||||
|
||||
if layer.activation.is_gated:
|
||||
layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv
|
||||
else:
|
||||
layer.output1_scales_scalar = (
|
||||
torch.ones_like(g1_alphas) * layer.w2_input_scale_inv
|
||||
)
|
||||
layer.output2_scales_scalar = g2_alphas
|
||||
|
||||
|
||||
def apply_fi_trtllm_fp8_per_tensor_moe(
|
||||
layer: torch.nn.Module,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
top_k: int,
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
global_num_experts: int,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> torch.Tensor:
|
||||
from flashinfer.fused_moe import RoutingMethodType
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
|
||||
# Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe
|
||||
assert (
|
||||
hasattr(layer, "output1_scales_scalar")
|
||||
and hasattr(layer, "output1_scales_gate_scalar")
|
||||
and hasattr(layer, "output2_scales_scalar")
|
||||
)
|
||||
|
||||
if layer.routing_method_type == RoutingMethodType.Llama4:
|
||||
assert (
|
||||
not layer.renormalize
|
||||
and layer.custom_routing_function == Llama4MoE.custom_routing_function
|
||||
), (
|
||||
"FusedMoE flashinfer kernels with Llama4 routing method are only "
|
||||
"supported for Llama4"
|
||||
)
|
||||
else:
|
||||
assert layer.custom_routing_function is None, (
|
||||
"Custom routing function is only supported for Llama4"
|
||||
)
|
||||
activation_type = activation_to_flashinfer_int(layer.activation)
|
||||
|
||||
return torch.ops.vllm.fi_trtllm_fp8_per_tensor_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=hidden_states,
|
||||
input_scale=layer.w13_input_scale,
|
||||
gemm1_weights=layer.w13_weight,
|
||||
gemm2_weights=layer.w2_weight,
|
||||
output1_scales_scalar=layer.output1_scales_scalar,
|
||||
output1_scales_gate_scalar=layer.output1_scales_gate_scalar,
|
||||
output2_scales_scalar=layer.output2_scales_scalar,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
use_routing_scales_on_input=apply_router_weight_on_input,
|
||||
routing_method_type=layer.routing_method_type,
|
||||
activation_type=activation_type,
|
||||
)
|
||||
|
||||
|
||||
def make_fp8_moe_alpha_scales_for_fi(
|
||||
w13_scale: torch.Tensor,
|
||||
w13_input_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
g1_alphas = (w13_scale * w13_input_scale).squeeze()
|
||||
g2_alphas = (w2_scale * w2_input_scale).squeeze()
|
||||
|
||||
return g1_alphas, g2_alphas
|
||||
|
||||
|
||||
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
|
||||
backend_map = {
|
||||
"throughput": FlashinferMoeBackend.CUTLASS,
|
||||
@@ -432,6 +342,7 @@ def prepare_fp8_moe_layer_for_fi(
|
||||
min_alignment,
|
||||
)
|
||||
layer.intermediate_size_per_partition = new_intermediate
|
||||
layer.moe_config.intermediate_size_per_partition = new_intermediate
|
||||
|
||||
# FI kernels require W31 layout rather than W13.
|
||||
if layer.moe_config.is_act_and_mul:
|
||||
@@ -440,20 +351,12 @@ def prepare_fp8_moe_layer_for_fi(
|
||||
w13_scale = swap_w13_to_w31(w13_scale)
|
||||
|
||||
# FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle
|
||||
# and registration of alpha scales. Note that we do not register
|
||||
# as nn.Parameters since they are not needed for weight-reloading.
|
||||
# and registration of alpha scales.
|
||||
if is_trtllm and not block_quant:
|
||||
assert w13_input_scale is not None
|
||||
assert w2_input_scale is not None
|
||||
|
||||
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2, is_gated)
|
||||
register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer,
|
||||
w13_scale=w13_scale,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_scale=w2_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
|
||||
# Clamp block scales to avoid NaN from the FlashInfer CUTLASS kernel.
|
||||
# Some FP8 models have near-zero block scales (~1e-23) for dead/unused
|
||||
|
||||
@@ -172,7 +172,7 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
||||
|
||||
# Further check if the ModularKernel implementation uses the DeepGemmExperts
|
||||
return isinstance(
|
||||
module.quant_method.moe_mk, (DeepGemmExperts, TritonOrDeepGemmExperts)
|
||||
module.quant_method.moe_kernel, (DeepGemmExperts, TritonOrDeepGemmExperts)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -88,9 +88,14 @@ def flashinfer_autotune(runner: "GPUModelRunner") -> None:
|
||||
Without autotuning, FlashInfer will rely on heuristics, which may
|
||||
be significantly slower.
|
||||
"""
|
||||
from vllm.utils.flashinfer import autotune
|
||||
import vllm.utils.flashinfer as fi_utils
|
||||
|
||||
with torch.inference_mode(), fi_utils.autotune():
|
||||
# Certain FlashInfer kernels (e.g. nvfp4 routed moe) are
|
||||
# incompatible with autotuning. This state is used to skip
|
||||
# those kernels during the autotuning process.
|
||||
fi_utils._is_fi_autotuning = True
|
||||
|
||||
with torch.inference_mode(), autotune():
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
# When autotuning with number of tokens m, flashinfer will autotune
|
||||
# operations for all number of tokens up to m.
|
||||
@@ -100,3 +105,5 @@ def flashinfer_autotune(runner: "GPUModelRunner") -> None:
|
||||
skip_eplb=True,
|
||||
is_profile=True,
|
||||
)
|
||||
|
||||
fi_utils._is_fi_autotuning = False
|
||||
|
||||
@@ -140,6 +140,7 @@ autotune = _lazy_import_wrapper(
|
||||
"autotune",
|
||||
fallback_fn=lambda *args, **kwargs: contextlib.nullcontext(),
|
||||
)
|
||||
_is_fi_autotuning: bool = False
|
||||
|
||||
|
||||
@functools.cache
|
||||
|
||||
Reference in New Issue
Block a user