fix staging kernel packed_k_mask double-count
This commit is contained in:
@@ -113,7 +113,7 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
|
||||
e2m1_packed = (e2m1_hi << 4 | e2m1_lo).to(tl.uint8) # [BLOCK_K // 2]
|
||||
|
||||
k_offsets_out = k_block_id * (BLOCK_K // 2) + tl.arange(0, BLOCK_K // 2)
|
||||
k_mask_out = (k_block_id * BLOCK_K // 2 + k_offsets_out) < (hidden_size // 2)
|
||||
k_mask_out = k_offsets_out < (hidden_size // 2)
|
||||
tl.store(
|
||||
x_fp4 + token_id * x_stride_m + k_offsets_out * x_stride_k,
|
||||
e2m1_packed,
|
||||
|
||||
Reference in New Issue
Block a user