more fixes7

This commit is contained in:
2026-05-14 20:11:37 +00:00
parent 4363eee2ce
commit 6aae8f1393

View File

@@ -150,11 +150,10 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
# Pack E2M1 pairs into single bytes (2 per byte, low nibble first)
# mxf4nvf4 reads FP4 packed from SMEM — must match kernel's TMA layout
# Reshape to pairs instead of strided indexing (Triton doesn't support
# [0::2] on reshaped tensors — unsupported tensor index error)
e2m1_pairs = tl.reshape(e2m1_4bit, [BLOCK_K // 2, 2])
e2m1_lo = e2m1_pairs[:, 0] # even indices → low nibble
e2m1_hi = e2m1_pairs[:, 1] # odd indices → high nibble
# e2m1_4bit is [num_groups, GROUP_K] — stride within each group (row-major
# layout means within-group striding pairs the same elements as flat striding)
e2m1_lo = e2m1_4bit[:, 0::2] # even within group → low nibble
e2m1_hi = e2m1_4bit[:, 1::2] # odd within group → high nibble
e2m1_packed = (e2m1_hi << 4 | e2m1_lo).to(tl.uint8) # [BLOCK_K // 2]
k_offsets_out = k_block_id * (BLOCK_K // 2) + tl.arange(0, BLOCK_K // 2)