[Bugfix] Fix passing of activation_type to trtllm fused MoE NVFP4 and FP8 (#36017)
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
(cherry picked from commit d7adcadb9b)
This commit is contained in:
@@ -240,12 +240,11 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Delay import for non-CUDA.
|
# Delay import for non-CUDA.
|
||||||
import flashinfer
|
import flashinfer
|
||||||
from flashinfer.fused_moe.core import ActivationType
|
|
||||||
|
|
||||||
# Confirm supported activation function.
|
# Confirm supported activation function.
|
||||||
assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||||
|
|
||||||
activation_type = ActivationType(activation_to_flashinfer_int(activation))
|
activation_type = activation_to_flashinfer_int(activation)
|
||||||
|
|
||||||
# Confirm Llama-4 routing is proper.
|
# Confirm Llama-4 routing is proper.
|
||||||
if self.routing_method_type == RoutingMethodType.Llama4:
|
if self.routing_method_type == RoutingMethodType.Llama4:
|
||||||
|
|||||||
@@ -323,4 +323,5 @@ class TrtLlmNvFp4ExpertsMonolithic(
|
|||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
routing_method_type=self.routing_method_type,
|
routing_method_type=self.routing_method_type,
|
||||||
do_finalize=True,
|
do_finalize=True,
|
||||||
|
activation_type=activation_to_flashinfer_int(activation),
|
||||||
)[0]
|
)[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user