more fixes7
This commit is contained in:
@@ -150,11 +150,10 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
|
||||
|
||||
# Pack E2M1 pairs into single bytes (2 per byte, low nibble first)
|
||||
# mxf4nvf4 reads FP4 packed from SMEM — must match kernel's TMA layout
|
||||
# Reshape to pairs instead of strided indexing (Triton doesn't support
|
||||
# [0::2] on reshaped tensors — unsupported tensor index error)
|
||||
e2m1_pairs = tl.reshape(e2m1_4bit, [BLOCK_K // 2, 2])
|
||||
e2m1_lo = e2m1_pairs[:, 0] # even indices → low nibble
|
||||
e2m1_hi = e2m1_pairs[:, 1] # odd indices → high nibble
|
||||
# e2m1_4bit is [num_groups, GROUP_K] — stride within each group (row-major
|
||||
# layout means within-group striding pairs the same elements as flat striding)
|
||||
e2m1_lo = e2m1_4bit[:, 0::2] # even within group → low nibble
|
||||
e2m1_hi = e2m1_4bit[:, 1::2] # odd within group → high nibble
|
||||
e2m1_packed = (e2m1_hi << 4 | e2m1_lo).to(tl.uint8) # [BLOCK_K // 2]
|
||||
|
||||
k_offsets_out = k_block_id * (BLOCK_K // 2) + tl.arange(0, BLOCK_K // 2)
|
||||
|
||||
Reference in New Issue
Block a user