Fix fallback to default tactic (flashinfer autotuner) with trtllm_fp4_block_scale_moe (#35088)
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
This commit is contained in:
@@ -348,7 +348,7 @@ def flashinfer_trtllm_fp4_moe(
|
||||
hidden_states=hidden_states_fp4,
|
||||
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
||||
torch.float8_e4m3fn
|
||||
).flatten(),
|
||||
).reshape(*hidden_states_fp4.shape[:-1], -1),
|
||||
gemm1_weights=layer.w13_weight.data,
|
||||
gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
|
||||
gemm1_bias=None,
|
||||
@@ -432,7 +432,7 @@ def flashinfer_trtllm_fp4_routed_moe(
|
||||
hidden_states=hidden_states_fp4,
|
||||
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
||||
torch.float8_e4m3fn
|
||||
).flatten(),
|
||||
).reshape(*hidden_states_fp4.shape[:-1], -1),
|
||||
gemm1_weights=layer.w13_weight.data,
|
||||
gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
|
||||
gemm1_bias=None,
|
||||
|
||||
Reference in New Issue
Block a user