From 27dbf2850f459a57396d9002ba0b5e0aae15f3da Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 11 May 2026 22:23:05 +0000 Subject: [PATCH] fix: replace nested tl.where with sum-of-comparisons for E2M1 quantization Triton can't compile deeply nested tl.where. Use arithmetic instead: idx = sum(abs_s >= threshold_i) for 7 threshold values. --- patches/deepseek_v4.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/patches/deepseek_v4.py b/patches/deepseek_v4.py index 50b8303..5958d08 100644 --- a/patches/deepseek_v4.py +++ b/patches/deepseek_v4.py @@ -340,15 +340,13 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel( scaled = tl.minimum(scaled, 6.0) abs_s = tl.abs(scaled) - # Thresholds: midpoints between [0, 0.5, 1, 1.5, 2, 3, 4, 6] - # [0, 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0, INF] - e2m1_idx = tl.where(abs_s < 0.25, 0, - tl.where(abs_s < 0.75, 1, - tl.where(abs_s < 1.25, 2, - tl.where(abs_s < 1.75, 3, - tl.where(abs_s < 2.5, 4, - tl.where(abs_s < 3.5, 5, - tl.where(abs_s < 5.0, 6, 7))))))) + # E2M1 quantization using arithmetic instead of nested tl.where (Triton compile error) + # LUT: [0, 0.5, 1, 1.5, 2, 3, 4, 6] → thresholds at midpoints + # idx = sum(abs_s >= threshold_i) for thresholds [0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0] + e2m1_idx = ((abs_s >= 0.25).to(tl.int32) + (abs_s >= 0.75).to(tl.int32) + + (abs_s >= 1.25).to(tl.int32) + (abs_s >= 1.75).to(tl.int32) + + (abs_s >= 2.5).to(tl.int32) + (abs_s >= 3.5).to(tl.int32) + + (abs_s >= 5.0).to(tl.int32)) sign_bit = (scaled < 0).to(tl.int32) e2m1_4bit = (sign_bit << 3) | e2m1_idx # 4-bit: (sign << 3) | index