fix: replace nested tl.where with sum-of-comparisons for E2M1 quantization
Triton can't compile deeply nested tl.where. Use arithmetic instead: idx = sum(abs_s >= threshold_i) for 7 threshold values.
This commit is contained in:
@@ -340,15 +340,13 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
|
||||
scaled = tl.minimum(scaled, 6.0)
|
||||
|
||||
abs_s = tl.abs(scaled)
|
||||
# Thresholds: midpoints between [0, 0.5, 1, 1.5, 2, 3, 4, 6]
|
||||
# [0, 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0, INF]
|
||||
e2m1_idx = tl.where(abs_s < 0.25, 0,
|
||||
tl.where(abs_s < 0.75, 1,
|
||||
tl.where(abs_s < 1.25, 2,
|
||||
tl.where(abs_s < 1.75, 3,
|
||||
tl.where(abs_s < 2.5, 4,
|
||||
tl.where(abs_s < 3.5, 5,
|
||||
tl.where(abs_s < 5.0, 6, 7)))))))
|
||||
# E2M1 quantization using arithmetic instead of nested tl.where (Triton compile error)
|
||||
# LUT: [0, 0.5, 1, 1.5, 2, 3, 4, 6] → thresholds at midpoints
|
||||
# idx = sum(abs_s >= threshold_i) for thresholds [0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0]
|
||||
e2m1_idx = ((abs_s >= 0.25).to(tl.int32) + (abs_s >= 0.75).to(tl.int32) +
|
||||
(abs_s >= 1.25).to(tl.int32) + (abs_s >= 1.75).to(tl.int32) +
|
||||
(abs_s >= 2.5).to(tl.int32) + (abs_s >= 3.5).to(tl.int32) +
|
||||
(abs_s >= 5.0).to(tl.int32))
|
||||
sign_bit = (scaled < 0).to(tl.int32)
|
||||
e2m1_4bit = (sign_bit << 3) | e2m1_idx # 4-bit: (sign << 3) | index
|
||||
|
||||
|
||||
Reference in New Issue
Block a user