fix: replace nested tl.where with sum-of-comparisons for E2M1 quantization

Triton can't compile deeply nested tl.where. Use arithmetic instead:
idx = sum(abs_s >= threshold_i) for 7 threshold values.
This commit is contained in:
2026-05-11 22:23:05 +00:00
parent 3d1f3de190
commit 27dbf2850f

View File

@@ -340,15 +340,13 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
scaled = tl.minimum(scaled, 6.0)
abs_s = tl.abs(scaled)
# Thresholds: midpoints between [0, 0.5, 1, 1.5, 2, 3, 4, 6]
# [0, 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0, INF]
e2m1_idx = tl.where(abs_s < 0.25, 0,
tl.where(abs_s < 0.75, 1,
tl.where(abs_s < 1.25, 2,
tl.where(abs_s < 1.75, 3,
tl.where(abs_s < 2.5, 4,
tl.where(abs_s < 3.5, 5,
tl.where(abs_s < 5.0, 6, 7)))))))
# E2M1 quantization using arithmetic instead of nested tl.where (Triton compile error)
# LUT: [0, 0.5, 1, 1.5, 2, 3, 4, 6] → thresholds at midpoints
# idx = sum(abs_s >= threshold_i) for thresholds [0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0]
e2m1_idx = ((abs_s >= 0.25).to(tl.int32) + (abs_s >= 0.75).to(tl.int32) +
(abs_s >= 1.25).to(tl.int32) + (abs_s >= 1.75).to(tl.int32) +
(abs_s >= 2.5).to(tl.int32) + (abs_s >= 3.5).to(tl.int32) +
(abs_s >= 5.0).to(tl.int32))
sign_bit = (scaled < 0).to(tl.int32)
e2m1_4bit = (sign_bit << 3) | e2m1_idx # 4-bit: (sign << 3) | index