- nvfp4_mega_moe_full: drop-in replacement for deep_gemm.mega.fp8_nvfp4_mega_moe - transform_nvfp4_weights_for_mega_moe: weight transformation (tested) - SymmBuffer + get_symm_buffer_for_nvfp4_mega_moe: API-matching stubs - MEGA_MOE_STATIC=1 support for pipeline testing - pyproject.toml for pip install
33 lines
1.1 KiB
Mojo
33 lines
1.1 KiB
Mojo
"""
|
|
NVFP4 weight transformation and SF layout utilities.
|
|
|
|
Port of deep_gemm.mega.transform_nvfp4_weights_for_mega_moe
|
|
"""
|
|
|
|
from math import ceil_div
|
|
|
|
fn fold_global_scale_into_block_scales(
|
|
weight_scale: Tensor[float8_e4m3fn], # (N, K//16) UE4M3 block scales
|
|
weight_scale_2: Tensor[float32], # (num_logical,) or scalar global scale
|
|
logical_widths: List[int], # per-logical-weight row counts
|
|
) -> Tensor[float32]:
|
|
"""Fold global scale into block scales: UE4M3 * FP32 -> FP32"""
|
|
# Convert UE4M3 to float32, multiply by global scale
|
|
# For MergedColumnParallelLinear, expand per-logical global scale
|
|
...
|
|
|
|
fn pack_ue4m3_to_int32(sf: Tensor[float8_e4m3fn]) -> Tensor[int32]:
|
|
"""Pack 4 UE4M3 values (4 bytes) into one int32 for DeepGEMM TMA"""
|
|
# View as uint8, pack 4 consecutive bytes into int32
|
|
...
|
|
|
|
fn transform_sf_into_required_layout(
|
|
sf_mn: Tensor[int32], # MN-major packed SF
|
|
N: int, K: int,
|
|
recipe: Tuple[int, int], # (gran_mn, gran_k)
|
|
num_groups: int,
|
|
) -> Tensor[int32]:
|
|
"""Transform SF into TMA-aligned UTCCP layout for DeepGEMM"""
|
|
# Call into DeepGEMM's C++ layout transform
|
|
...
|