[gpt-oss] Cache permute indices for faster MXFP4 MoE layer loading (#24154)

Signed-off-by: Wei Wei <wwei6@meta.com>
This commit is contained in:
Wei
2025-09-09 21:27:53 -07:00
committed by GitHub
parent 53b42f4102
commit 0efdb5c3ba
2 changed files with 145 additions and 34 deletions

View File

@@ -122,6 +122,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
"MXFP4 MoE is enabled on Blackwell but FlashInfer "
"is not available. This may result in degraded performance. "
"Please `pip install vllm[flashinfer]` for best results.")
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
def _should_use_marlin(self):
if envs.VLLM_MXFP4_USE_MARLIN is not None:
@@ -266,7 +267,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
if self.use_marlin:
prepare_moe_fp4_layer_for_marlin(layer)
elif should_use_flashinfer_mxfp4():
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
from flashinfer.fp4_quantization import (
nvfp4_block_scale_interleave)
from flashinfer.fused_moe.core import (
_maybe_get_cached_w2_permute_indices)
layer.gemm1_alpha = Parameter(torch.tensor(
[1.702] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False)
@@ -343,25 +347,63 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
gemm2_bias_shuffled = []
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
for i in range(self.num_experts):
gemm1_weights_mxfp4_shuffled.append(
shuffle_matrix_a(w13_weight[i].view(torch.uint8),
epilogue_tile_m))
# w13 weight shuffling
permute_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
w13_weight[i].view(torch.uint8),
epilogue_tile_m,
)
gemm1_weights_mxfp4_shuffled.append(w13_weight[i].view(
torch.uint8)[permute_indices.to(
w13_weight.device)].contiguous())
# w13 scale shuffling
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
w13_weight_scale[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm1_scales_mxfp4_shuffled.append(
shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8),
epilogue_tile_m))
gemm1_bias_shuffled.append(
shuffle_matrix_a(w13_bias[i].clone().reshape(-1, 1),
epilogue_tile_m))
gemm2_weights_mxfp4_shuffled.append(
shuffle_matrix_a(w2_weight[i].view(torch.uint8),
epilogue_tile_m))
nvfp4_block_scale_interleave(w13_weight_scale[i].view(
torch.uint8)[permute_sf_indices.to(
w13_weight_scale.device)].contiguous()))
# w13 bias shuffling
permute_bias_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
w13_bias[i].clone().reshape(-1, 1),
epilogue_tile_m,
)
gemm1_bias_shuffled.append(w13_bias[i].clone().reshape(
-1,
1)[permute_bias_indices.to(w13_bias.device)].contiguous())
# w2 weight shuffling
permute_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
w2_weight[i].view(torch.uint8),
epilogue_tile_m,
)
gemm2_weights_mxfp4_shuffled.append(w2_weight[i].view(
torch.uint8)[permute_indices.to(
w2_weight.device)].contiguous())
# w2 scale shuffling
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
w2_weight_scale[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm2_scales_mxfp4_shuffled.append(
shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8),
epilogue_tile_m))
gemm2_bias_shuffled.append(
shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1),
epilogue_tile_m))
nvfp4_block_scale_interleave(w2_weight_scale[i].view(
torch.uint8)[permute_sf_indices.to(
w2_weight_scale.device)].contiguous()))
# w2 bias shuffling
permute_indices = _maybe_get_cached_w2_permute_indices(
self._cache_permute_indices,
w2_bias[i].clone().reshape(-1, 1),
epilogue_tile_m,
)
gemm2_bias_shuffled.append(w2_bias[i].clone().reshape(
-1, 1)[permute_indices.to(w2_bias.device)].contiguous())
w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
w13_weight_scale = torch.stack(