[Bugfix][MXFP4] Call trtllm_fp4_block_scale_moe with kwargs (#33104)
Signed-off-by: Pengchao Wang <wpc@fb.com>
This commit is contained in:
@@ -1053,32 +1053,32 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1)
|
||||
|
||||
trtllm_gen_output = trtllm_fp4_block_scale_moe(
|
||||
router_logits.to(torch.bfloat16),
|
||||
None, # routing_bias
|
||||
x_quant,
|
||||
x_scale,
|
||||
layer.w13_weight, # uint8 (e2m1 x 2)
|
||||
layer.w13_weight_scale, # uint8 (e4m3 x 2)
|
||||
layer.w13_bias, # fp32 per expert per channel
|
||||
layer.gemm1_alpha, # fp32 per expert
|
||||
layer.gemm1_beta, # fp32 per expert
|
||||
layer.gemm1_clamp_limit, # fp32 per expert
|
||||
layer.w2_weight, # uint8 (e2m1 x 2)
|
||||
layer.w2_weight_scale, # ue8m0
|
||||
layer.w2_bias, # fp32 per expert per channel
|
||||
None, # output1_scale_scalar
|
||||
None, # output1_scale_gate_scalar
|
||||
None, # output2_scale_scalar
|
||||
layer.global_num_experts,
|
||||
layer.top_k,
|
||||
None, # n_group
|
||||
None, # topk_group
|
||||
self.intermediate_size, # padded to multiple of 256
|
||||
layer.ep_rank * layer.local_num_experts, # local_expert_offset
|
||||
self.num_experts, # local num experts
|
||||
None, # routed_scaling_factor
|
||||
1 if layer.renormalize else 0, # routing_method_type, renormalize
|
||||
True, # do finalize
|
||||
routing_logits=router_logits.to(torch.bfloat16),
|
||||
routing_bias=None,
|
||||
hidden_states=x_quant,
|
||||
hidden_states_scale=x_scale,
|
||||
gemm1_weights=layer.w13_weight, # uint8 (e2m1 x 2)
|
||||
gemm1_weights_scale=layer.w13_weight_scale, # uint8 (e4m3 x 2)
|
||||
gemm1_bias=layer.w13_bias, # fp32 per expert per channel
|
||||
gemm1_alpha=layer.gemm1_alpha, # fp32 per expert
|
||||
gemm1_beta=layer.gemm1_beta, # fp32 per expert
|
||||
gemm1_clamp_limit=layer.gemm1_clamp_limit, # fp32 per expert
|
||||
gemm2_weights=layer.w2_weight, # uint8 (e2m1 x 2)
|
||||
gemm2_weights_scale=layer.w2_weight_scale, # ue8m0
|
||||
gemm2_bias=layer.w2_bias, # fp32 per expert per channel
|
||||
output1_scale_scalar=None,
|
||||
output1_scale_gate_scalar=None,
|
||||
output2_scale_scalar=None,
|
||||
num_experts=layer.global_num_experts,
|
||||
top_k=layer.top_k,
|
||||
n_group=None,
|
||||
topk_group=None,
|
||||
intermediate_size=self.intermediate_size, # padded to multiple of 256
|
||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=self.num_experts,
|
||||
routed_scaling_factor=None,
|
||||
routing_method_type=1 if layer.renormalize else 0,
|
||||
do_finalize=True,
|
||||
tune_max_num_tokens=max(self.max_capture_size, 1),
|
||||
)[0]
|
||||
return trtllm_gen_output
|
||||
|
||||
Reference in New Issue
Block a user