[NVIDIA] [feat] Integrate flashinfer Trtllmgen bf16 moe (#32954)
Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
This commit is contained in:
@@ -195,6 +195,81 @@ def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) ->
|
||||
return backend in backends_supporting_global_sf
|
||||
|
||||
|
||||
def convert_moe_weights_to_flashinfer_trtllm_block_layout(
|
||||
cache_permute_indices: dict[torch.Size, torch.Tensor],
|
||||
w13_weight: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert expert weights to FlashInfer's block layout.
|
||||
|
||||
This reorders W13 and W2 into the expected epilogue-tiled block layout and
|
||||
returns the shuffled weight tensors.
|
||||
"""
|
||||
if w13_weight.dtype != torch.bfloat16 or w2_weight.dtype != torch.bfloat16:
|
||||
raise ValueError(
|
||||
"Unquantized Moe Backend FlashInfer TRTLLM requires bfloat16 weights"
|
||||
)
|
||||
|
||||
from flashinfer.fused_moe.core import (
|
||||
_maybe_get_cached_w3_w1_permute_indices,
|
||||
convert_to_block_layout,
|
||||
get_w2_permute_indices_with_cache,
|
||||
)
|
||||
|
||||
epilogue_tile_m = 128
|
||||
block_k = 128
|
||||
|
||||
# Reorder rows of W13 and W2 for fused gated activation and convert to the
|
||||
# block layout expected by the FlashInfer kernel.
|
||||
num_experts = w13_weight.shape[0]
|
||||
device_w13 = w13_weight.device
|
||||
device_w2 = w2_weight.device
|
||||
|
||||
w13_weights_shuffled: list[torch.Tensor] = []
|
||||
w2_weights_shuffled: list[torch.Tensor] = []
|
||||
|
||||
for i in range(num_experts):
|
||||
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
|
||||
cache_permute_indices,
|
||||
w13_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
tmp_weights1 = (
|
||||
w13_weight[i]
|
||||
.clone()
|
||||
.view(torch.uint8)[permute_indices.to(device_w13)]
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
permute_indices = get_w2_permute_indices_with_cache(
|
||||
cache_permute_indices,
|
||||
w2_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
tmp_weights2 = (
|
||||
w2_weight[i]
|
||||
.clone()
|
||||
.view(torch.uint8)[permute_indices.to(device_w2)]
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
tmp_weights1 = convert_to_block_layout(tmp_weights1.view(torch.uint8), block_k)
|
||||
tmp_weights2 = convert_to_block_layout(tmp_weights2.view(torch.uint8), block_k)
|
||||
|
||||
w13_weights_shuffled.append(tmp_weights1.view(torch.bfloat16))
|
||||
w2_weights_shuffled.append(tmp_weights2.view(torch.bfloat16))
|
||||
|
||||
# Stack weights for all experts and return as BF16 tensors.
|
||||
w13_weights_shuffled_tensor = (
|
||||
torch.stack(w13_weights_shuffled).view(torch.bfloat16).contiguous()
|
||||
)
|
||||
w2_weights_shuffled_tensor = (
|
||||
torch.stack(w2_weights_shuffled).view(torch.bfloat16).contiguous()
|
||||
)
|
||||
|
||||
return w13_weights_shuffled_tensor, w2_weights_shuffled_tensor
|
||||
|
||||
|
||||
def align_fp8_moe_weights_for_fi(
|
||||
w13: torch.Tensor, w2: torch.Tensor, is_act_and_mul: bool
|
||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||||
|
||||
Reference in New Issue
Block a user