diff --git a/Dockerfile b/Dockerfile index 759ca288..2c47495e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,7 +13,6 @@ ENV CUDA_HOME=/usr/local/cuda ENV TORCH_CUDA_ARCH_LIST="10.0" # Clone latest CUTLASS (has NVFP4 block-scaled MMA support) -ARG CUTLASS_CACHE_BUSTER=1 RUN git clone --depth 1 https://github.com/NVIDIA/cutlass.git /root/cutlass # Copy and install the NVFP4 mega_moe kernel (from this repo) @@ -34,7 +33,6 @@ RUN pip install tilelang ENV PYTHONPATH="/root/nvfp4-megamoe-kernel/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm:/root/nvfp4-megamoe-kernel:${PYTHONPATH}" # Copy vLLM patches -ARG PATCH_CACHE_BUSTER=82 COPY vllm/patches/deepseek_v4.py /tmp/patches/deepseek_v4.py COPY vllm/patches/staging_kernel.py /tmp/patches/staging_kernel.py COPY vllm/patches/deepseek_v4_attention.py /tmp/patches/deepseek_v4_attention.py diff --git a/build_and_run.sh b/build_and_run.sh new file mode 100755 index 00000000..ab4fbd71 --- /dev/null +++ b/build_and_run.sh @@ -0,0 +1,19 @@ +#!/bin/bash +set -euo pipefail + +cd "$(dirname "$0")" + +# Bust any ARG cache busters in Dockerfile by replacing with timestamp +TIMESTAMP=$(date +%s) +sed -i -E "s/ARG [A-Z_]+CACHE_BUSTER=.*/ARG CACHE_BUSTER=${TIMESTAMP}/" Dockerfile + +echo "=== Stopping existing container ===" +docker compose down 2>/dev/null || true + +echo "=== Building (no cache) ===" +docker compose build --no-cache + +echo "=== Starting ===" +docker compose up -d + +echo "=== Done. Container: $(docker compose ps -q) ===" diff --git a/scripts/serve_vllm.py b/scripts/serve_vllm.py deleted file mode 100644 index 0aed589e..00000000 --- a/scripts/serve_vllm.py +++ /dev/null @@ -1,100 +0,0 @@ -#!/usr/bin/env python3 -""" -DeepSeek V4 Pro NVFP4 — vLLM OpenAI-compatible server. - -Run from the venv on the B200 node: - source /root/nvidia-meeting/venv/bin/activate - python3 /root/nvidia-meeting/deepseek-v4-quant/scripts/serve_vllm.py - -Or in the background: - nohup python3 /root/nvidia-meeting/deepseek-v4-quant/scripts/serve_vllm.py \ - > /root/nvidia-meeting/vllm_serve.log 2>&1 & -""" - -import subprocess -import sys - -# ── Patch: Add compress_ratios to DeepseekV4Config ────────────────────────── -# transformers 5.8.0 renamed compress_ratios → compress_rates (dict format). -# vllm 0.20.2 still expects compress_ratios as a list indexed by layer_id. -# We patch the Config class to expose compress_ratios as a property that -# converts the new dict format back to the list format vllm expects. -import transformers -try: - from transformers import DeepseekV4Config - - _orig_init = DeepseekV4Config.__init__ - - def _patched_init(self, *args, **kwargs): - _orig_init(self, *args, **kwargs) - # If compress_ratios already exists as a list, leave it alone - if hasattr(self, 'compress_ratios') and isinstance(self.compress_ratios, list): - return - # Convert compress_rates dict → compress_ratios list - if hasattr(self, 'compress_rates') and isinstance(self.compress_rates, dict): - rates = self.compress_rates - # Build per-layer list from the dict schema - # V4 pattern: layers 0-1=128, then alternating 4/128, last=0 - n_layers = getattr(self, 'num_hidden_layers', 61) - cr = rates.get('compressed_sparse_attention', 4) - hr = rates.get('heavily_compressed_attention', 128) - ratios = [] - for i in range(n_layers): - if i < 2: - ratios.append(hr) - elif i == n_layers - 1: - ratios.append(0) - else: - ratios.append(cr if i % 2 == 0 else hr) - self.compress_ratios = ratios - elif hasattr(self, 'compress_rates') and isinstance(self.compress_rates, list): - self.compress_ratios = self.compress_rates - - DeepseekV4Config.__init__ = _patched_init - print("✓ Patched DeepseekV4Config.__init__ to add compress_ratios") -except ImportError: - print("⚠ DeepseekV4Config not found, skipping compress_ratios patch") - -MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" - -# These flags are critical for V4 — do not change without understanding why: -# --trust-remote-code V4 needs custom modeling code -# --kv-cache-dtype fp8 Match our kv_cache_qformat=fp8_cast quantization -# --block-size 256 V4 recommended block size -# --enable-expert-parallel Distribute expert computation across GPUs (critical for 256-expert MoE) -# --tensor-parallel-size 8 8× B200 -# --compilation-config CUDA graphs for throughput — FULL_AND_PIECEWISE + all custom ops -# --attention_config FP4 indexer cache for V4 MLA attention -# --moe-backend deep_gemm_mega_moe — optimized MoE kernel for Blackwell -# --tokenizer-mode deepseek_v4 — V4-specific tokenizer -# --tool-call-parser deepseek_v4 — native tool calling -# --enable-auto-tool-choice Auto tool choice for function calling -# --reasoning-parser deepseek_v4 — reasoning/thinking output parsing -# --speculative_config MTP speculative decoding (2 speculative tokens) - -cmd = [ - sys.executable, "-m", "vllm.entrypoints.openai.api_server", - "--model", MODEL, - "--trust-remote-code", - "--kv-cache-dtype", "fp8", - "--block-size", "256", - "--enable-expert-parallel", - "--tensor-parallel-size", "8", - "--compilation-config", '{"cudagraph_mode":"FULL_AND_PIECEWISE", "custom_ops":["all"]}', - "--attention_config.use_fp4_indexer_cache=True", - "--moe-backend", "deep_gemm_mega_moe", # WARN: No NVFP4 mega_moe kernel. Use docker-compose (omits this flag) instead. - "--tokenizer-mode", "deepseek_v4", - "--tool-call-parser", "deepseek_v4", - "--enable-auto-tool-choice", - "--reasoning-parser", "deepseek_v4", - "--speculative_config", '{"method":"mtp","num_speculative_tokens":2}', - "--host", "0.0.0.0", - "--port", "8000", -] - -print(f"Starting vLLM server for {MODEL}") -print(f"Command: {' '.join(cmd)}") -print(f"Log: /root/nvidia-meeting/vllm_serve.log") -print() - -sys.exit(subprocess.call(cmd)) diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index 1d5c04ff..4e40fdaf 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -283,11 +283,8 @@ def nvfp4_mega_moe_full( # vLLM's symm_buffer stores global IDs (0..383) but our weight tensors # are indexed by local ID (0..47). Each rank handles a contiguous chunk: # rank r gets experts [r*E_per_rank, (r+1)*E_per_rank). - # We derive the start index from the first global ID that maps to local 0. num_experts_per_rank = l1_w.shape[0] - # Find experts_start_idx: the minimum global ID that this rank handles. - # All topk_ids in the buffer should fall within this rank's range. - experts_start_idx = (topk_ids.min().item() // num_experts_per_rank) * num_experts_per_rank + experts_start_idx = symm_buffer.experts_start_idx topk_ids_local = topk_ids - experts_start_idx if MEGA_MOE_DEBUG: diff --git a/src/nvfp4_megamoe_kernel/symm_buffer.py b/src/nvfp4_megamoe_kernel/symm_buffer.py index a0282dff..84be7ae6 100644 --- a/src/nvfp4_megamoe_kernel/symm_buffer.py +++ b/src/nvfp4_megamoe_kernel/symm_buffer.py @@ -31,6 +31,7 @@ class SymmBuffer: self.top_k = top_k self.hidden_size = hidden_size self.intermediate_size = intermediate_size + self.experts_start_idx = 0 # set by caller before kernel invocation device = torch.cuda.current_device() diff --git a/vllm/patches/deepseek_v4.py b/vllm/patches/deepseek_v4.py index 6ea027fb..9de3ef6b 100644 --- a/vllm/patches/deepseek_v4.py +++ b/vllm/patches/deepseek_v4.py @@ -206,236 +206,9 @@ class DeepseekV4FP8Config(Fp8Config): return isinstance(layer, FusedMoE) and self.expert_dtype == "fp4" -import triton -import triton.language as tl -import torch - -""" -NVFP4 staging kernel — full FP4 (E2M1) activations + UE4M3 block16 scales. - -The mxf4nvf4 PTX instruction requires BOTH A and B to be FP4 (E2M1 packed). -This kernel quantizes BF16 activations → E2M1 packed uint8 with UE4M3 scales. -""" - - -@triton.jit -def _deepseek_v4_stage_mega_moe_inputs_kernel( - hidden_states, - x_fp4, # uint8, shape (M, K//2) — E2M1 packed, 2 values per byte - x_sf, # int32, shape (M, K//64) — UE4M3 packed, 4 scales per int32 - topk_ids, - topk_weights, - topk_idx_out, - topk_weights_out, - hidden_stride_m: tl.constexpr, - hidden_stride_k: tl.constexpr, - x_stride_m: tl.constexpr, - x_stride_k: tl.constexpr, - x_sf_stride_m: tl.constexpr, - x_sf_stride_k: tl.constexpr, - topk_ids_stride_m: tl.constexpr, - topk_ids_stride_k: tl.constexpr, - topk_weights_stride_m: tl.constexpr, - topk_weights_stride_k: tl.constexpr, - topk_idx_stride_m: tl.constexpr, - topk_idx_stride_k: tl.constexpr, - topk_weights_out_stride_m: tl.constexpr, - topk_weights_out_stride_k: tl.constexpr, - hidden_size: tl.constexpr, - top_k: tl.constexpr, - BLOCK_K: tl.constexpr, # 128 elements (loaded from hidden) - GROUP_K: tl.constexpr, # 16 (NVFP4 group_size) - BLOCK_TOPK: tl.constexpr, -) -> None: - token_id = tl.program_id(0) - k_block_id = tl.program_id(1) - - k_offsets = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K) - k_mask = k_offsets < hidden_size - hidden = tl.load( - hidden_states + token_id * hidden_stride_m + k_offsets * hidden_stride_k, - mask=k_mask, - other=0.0, - ).to(tl.float32) - - num_groups: tl.constexpr = BLOCK_K // GROUP_K # 8 - hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K]) - abs_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K]) - amax = tl.max(abs_groups, axis=1) - amax = tl.maximum(amax, 1.0e-4) - - # ---- UE4M3 scale computation ---- - # scale = amax / 6.0 (E2M1 max value = 6) - # Then quantize scale to UE4M3 format - scale = amax / 6.0 - scale_bits = scale.to(tl.uint32, bitcast=True) - scale_exp = (scale_bits >> 23) & 0xFF - scale_mant = scale_bits & 0x7FFFFF - - # Convert FP32 → E4M3 manually - e4m3_exp = scale_exp - 120 # FP32 bias=127, E4M3 bias=7 - e4m3_exp = tl.maximum(e4m3_exp, 0) - e4m3_exp = tl.minimum(e4m3_exp, 15) - e4m3_mant = scale_mant >> 20 - round_bit = (scale_mant >> 19) & 1 - e4m3_mant = e4m3_mant + round_bit - overflow = e4m3_mant >= 8 - e4m3_mant = tl.where(overflow, 0, e4m3_mant) - e4m3_exp = tl.where(overflow, e4m3_exp + 1, e4m3_exp) - e4m3_exp = tl.minimum(e4m3_exp, 15) - scale_e4m3_bits = (e4m3_exp << 3) | e4m3_mant - - # Reconstruct dequantized scale for E2M1 quantization - e4m3_exp_for_recon = tl.maximum(e4m3_exp.to(tl.int32) - 7, -126) - two_pow_exp_bits = (e4m3_exp_for_recon + 127).to(tl.uint32) << 23 - two_pow_exp = two_pow_exp_bits.to(tl.float32, bitcast=True) - normal_value = (1.0 + e4m3_mant.to(tl.float32) / 8.0) * two_pow_exp - subnormal_value = (e4m3_mant.to(tl.float32) / 8.0) * 0.015625 - e4m3_value = tl.where(e4m3_exp == 0, subnormal_value, normal_value) - - # ---- E2M1 FP4 quantization ---- - # E2M1 LUT (unsigned): [0, 0.5, 1, 1.5, 2, 3, 4, 6] - # Nearest-neighbor using thresholds (midpoints between consecutive values) - scaled = hidden_groups * (1.0 / tl.maximum(e4m3_value, 1e-6))[:, None] - # Clamp to E2M1 range [-6, 6] - scaled = tl.maximum(scaled, -6.0) - scaled = tl.minimum(scaled, 6.0) - - abs_s = tl.abs(scaled) - # E2M1 quantization using arithmetic instead of nested tl.where (Triton compile error) - # LUT: [0, 0.5, 1, 1.5, 2, 3, 4, 6] → thresholds at midpoints - # idx = sum(abs_s >= threshold_i) for thresholds [0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0] - e2m1_idx = ((abs_s >= 0.25).to(tl.int32) + (abs_s >= 0.75).to(tl.int32) + - (abs_s >= 1.25).to(tl.int32) + (abs_s >= 1.75).to(tl.int32) + - (abs_s >= 2.5).to(tl.int32) + (abs_s >= 3.5).to(tl.int32) + - (abs_s >= 5.0).to(tl.int32)) - sign_bit = (scaled < 0).to(tl.int32) - e2m1_4bit = (sign_bit << 3) | e2m1_idx # 4-bit: (sign << 3) | index - - # Pack 2 E2M1 values per byte: even→low nibble, odd→high nibble - PACKED_K: tl.constexpr = BLOCK_K // 2 # 64 - e2m1_pairs = tl.reshape(e2m1_4bit, [PACKED_K, 2]) - even, odd = tl.split(e2m1_pairs) # splits last axis (size 2) into two [PACKED_K] tensors - packed_byte = (odd.to(tl.uint8) << 4) | even.to(tl.uint8) - - packed_k_offsets = k_block_id * PACKED_K + tl.arange(0, PACKED_K) - packed_k_mask = packed_k_offsets < (hidden_size // 2) - tl.store( - x_fp4 + token_id * x_stride_m + packed_k_offsets * x_stride_k, - packed_byte, - mask=packed_k_mask, - ) - - # Pack UE4M3 bytes into int32 (NVFP4: group_size=16, 4 groups per 64 elements) - # 8 groups per k_block of 128 → 2 int32s per k_block - # int32 can only pack 4 bytes (shifts >= 32 are UB on GPU), so split into two packs - scale_offsets = tl.arange(0, num_groups) # [0..7] - first_half = scale_offsets < 4 # groups 0-3 → int32[0] - second_half = scale_offsets >= 4 # groups 4-7 → int32[1] - - packed_lo = tl.sum( - tl.where(first_half, scale_e4m3_bits.to(tl.int32) << (scale_offsets * 8), 0), - axis=0, - ).to(tl.int32) - packed_hi = tl.sum( - tl.where(second_half, scale_e4m3_bits.to(tl.int32) << ((scale_offsets - 4) * 8), 0), - axis=0, - ).to(tl.int32) - - # Write 2 int32s per k_block: x_sf shape is (M, K//64) = (M, num_k_blocks * 2) - sf_base = token_id * x_sf_stride_m + k_block_id * 2 * x_sf_stride_k - tl.store(x_sf + sf_base, packed_lo) - tl.store(x_sf + sf_base + x_sf_stride_k, packed_hi) - - if k_block_id == 0: - topk_offsets = tl.arange(0, BLOCK_TOPK) - topk_mask = topk_offsets < top_k - - ids = tl.load( - topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k, - mask=topk_mask, - other=0, - ).to(tl.int64) - tl.store( - topk_idx_out - + token_id * topk_idx_stride_m - + topk_offsets * topk_idx_stride_k, - ids, - mask=topk_mask, - ) - - weights = tl.load( - topk_weights - + token_id * topk_weights_stride_m - + topk_offsets * topk_weights_stride_k, - mask=topk_mask, - other=0.0, - ) - tl.store( - topk_weights_out - + token_id * topk_weights_out_stride_m - + topk_offsets * topk_weights_out_stride_k, - weights, - mask=topk_mask, - ) - - -def _stage_deepseek_v4_mega_moe_inputs( - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - x_fp4: torch.Tensor, # uint8, shape (M, K//2) - x_sf: torch.Tensor, # int32, shape (M, K//64) - topk_idx_out: torch.Tensor, - topk_weights_out: torch.Tensor, -) -> None: - num_tokens, hidden_size = hidden_states.shape - if num_tokens == 0: - return - if hidden_size % 128 != 0: - raise ValueError( - "DeepSeek V4 MegaMoE input staging requires hidden_size to be " - "a multiple of 128." - ) - top_k = topk_ids.shape[1] - if topk_weights.shape != topk_ids.shape: - raise ValueError( - "DeepSeek V4 MegaMoE input staging requires topk_weights and " - "topk_ids to have the same shape." - ) - - block_k = 128 - grid = (num_tokens, triton.cdiv(hidden_size, block_k)) - block_topk = triton.next_power_of_2(top_k) - _deepseek_v4_stage_mega_moe_inputs_kernel[grid]( - hidden_states, - x_fp4, - x_sf, - topk_ids, - topk_weights, - topk_idx_out, - topk_weights_out, - hidden_states.stride(0), - hidden_states.stride(1), - x_fp4.stride(0), - x_fp4.stride(1), - x_sf.stride(0), - x_sf.stride(1), - topk_ids.stride(0), - topk_ids.stride(1), - topk_weights.stride(0), - topk_weights.stride(1), - topk_idx_out.stride(0), - topk_idx_out.stride(1), - topk_weights_out.stride(0), - topk_weights_out.stride(1), - hidden_size, - top_k, - BLOCK_K=block_k, - GROUP_K=16, # NVFP4: group_size=16 (scale_vec::4X) - BLOCK_TOPK=block_topk, - num_warps=4, - ) +# Staging kernel imported from standalone module (has proper RNE rounding +# and subnormal handling for UE4M3 — the old embedded copy was broken). +from staging_kernel import _stage_deepseek_v4_mega_moe_inputs def make_deepseek_v4_expert_params_mapping( @@ -757,6 +530,7 @@ class DeepseekV4MegaMoEExperts(nn.Module): import nvfp4_megamoe_kernel as deep_gemm symm_buffer = self.get_symm_buffer() + symm_buffer.experts_start_idx = self.experts_start_idx num_tokens = hidden_states.shape[0] _stage_deepseek_v4_mega_moe_inputs( hidden_states, @@ -780,7 +554,7 @@ class DeepseekV4MegaMoEExperts(nn.Module): sf_sample = symm_buffer.x_sf[:num_tokens] print(f"[MEGA_MOE_DEBUG] x range: min={x_sample.min().item()} max={x_sample.max().item()}") if sf_sample.numel() > 0: - print(f"[MEGA_MOE_DEBUG] x_sf range: min={sf_sample.to(torch.float32).min().item()} max={sf_sample.to(torch.float32).max().item}") + print(f"[MEGA_MOE_DEBUG] x_sf range: min={sf_sample.to(torch.float32).min().item()} max={sf_sample.to(torch.float32).max().item()}") topk_sample = symm_buffer.topk_idx[:num_tokens] print(f"[MEGA_MOE_DEBUG] topk_idx range: min={topk_sample.min().item()} max={topk_sample.max().item()}") torch.cuda.synchronize() @@ -906,7 +680,7 @@ class DeepseekV4MoE(nn.Module): raise NotImplementedError( "DeepSeek V4 MegaMoE currently supports sqrtsoftplus routing only." ) - # NVFP4 experts work with mega_moe via NVFP4→MXFP4 conversion in finalize_weights + # NVFP4 experts work with mega_moe via NVFP4 weight transformation in finalize_weights self.gate = GateLinear( config.hidden_size, @@ -1051,6 +825,12 @@ class DeepseekV4MoE(nn.Module): activation_clamp=activation_clamp, ) + # EP all-reduce: each rank only computes its local experts, + # so we must sum across EP ranks to get the full routed output. + torch.distributed.all_reduce( + final_hidden_states, group=self.ep_group.device_group + ) + if self.shared_experts is not None: shared_output = self.shared_experts(hidden_states) final_hidden_states += shared_output @@ -1779,7 +1559,7 @@ class DeepseekV4Model(nn.Module): - compressor.fused_wkv_wgate: Dequant NVFP4->bf16 (used via direct torch.mm in attention parallel stream) - shared_experts (gate_up_proj, down_proj): Stay native NVFP4 via DeepGEMM fp8_fp4_gemm - - MoE experts: Handled by DeepseekV4MegaMoEExperts (NVFP4→MXFP4) + - MoE experts: Handled by DeepseekV4MegaMoEExperts (NVFP4 native) """ E2M1_LUT = torch.tensor( [0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16 @@ -2404,10 +2184,28 @@ class DeepseekV4ForCausalLM(nn.Module): return getattr(self.model, "_mtp_hidden_buffer", None) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self, skip_substrs=["mtp."]) - loaded_params = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + # Use the model-level loader which handles NVFP4 expert mapping, + # uint8→bf16 unpacking for MergedColumnParallelLinear, and + # bf16→NVFP4 quantization for unquantized layers. + # AutoWeightsLoader bypasses this logic and would break NVFP4 loading. + loaded_params = self.model.load_weights(weights) self.model.finalize_mega_moe_weights() self.model._convert_nvfp4_post_load() + if int(os.environ.get('NVFP4_DEBUG', '0')): + # Count loaded expert weights to catch silent load failures + for i, layer in enumerate(self.model.layers): + ffn = layer.ffn + if hasattr(ffn, 'experts') and hasattr(ffn.experts, 'w13_weight'): + w13 = ffn.experts.w13_weight + w13_sf = ffn.experts.w13_weight_scale + w13_sf2 = ffn.experts.w13_weight_scale_2 + w2 = ffn.experts.w2_weight + n_experts = w13.shape[0] + nonzero_w13 = (w13.abs().amax(dim=(1,2)) > 0).sum().item() + nonzero_w2 = (w2.abs().amax(dim=(1,2)) > 0).sum().item() + print(f"[NVFP4_DEBUG] Layer {i}: {nonzero_w13}/{n_experts} w13 nonzero, " + f"{nonzero_w2}/{n_experts} w2 nonzero, " + f"w13_sf shape={w13_sf.shape}, w13_sf2 shape={w13_sf2.shape}") if os.environ.get('NVFP4_DEBUG_SYNC', '') == '1': torch.cuda.synchronize() print("[NVFP4] post-load conversion done, CUDA OK") diff --git a/vllm/patches/staging_kernel.py b/vllm/patches/staging_kernel.py index 3bc0fb06..679f3a97 100644 --- a/vllm/patches/staging_kernel.py +++ b/vllm/patches/staging_kernel.py @@ -106,7 +106,12 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel( e4m3_mant = tl.where(overflow, 0, e4m3_mant) e4m3_exp = tl.where(overflow, e4m3_exp + 1, e4m3_exp) e4m3_exp = tl.maximum(e4m3_exp, 0) - e4m3_exp = tl.minimum(e4m3_exp, 15) + # Saturation: E4M3FN reserves exp=15 for Inf/NaN (0x7F = NaN). + # Clamp to max representable finite value (exp=14, mant=7 = 0x77 = 448.0). + # This matches PyTorch's .to(torch.float8_e4m3fn) behavior. + sat = e4m3_exp >= 15 + e4m3_exp = tl.where(sat, 14, e4m3_exp) + e4m3_mant = tl.where(sat, 7, e4m3_mant) scale_e4m3_bits = (e4m3_exp << 3) | e4m3_mant # Reconstruct dequantized scale by decoding the STORED E4M3 bits. @@ -187,7 +192,7 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel( topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k, mask=topk_mask, other=0, - ).to(tl.int64) + ).to(tl.int32) tl.store( topk_idx_out + token_id * topk_idx_stride_m