Files
nvfp4-megamoe-kernel/src/layout.mojo
biondizzle c2b752c2fe Initial: TileLang NVFP4 mega_moe kernel package
- nvfp4_mega_moe_full: drop-in replacement for deep_gemm.mega.fp8_nvfp4_mega_moe
- transform_nvfp4_weights_for_mega_moe: weight transformation (tested)
- SymmBuffer + get_symm_buffer_for_nvfp4_mega_moe: API-matching stubs
- MEGA_MOE_STATIC=1 support for pipeline testing
- pyproject.toml for pip install
2026-05-13 15:44:51 +00:00

33 lines
1.1 KiB
Mojo

"""
NVFP4 weight transformation and SF layout utilities.
Port of deep_gemm.mega.transform_nvfp4_weights_for_mega_moe
"""
from math import ceil_div
fn fold_global_scale_into_block_scales(
weight_scale: Tensor[float8_e4m3fn], # (N, K//16) UE4M3 block scales
weight_scale_2: Tensor[float32], # (num_logical,) or scalar global scale
logical_widths: List[int], # per-logical-weight row counts
) -> Tensor[float32]:
"""Fold global scale into block scales: UE4M3 * FP32 -> FP32"""
# Convert UE4M3 to float32, multiply by global scale
# For MergedColumnParallelLinear, expand per-logical global scale
...
fn pack_ue4m3_to_int32(sf: Tensor[float8_e4m3fn]) -> Tensor[int32]:
"""Pack 4 UE4M3 values (4 bytes) into one int32 for DeepGEMM TMA"""
# View as uint8, pack 4 consecutive bytes into int32
...
fn transform_sf_into_required_layout(
sf_mn: Tensor[int32], # MN-major packed SF
N: int, K: int,
recipe: Tuple[int, int], # (gran_mn, gran_k)
num_groups: int,
) -> Tensor[int32]:
"""Transform SF into TMA-aligned UTCCP layout for DeepGEMM"""
# Call into DeepGEMM's C++ layout transform
...