fix staging kernel packed_k_mask double-count

This commit is contained in:
2026-05-12 08:08:24 +00:00
parent 5ea5b579c3
commit 5840291ea3

View File

@@ -113,7 +113,7 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
e2m1_packed = (e2m1_hi << 4 | e2m1_lo).to(tl.uint8) # [BLOCK_K // 2]
k_offsets_out = k_block_id * (BLOCK_K // 2) + tl.arange(0, BLOCK_K // 2)
k_mask_out = (k_block_id * BLOCK_K // 2 + k_offsets_out) < (hidden_size // 2)
k_mask_out = k_offsets_out < (hidden_size // 2)
tl.store(
x_fp4 + token_id * x_stride_m + k_offsets_out * x_stride_k,
e2m1_packed,