[NVIDIA] [feat] Integrate flashinfer Trtllmgen bf16 moe (#32954)

Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
This commit is contained in:
Linda
2026-01-29 19:00:13 +01:00
committed by GitHub
parent 8c8ebeb941
commit 0493d897c4
5 changed files with 290 additions and 17 deletions

View File

@@ -195,6 +195,81 @@ def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) ->
return backend in backends_supporting_global_sf
def convert_moe_weights_to_flashinfer_trtllm_block_layout(
cache_permute_indices: dict[torch.Size, torch.Tensor],
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert expert weights to FlashInfer's block layout.
This reorders W13 and W2 into the expected epilogue-tiled block layout and
returns the shuffled weight tensors.
"""
if w13_weight.dtype != torch.bfloat16 or w2_weight.dtype != torch.bfloat16:
raise ValueError(
"Unquantized Moe Backend FlashInfer TRTLLM requires bfloat16 weights"
)
from flashinfer.fused_moe.core import (
_maybe_get_cached_w3_w1_permute_indices,
convert_to_block_layout,
get_w2_permute_indices_with_cache,
)
epilogue_tile_m = 128
block_k = 128
# Reorder rows of W13 and W2 for fused gated activation and convert to the
# block layout expected by the FlashInfer kernel.
num_experts = w13_weight.shape[0]
device_w13 = w13_weight.device
device_w2 = w2_weight.device
w13_weights_shuffled: list[torch.Tensor] = []
w2_weights_shuffled: list[torch.Tensor] = []
for i in range(num_experts):
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
cache_permute_indices,
w13_weight[i].view(torch.uint8),
epilogue_tile_m,
)
tmp_weights1 = (
w13_weight[i]
.clone()
.view(torch.uint8)[permute_indices.to(device_w13)]
.contiguous()
)
permute_indices = get_w2_permute_indices_with_cache(
cache_permute_indices,
w2_weight[i].view(torch.uint8),
epilogue_tile_m,
)
tmp_weights2 = (
w2_weight[i]
.clone()
.view(torch.uint8)[permute_indices.to(device_w2)]
.contiguous()
)
tmp_weights1 = convert_to_block_layout(tmp_weights1.view(torch.uint8), block_k)
tmp_weights2 = convert_to_block_layout(tmp_weights2.view(torch.uint8), block_k)
w13_weights_shuffled.append(tmp_weights1.view(torch.bfloat16))
w2_weights_shuffled.append(tmp_weights2.view(torch.bfloat16))
# Stack weights for all experts and return as BF16 tensors.
w13_weights_shuffled_tensor = (
torch.stack(w13_weights_shuffled).view(torch.bfloat16).contiguous()
)
w2_weights_shuffled_tensor = (
torch.stack(w2_weights_shuffled).view(torch.bfloat16).contiguous()
)
return w13_weights_shuffled_tensor, w2_weights_shuffled_tensor
def align_fp8_moe_weights_for_fi(
w13: torch.Tensor, w2: torch.Tensor, is_act_and_mul: bool
) -> tuple[torch.Tensor, torch.Tensor, int]: