diff --git a/patches/staging_kernel.py b/patches/staging_kernel.py index 9942eb4..d409d14 100644 --- a/patches/staging_kernel.py +++ b/patches/staging_kernel.py @@ -113,7 +113,7 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel( e2m1_packed = (e2m1_hi << 4 | e2m1_lo).to(tl.uint8) # [BLOCK_K // 2] k_offsets_out = k_block_id * (BLOCK_K // 2) + tl.arange(0, BLOCK_K // 2) - k_mask_out = (k_block_id * BLOCK_K // 2 + k_offsets_out) < (hidden_size // 2) + k_mask_out = k_offsets_out < (hidden_size // 2) tl.store( x_fp4 + token_id * x_stride_m + k_offsets_out * x_stride_k, e2m1_packed,