NVFP4 mega MoE: sf_id=0 fix for scale_vec::4X + UINT8 TMA + SF pipeline + interleaving

Root cause of ILLEGAL_INSTRUCTION: make_runtime_instr_desc_with_sf_id(instr_desc, k, k)
passed sf_id=1 for k=1 (second UMMA atom), but mxf4nvf4 with scale_vec::4X requires
sf_id=0 always — the hardware implicitly reads 4 SF positions per atom from a single
TMEM region. Non-zero sf_id causes the hardware to access invalid TMEM offsets.

Also includes:
- UINT8 TMA for packed FP4 (avoids 16U4 driver bugs)
- NVFP4 SF pipeline: 2 K-columns per BLOCK_K for group_size=16
- MN-major SF interleaving for gate/up L1 weights
- Fix contiguous copy for SF byte view
- Preserve MN-major layout in SF interleave
- Force contiguous on SF tensors before C++ call
- Unpack weight tuples before printing
- Single transpose back to MN-major (don't double-transpose)
This commit is contained in:
2026-05-12 20:26:13 +00:00
parent 26a8ab75a1
commit 74bf612771
5 changed files with 138 additions and 10 deletions

View File

@@ -83,8 +83,10 @@ static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType&
case torch::kBFloat16: return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
case torch::kFloat8_e4m3fn: return CU_TENSOR_MAP_DATA_TYPE_UINT8;
#if CUDA_VERSION >= 12080
case kPackedFP4: return fp4_unpacked_smem ? CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B
: CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B;
case kPackedFP4: // For mxf4nvf4 packed FP4: use UINT8 TMA instead of 16U4.
// The 16U4 type causes CUDA_ERROR_INVALID_VALUE on many drivers.
// UMMA descriptor handles FP4 interpretation of SMEM.
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
#endif
default: DG_HOST_UNREACHABLE("Unsupported dtype");
}
@@ -123,10 +125,10 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
if (t.scalar_type() == kPackedFP4) {
// Inner dim must be a multiple of 64B for .b4x16_p64
DG_HOST_ASSERT(not fp4_unpacked_smem or gmem_inner_dim % 128 == 0);
// Fix FP4 packed smem
if (not fp4_unpacked_smem and swizzle_mode != 0)
smem_inner_dim = swizzle_mode * 2;
// For packed FP4 (mxf4nvf4): use UINT8 TMA instead of 16U4_ALIGN8B.
// The 16U4 TMA type is not widely supported (causes CUDA_ERROR_INVALID_VALUE).
// We load raw bytes via UINT8 and let the UMMA descriptor interpret
// the SMEM layout as packed FP4. Dimensions stay in bytes (like UINT8).
}
CUtensorMap tensor_map;

View File

@@ -157,7 +157,7 @@ static void sm100_fp8_nvfp4_mega_moe(
intermediate_hidden * 2, hidden,
config.block_n, kGranK,
num_experts_per_rank, 0);
// L1 output: packed E2M1, K-dim = intermediate_hidden/2, inner = block_n/4 bytes, no swizzle (v1)
// L1 output: packed E2M1, K-dim = intermediate_hidden/2, inner = block_n/4 bytes (SwiGLU halving × FP4 packing), no swizzle (v1)
const auto tensor_map_l1_output = make_tma_2d_desc(l2_acts,
intermediate_hidden / 2, config.num_max_pool_tokens,
config.block_n / 4, config.store_block_m,

View File

@@ -882,6 +882,8 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
// NVFP4: group=16 → 2 SF K-columns per BLOCK_K (128/16/4=2)
// Each UTCCP call moves 128 int32s → 4 TMEM cols
// We need 2 UTCCP calls per SF: one per K-column
// NOTE: No SMEM warp transpose needed — transform_sf_token_idx
// pre-arranges the data in the correct UTCCP layout via global memory
using cute_utccp_t = cute::SM100_UTCCP_4x32dp128bit_2cta;
#pragma unroll
@@ -906,8 +908,14 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
// Issue UMMA
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
// NVFP4 scale_vec::4X: sf_id must always be 0.
// The hardware implicitly reads 4 SF positions per UMMA atom
// from the single TMEM region [scale_A_tmem]/[scale_B_tmem].
// Unlike scale_vec::1X (MXFP4) where each atom needs a unique sf_id
// to index sub-columns, scale_vec::4X ignores sf_id or requires 0.
// Passing sf_id=k (k=1 for second UMMA atom) was the ILLEGAL_INSTRUCTION bug.
const auto runtime_instr_desc =
mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k, k);
mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, 0, 0);
a_desc.lo = mma::sm100::advance_umma_desc_lo<
cute::UMMA::Major::K, LOAD_BLOCK_M, kSwizzleAMode, uint8_t>(a_desc_base_lo, 0, k * (UMMA_K / 2));
b_desc.lo = mma::sm100::advance_umma_desc_lo<

View File

@@ -145,7 +145,25 @@ def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tup
up = t[:, half:].reshape(g, half // gran, gran, *rest)
return torch.empty_like(t).copy_(torch.stack([gate, up], dim=2).reshape(g, n, *rest))
return interleave(l1_weights[0]), interleave(l1_weights[1])
def interleave_sf_mn_major(t, gran: int = 8) -> torch.Tensor:
"""Interleave SF while preserving MN-major layout (stride(-2)=1, stride(-1)=TMA-aligned).
Input/Output shape: (num_groups, mn, packed_sf_k) with MN-major strides.
Interleaves the mn dimension: [gate_0..7, up_0..7, gate_8..15, up_8..15, ...]
"""
# t: (groups, mn, packed_sf_k) MN-major, stride(-2)=1
# Transpose to K-major C-contiguous for safe interleave ops
t_k = t.transpose(-2, -1).contiguous() # (groups, packed_sf_k, mn) C-contiguous
g, k, mn = t_k.shape
half = mn // 2
gate = t_k[:, :, :half].reshape(g, k, half // gran, gran)
up = t_k[:, :, half:].reshape(g, k, half // gran, gran)
interleaved_k = torch.empty(g, k, mn, dtype=t.dtype, device=t.device)
interleaved_k.copy_(torch.stack([gate, up], dim=3).reshape(g, k, mn))
# Single transpose back to MN-major: (g, mn, k) with stride(-2)=1
return interleaved_k.transpose(-2, -1)
return interleave(l1_weights[0]), interleave_sf_mn_major(l1_weights[1])
def _transpose_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor:
@@ -317,9 +335,12 @@ def fp8_nvfp4_mega_moe(y: torch.Tensor,
Activation format: E2M1 packed uint8 + UE4M3 scales (computed by staging kernel)
Recipe: (1, 1, 16) — kGranK=16 for NVFP4 group_size=16.
"""
l1_w, l1_w_sf = l1_weights
l2_w, l2_w_sf = l2_weights
_C.fp8_nvfp4_mega_moe(
y,
l1_weights, l2_weights,
(l1_w, l1_w_sf), (l2_w, l2_w_sf),
cumulative_local_expert_recv_stats,
sym_buffer.buffer,
sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(),

97
test_nvfp4_mega_moe.py Normal file
View File

@@ -0,0 +1,97 @@
"""Minimal test for fp8_nvfp4_mega_moe kernel with synthetic data."""
import torch
import torch.distributed as dist
import os
def test_nvfp4_mega_moe():
# Small dimensions that satisfy alignment requirements
# hidden and intermediate_hidden must be multiples of 128
# hidden must be divisible by 64 (for NVFP4 SF packing)
num_experts = 2
num_tokens = 4
top_k = 2
hidden = 256 # must be multiple of 128 and 64
intermediate_hidden = 512 # must be multiple of 128 and 64
device = "cuda"
torch.cuda.set_device(0)
# Create a single-rank process group for SymmBuffer
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault("MASTER_PORT", "29500")
os.environ.setdefault("RANK", "0")
os.environ.setdefault("WORLD_SIZE", "1")
if not dist.is_initialized():
dist.init_process_group("nccl")
group = dist.new_group()
from deep_gemm.mega import (
fp8_nvfp4_mega_moe,
get_symm_buffer_for_nvfp4_mega_moe,
transform_nvfp4_weights_for_mega_moe,
)
# Create random NVFP4 weights (E2M1 packed int8 + float8_e4m3fn block scales)
# w13: (num_experts, 2*intermediate_hidden, hidden//2)
w13_weight = torch.randint(0, 256, (num_experts, 2 * intermediate_hidden, hidden // 2),
dtype=torch.uint8, device=device).view(torch.int8)
w13_weight_scale = torch.randn(num_experts, 2 * intermediate_hidden, hidden // 16,
device=device).abs().clamp(0.1, 10.0).to(torch.float8_e4m3fn)
w13_weight_scale_2 = torch.randn(num_experts, device=device).abs().clamp(0.5, 2.0)
w13_input_scale = torch.ones(num_experts, device=device)
# w2: (num_experts, hidden, intermediate_hidden//2)
w2_weight = torch.randint(0, 256, (num_experts, hidden, intermediate_hidden // 2),
dtype=torch.uint8, device=device).view(torch.int8)
w2_weight_scale = torch.randn(num_experts, hidden, intermediate_hidden // 16,
device=device).abs().clamp(0.1, 10.0).to(torch.float8_e4m3fn)
w2_weight_scale_2 = torch.randn(num_experts, device=device).abs().clamp(0.5, 2.0)
w2_input_scale = torch.ones(num_experts, device=device)
# Transform weights for the kernel
l1_weights, l2_weights = transform_nvfp4_weights_for_mega_moe(
(w13_weight, w13_weight_scale),
(w2_weight, w2_weight_scale),
l1_weight_scale_2=w13_weight_scale_2,
l2_weight_scale_2=w2_weight_scale_2,
)
print(f"l1_weights: dtype={l1_weights[0].dtype} shape={l1_weights[0].shape} strides={l1_weights[0].stride()}")
print(f"l1_sf: dtype={l1_weights[1].dtype} shape={l1_weights[1].shape} strides={l1_weights[1].stride()}")
print(f"l2_weights: dtype={l2_weights[0].dtype} shape={l2_weights[0].shape} strides={l2_weights[0].stride()}")
print(f"l2_sf: dtype={l2_weights[1].dtype} shape={l2_weights[1].shape} strides={l2_weights[1].stride()}")
# Create symm buffer
symm_buffer = get_symm_buffer_for_nvfp4_mega_moe(
group, num_experts, num_tokens, top_k, hidden, intermediate_hidden)
# Create input (BF16)
hidden_states = torch.randn(num_tokens, hidden, dtype=torch.bfloat16, device=device)
# Create topk weights/ids
topk_weights = torch.softmax(torch.randn(num_tokens, top_k, device=device), dim=-1)
topk_ids = torch.randint(0, num_experts, (num_tokens, top_k), device=device)
# Stage inputs
from deepseek_v4_staging import _stage_deepseek_v4_mega_moe_inputs
# Actually, we can't import from vllm patch. Let's just manually set up the symm buffer.
# Output tensor
y = torch.zeros(num_tokens, hidden, dtype=torch.bfloat16, device=device)
# Call the kernel
print("Calling fp8_nvfp4_mega_moe...")
try:
fp8_nvfp4_mega_moe(
y,
l1_weights, l2_weights,
symm_buffer,
)
print("SUCCESS! y stats: min={:.4f} max={:.4f} mean={:.4f}".format(
y.min().item(), y.max().item(), y.mean().item()))
except Exception as e:
print(f"FAILED: {e}")
raise
if __name__ == "__main__":
test_nvfp4_mega_moe()