WIP: CuTeDSL shared expert kernel
Dedicated runner (shared_expert_pipeline.py) and test (test_shared_expert.py). Tried reusing MoE runner with 1 expert — fails because MoE runner assumes hidden_size != HC_DIM for scatter. Need dedicated runner with correct scale assembly. Will continue tomorrow.
This commit is contained in:
159
CURRENT_BUG.md
159
CURRENT_BUG.md
@@ -1,68 +1,113 @@
|
||||
# Current Bug: vLLM produces empty/garbage output
|
||||
|
||||
**Status:** Weights confirmed good — bug is in vLLM's quant pipeline for attention
|
||||
**Status:** Debugging, plan revised — building our own kernels
|
||||
**Date:** 2026-05-18
|
||||
|
||||
## Symptom
|
||||
- vLLM server starts, loads model, processes requests (200 OK)
|
||||
- Chat completions return `content: ""` with `finish_reason: "length"`
|
||||
- 20 completion tokens generated but all produce empty/NaN logits
|
||||
- With enforce-eager + diagnostics: **NaN from layer 0 onward** on real requests
|
||||
|
||||
## ✅ Confirmed: Weights produce valid output
|
||||
## What we know
|
||||
|
||||
Standalone test (`test_attn_moe_chain.py`) running directly on B200:
|
||||
### ✅ Confirmed working
|
||||
- **MoE expert CuTeDSL kernel** — cosine 0.988, cudagraph-safe, production-ready
|
||||
- **All NVFP4 weights dequantize correctly to BF16** — standalone test proves it
|
||||
- **Full attention weight chain produces valid output** (embed → q_a → norm → q_b → o_a → o_b)
|
||||
- **Post-quant fix runs at the right time** — patched `utils.py` calls `_post_quant_fix()` after `process_weights_after_loading`
|
||||
- **183 attention projections dequantized to BF16** (61 layers × 3 projs)
|
||||
|
||||
| Step | Operation | amax | NaN? |
|
||||
|------|-----------|------|------|
|
||||
| 1 | Embed tokens | 1.27 | No |
|
||||
| 2 | hc_mult expansion | 1.27 | No |
|
||||
| 3 | RMSNorm | 0.20 | No |
|
||||
| 4 | q_a_proj (NVFP4→BF16 dequant + matmul) | 0.50 | No |
|
||||
| 5 | kv_proj (NVFP4→BF16 dequant + matmul) | 1.30 | No |
|
||||
| 6 | q_norm + kv_norm | 0.11 / 1.87 | No |
|
||||
| 7 | q_b_proj (NVFP4→BF16 dequant + matmul) | 1.10 | No |
|
||||
| 8 | MoE CuTeDSL runner (with warmup gs) | cosine 0.988 | No |
|
||||
### ❌ Still broken
|
||||
- Even with BF16 attention, model produces empty output
|
||||
- Shared experts also use `FlashInferCutlassNvFp4LinearKernel` with broken `input_scale`
|
||||
- Added shared experts to BF16 dequant fix (122 more projections) — **testing in progress**
|
||||
|
||||
**Every step produces valid, non-NaN, non-zero output.** The problem is NOT the weights.
|
||||
### 🔥 The real problem: vLLM's NVFP4 kernels are untrustworthy on B200
|
||||
|
||||
## ❌ Root Cause: vLLM's `process_weights_after_loading` breaks attention
|
||||
We spent the entire day fighting vLLM's `FlashInferCutlassNvFp4LinearKernel`:
|
||||
- Broken `input_scale` → NaN
|
||||
- `process_weights_after_loading` timing issues
|
||||
- Forward hooks not firing due to torch.compile/model wrappers
|
||||
- Dequant-to-BF16 workaround is a bandaid that loses NVFP4 benefits
|
||||
|
||||
### The timeline
|
||||
**We could have built our own kernel in the time we spent debugging theirs.**
|
||||
|
||||
1. `load_weights()` → our `_convert_nvfp4_post_load()` runs
|
||||
2. `process_weights_after_loading()` → vLLM's quant method runs AFTER, **overwriting our fixes**
|
||||
3. `FlashInferCutlassNvFp4LinearKernel` gets set up with broken `input_global_scale_inv`
|
||||
## Revised Plan: Our Own NVFP4 Kernels
|
||||
|
||||
### What the quant method does
|
||||
**Goal:** Replace ALL vLLM NVFP4 kernel paths with our own CuTeDSL implementations. No more `FlashInferCutlassNvFp4LinearKernel`. No more BF16 dequant workarounds.
|
||||
|
||||
`CompressedTensorsW4A4Fp4.process_weights_after_loading()`:
|
||||
```python
|
||||
input_global_scale_inv = layer.input_scale.max() # = 0.00025141 (WRONG)
|
||||
layer.input_global_scale = 1.0 / input_global_scale_inv # = 3977.6
|
||||
layer.input_global_scale_inv = input_global_scale_inv # = 0.00025141
|
||||
layer.alpha = input_global_scale * weight_global_scale
|
||||
```
|
||||
### Phase 0: Get the BF16 fix working (current)
|
||||
- Post-quant BF16 dequant for attention + shared experts
|
||||
- Verify the model produces actual text output
|
||||
- This is the "make it work" step
|
||||
|
||||
At runtime: `scaled_fp4_quant(x, input_global_scale_inv=0.00025141)` divides by 0.00025141 → multiplies by 3977.6 → massive overflow → NaN.
|
||||
### Phase 1: CuTeDSL Shared Expert Kernel
|
||||
**Priority:** High — shared experts are the last NVFP4 component using vLLM's broken kernel
|
||||
|
||||
### Why our fixes didn't work
|
||||
**Files to create:**
|
||||
- `cutedsl/shared_expert_pipeline.py` — L1 GEMM → SiLU → re-quant → L2 GEMM
|
||||
- Same pattern as MoE but simpler: no routing, no topk, no scatter
|
||||
- `gate_up_proj` already stacked (same as MoE L1)
|
||||
- `down_proj` same as MoE L2
|
||||
- `vllm/nvfp4_shared_expert.py` — runner class
|
||||
- Cudagraph-safe (pre-allocated buffers)
|
||||
- Warmup-based gs computation (same as MoE)
|
||||
- Called from `DeepseekV4MoE.forward()` for shared expert path
|
||||
- `tests/test_shared_expert.py` — standalone test
|
||||
- Load shared expert weights from checkpoint
|
||||
- CuTeDSL vs BF16 reference (cosine)
|
||||
- Cudagraph test
|
||||
|
||||
| Attempt | Why it failed |
|
||||
|---------|---------------|
|
||||
| BF16 dequant + `UnquantizedLinearMethod` | `process_weights_after_loading` overwrites `quant_method` back to `FlashInferCutlassNvFp4LinearKernel` |
|
||||
| Fix `input_scale` before quant method | Runs too early — quant method reads `input_scale` and overwrites our value |
|
||||
| Fix `input_global_scale_inv` directly | Attribute doesn't exist yet when our code runs — it's set BY the quant method |
|
||||
**Why it's easy:** Shared experts are literally MoE with 1 expert and no routing. The CuTeDSL `ScaledGroupedGemmKernel` with `num_groups=1` is just a regular GEMM.
|
||||
|
||||
### The key insight
|
||||
### Phase 2: CuTeDSL Attention Kernel
|
||||
**Priority:** High — attention is the biggest remaining NVFP4 component
|
||||
|
||||
Our code runs **inside** `load_weights()`. The quant method's `process_weights_after_loading()` runs **after** `load_weights()` returns. Any changes we make get overwritten.
|
||||
**Components to handle:**
|
||||
- `fused_wqa_wkv` — MergedColumnParallelLinear (q_a + kv fused)
|
||||
- `wq_b` — ColumnParallelLinear (second Q projection)
|
||||
- `wo_a` — currently FP8 via fp8_einsum
|
||||
- `wo_b` — ColumnParallelLinear (output projection)
|
||||
|
||||
## Config values (corrected)
|
||||
**Design options:**
|
||||
1. **Separate GEMMs** — one CuTeDSL GEMM per projection, simplest
|
||||
2. **Fused attention GEMM** — batch all projections together (more complex, more speed)
|
||||
|
||||
**Recommended: Start with option 1.** Each projection is just a standard NVFP4 GEMM. No need to fuse. We can optimize later.
|
||||
|
||||
**Files to create:**
|
||||
- `cutedsl/attention_pipeline.py` — NVFP4 GEMMs for each attention projection
|
||||
- `vllm/nvfp4_attention.py` — runner class
|
||||
- Handles q_a_proj, kv_proj, q_b_proj, o_a_proj, o_b_proj
|
||||
- Cudagraph-safe
|
||||
- Warmup gs for each projection
|
||||
- `tests/test_attention_nvfp4.py` — standalone test
|
||||
|
||||
**Challenge:** `fused_wqa_wkv` has TWO weight_scale_2 values (one for q_a, one for kv). Need to handle dual global scales (same pattern as MoE gate+up with different gs).
|
||||
|
||||
### Phase 3: Clean up
|
||||
- Remove all BF16 dequant code
|
||||
- Remove `vllm/patches/utils.py` patch
|
||||
- Remove `_post_quant_fix()` method
|
||||
- All NVFP4 goes through our CuTeDSL kernels
|
||||
- BF16 only where it must be (SiLU activation, final scatter, embeddings)
|
||||
|
||||
## NVFP4 Kernel Coverage (Target)
|
||||
|
||||
| Component | Kernel | Status |
|
||||
|-----------|--------|--------|
|
||||
| MoE experts (L1+L2) | CuTeDSL ScaledGroupedGemm | ✅ Working |
|
||||
| Shared experts (L1+L2) | CuTeDSL standard GEMM | 🔧 Phase 1 |
|
||||
| Attention projections | CuTeDSL standard GEMM | 🔧 Phase 2 |
|
||||
| wo_a | CuTeDSL or keep FP8 | 🔧 Phase 2 |
|
||||
| Compressor | BF16 (small, not worth it) | ✅ Done |
|
||||
| KV cache | FP8 (vLLM, not our concern) | ✅ Works |
|
||||
|
||||
## Config values
|
||||
|
||||
| Parameter | Value |
|
||||
|-----------|-------|
|
||||
| head_dim | 512 (NOT 56) |
|
||||
| head_dim | 512 |
|
||||
| num_attention_heads | 128 |
|
||||
| num_key_value_heads | 1 |
|
||||
| q_lora_rank | 1536 |
|
||||
@@ -70,35 +115,7 @@ Our code runs **inside** `load_weights()`. The quant method's `process_weights_a
|
||||
| o_lora_rank | 1024 |
|
||||
| hc_mult | 4 |
|
||||
| n_routed_experts | 384 (48 per EP rank) |
|
||||
|
||||
## Next step: Post-init hook
|
||||
|
||||
The fix must run AFTER `process_weights_after_loading` and BEFORE the first inference. Options:
|
||||
|
||||
**Option A: Override `input_global_scale_inv` post-init**
|
||||
- Add a `_fix_nvfp4_activation_scales()` method
|
||||
- Call it from the right hook point (after quant method setup, before inference)
|
||||
- Compute correct `input_global_scale_inv` from BF16 warmup
|
||||
- Override the Parameter on each attention module
|
||||
|
||||
**Option B: Replace quant_method with UnquantizedLinearMethod post-init**
|
||||
- After `process_weights_after_loading`, dequant weights to BF16
|
||||
- Swap `quant_method` on attention modules to `UnquantizedLinearMethod`
|
||||
- This time the quant method won't overwrite us (it already ran)
|
||||
|
||||
**Option C: Override the quant config to skip attention modules**
|
||||
- Tell `CompressedTensorsW4A4Fp4` to skip attention projections
|
||||
- Then dequantize to BF16 ourselves
|
||||
- Cleanest but requires modifying the quant config
|
||||
|
||||
Option B is most straightforward. The quant method already ran and set up its attributes. We can then come in and replace everything with BF16.
|
||||
|
||||
## Architecture notes
|
||||
|
||||
- Attention uses MLA (Multi-head Latent Attention) with 2-step Q projection (q_a → q_b)
|
||||
- `fused_wqa_wkv` = MergedColumnParallelLinear(q_a + kv fused)
|
||||
- `wo_a` = FP8 via fp8_einsum (no input_scale, weight-only)
|
||||
- `wo_b` = standard ColumnParallelLinear
|
||||
- `hc_pre` / `hc_post` = Head-Conditioned mixing (tilelang custom ops)
|
||||
- Dummy run zeros attention output by design (`out.zero_(); return`)
|
||||
- FlashMLA handles the actual MLA attention kernel
|
||||
| shared expert gate_proj | [3072, 3584] = 11MB NVFP4 / 22MB BF16 |
|
||||
| shared expert up_proj | [3072, 3584] = 11MB NVFP4 / 22MB BF16 |
|
||||
| shared expert down_proj | [7168, 1536] = 11MB NVFP4 / 22MB BF16 |
|
||||
| shared expert total | 33MB NVFP4 / 66MB BF16 per layer, ~2GB / ~4GB total |
|
||||
|
||||
149
cutedsl/shared_expert_pipeline.py
Normal file
149
cutedsl/shared_expert_pipeline.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""CuTeDSL Shared Expert Pipeline
|
||||
|
||||
NVFP4 inference for DeepSeek V4 shared experts.
|
||||
Uses ScaledGroupedGemmKernel with num_groups=1.
|
||||
|
||||
Pipeline:
|
||||
1. Quantize activation: BF16 → NVFP4 (using warmup gs)
|
||||
2. L1 GEMM: NVFP4_act × NVFP4_weight(gate_up) → BF16
|
||||
3. SiLU(gate) * up → BF16
|
||||
4. Re-quantize: BF16 → NVFP4 (using warmup gs)
|
||||
5. L2 GEMM: NVFP4_act × NVFP4_weight(down) → BF16
|
||||
"""
|
||||
|
||||
import torch
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
|
||||
|
||||
from cutedsl.bridge import (
|
||||
quantize_activation_nvfp4,
|
||||
assemble_scales_3d_side,
|
||||
make_b_k_major,
|
||||
run_nvfp4_grouped_gemm,
|
||||
)
|
||||
|
||||
|
||||
class CuTeDSLSharedExpertRunner:
|
||||
"""NVFP4 shared expert runner using CuTeDSL GEMM (num_groups=1)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
max_num_tokens: int,
|
||||
device: str = "cuda",
|
||||
swiglu_limit: float = 10.0,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.device = device
|
||||
self.swiglu_limit = swiglu_limit
|
||||
|
||||
# Weights (set after construction, then call finalize_weights)
|
||||
self.l1_fp4 = None
|
||||
self.l1_sf = None
|
||||
self.l1_gs = None
|
||||
self.l2_fp4 = None
|
||||
self.l2_sf = None
|
||||
self.l2_gs = None
|
||||
|
||||
# Processed weights (set by finalize_weights)
|
||||
self._l1_mat_b = None
|
||||
self._l2_mat_b = None
|
||||
self._l1_scale_b = None
|
||||
self._l2_scale_b = None
|
||||
self._l1_gsb = None
|
||||
self._l2_gsb = None
|
||||
|
||||
# Activation global scales (set by compute_activation_global_scales)
|
||||
self._l1_activation_global_scale = 1.0 / 2688
|
||||
self._l2_activation_global_scale = 1.0 / 2688
|
||||
|
||||
print(f"[CLAWMINE] SharedExpert init: hidden={hidden_size} intermediate={intermediate_size} "
|
||||
f"max_tokens={max_num_tokens} pid={os.getpid()}")
|
||||
|
||||
def set_swiglu_limit(self, limit: float):
|
||||
self.swiglu_limit = limit
|
||||
|
||||
def finalize_weights(self):
|
||||
"""Process weights for CuTeDSL GEMM. Must be called after setting l1/l2 weights."""
|
||||
# Stack weights and convert to K-major
|
||||
self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4))
|
||||
self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4))
|
||||
self._l1_scale_b = assemble_scales_3d_side(self.l1_sf)
|
||||
self._l2_scale_b = assemble_scales_3d_side(self.l2_sf)
|
||||
self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device)
|
||||
self._l2_gsb = torch.tensor(self.l2_gs, dtype=torch.float32, device=self.device)
|
||||
|
||||
# Free raw weights
|
||||
self.l1_fp4 = None
|
||||
self.l1_sf = None
|
||||
self.l1_gs = None
|
||||
self.l2_fp4 = None
|
||||
self.l2_sf = None
|
||||
self.l2_gs = None
|
||||
|
||||
def compute_activation_global_scales(self, hidden_states):
|
||||
"""Compute activation global scales from a warmup forward."""
|
||||
with torch.no_grad():
|
||||
act_amax = hidden_states.amax().item()
|
||||
self._l1_activation_global_scale = 1.0 / (act_amax * 1.5) if act_amax > 0 else 1.0 / 2688
|
||||
|
||||
# Run L1 to get intermediate
|
||||
l1_out = self._run_l1(hidden_states)
|
||||
if l1_out is not None and not torch.isnan(l1_out).any():
|
||||
gate = l1_out[:, :self.intermediate_size]
|
||||
up = l1_out[:, self.intermediate_size:]
|
||||
gate_silu = torch.nn.functional.silu(gate).clamp(max=self.swiglu_limit)
|
||||
up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit)
|
||||
intermediate = gate_silu * up
|
||||
int_amax = intermediate.amax().item()
|
||||
self._l2_activation_global_scale = 1.0 / (int_amax * 1.5) if int_amax > 0 else 1.0 / 2688
|
||||
|
||||
def _run_l1(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""L1 GEMM: activation × gate_up_weight → BF16."""
|
||||
gs = self._l1_activation_global_scale
|
||||
|
||||
# Quantize activation
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(hidden_states, gs)
|
||||
|
||||
# A-side scales: combine activation block scales with weight block scales
|
||||
# For num_groups=1, we need to match dimensions
|
||||
# x_sf shape: [num_tokens, hidden_size // 16]
|
||||
# l1_scale_b shape: [1, num_scale_rows, hidden_size]
|
||||
# Need to pad x_sf to match l1_scale_b dimensions
|
||||
|
||||
# Run GEMM
|
||||
out = run_nvfp4_grouped_gemm(
|
||||
x_fp4, self._l1_mat_b[0], x_sf, self._l1_scale_b[0],
|
||||
num_groups=1,
|
||||
)
|
||||
return out
|
||||
|
||||
def _run_l2(self, intermediate: torch.Tensor) -> torch.Tensor:
|
||||
"""L2 GEMM: intermediate × down_weight → BF16."""
|
||||
gs = self._l2_activation_global_scale
|
||||
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(intermediate, gs)
|
||||
|
||||
out = run_nvfp4_grouped_gemm(
|
||||
x_fp4, self._l2_mat_b[0], x_sf, self._l2_scale_b[0],
|
||||
num_groups=1,
|
||||
)
|
||||
return out
|
||||
|
||||
def run(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""Full shared expert forward: L1 → SiLU → L2 → output."""
|
||||
l1_out = self._run_l1(hidden_states)
|
||||
|
||||
gate = l1_out[:, :self.intermediate_size]
|
||||
up = l1_out[:, self.intermediate_size:]
|
||||
gate_silu = torch.nn.functional.silu(gate).clamp(max=self.swiglu_limit)
|
||||
up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit)
|
||||
intermediate = gate_silu * up
|
||||
|
||||
output = self._run_l2(intermediate)
|
||||
return output
|
||||
158
tests/test_shared_expert.py
Normal file
158
tests/test_shared_expert.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Standalone test: Shared expert using CuTeDSL MoE runner with 1 expert.
|
||||
|
||||
The shared expert is just "1 expert, no routing, top_k=1".
|
||||
We reuse the existing CuTeDSLMoERunner with num_experts=1.
|
||||
|
||||
Usage: python3 test_shared_expert.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import sys, os, json
|
||||
from safetensors import safe_open
|
||||
|
||||
MODEL_PATH = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
||||
DEVICE = "cuda:0"
|
||||
LAYER_IDX = 0
|
||||
HIDDEN_SIZE = 7168
|
||||
HC_MULT = 4
|
||||
HC_DIM = HC_MULT * HIDDEN_SIZE
|
||||
INTERMEDIATE_SIZE = 3072
|
||||
SWIGLU_LIMIT = 10.0
|
||||
NUM_TOKENS = 4
|
||||
|
||||
E2M1_LUT = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6., -0., -0.5, -1., -1.5, -2., -3., -4., -6.],
|
||||
dtype=torch.float32)
|
||||
|
||||
_cache = {}
|
||||
|
||||
def load_tensor(key, wm, model_dir):
|
||||
if key in _cache:
|
||||
return _cache[key]
|
||||
shard_path = os.path.join(model_dir, wm[key])
|
||||
with safe_open(shard_path, framework="pt") as f:
|
||||
t = f.get_tensor(key)
|
||||
_cache[key] = t
|
||||
return t
|
||||
|
||||
|
||||
def dequant_nvfp4(packed_uint8, scale_e4m3, global_scale):
|
||||
device = packed_uint8.device
|
||||
lut = E2M1_LUT.to(device)
|
||||
lower = lut[(packed_uint8 & 0x0F).long()]
|
||||
upper = lut[((packed_uint8 >> 4) & 0x0F).long()]
|
||||
out_features = packed_uint8.shape[0]
|
||||
in_features = packed_uint8.shape[1] * 2
|
||||
unpacked = torch.empty(out_features, in_features, dtype=torch.float32, device=device)
|
||||
unpacked[:, 0::2] = lower
|
||||
unpacked[:, 1::2] = upper
|
||||
block_scale = scale_e4m3.float()
|
||||
block_expanded = block_scale.repeat_interleave(16, dim=1)[:out_features, :in_features]
|
||||
return (unpacked * block_expanded * global_scale).to(torch.bfloat16)
|
||||
|
||||
|
||||
def main():
|
||||
torch.cuda.set_device(0)
|
||||
torch.manual_seed(42)
|
||||
|
||||
sys.path.insert(0, "/root/nvfp4-megamoe-kernel")
|
||||
from vllm.nvfp4_cutedsl import CuTeDSLMoERunner
|
||||
|
||||
with open(os.path.join(MODEL_PATH, "model.safetensors.index.json")) as f:
|
||||
wm = json.load(f)["weight_map"]
|
||||
P = lambda key: load_tensor(key, wm, MODEL_PATH).to(DEVICE)
|
||||
|
||||
print("=== Shared Expert Test (CuTeDSL MoE runner, 1 expert) ===\n")
|
||||
|
||||
# Load shared expert weights
|
||||
prefix = f"model.layers.{LAYER_IDX}.mlp.shared_experts"
|
||||
|
||||
gate_w = P(f"{prefix}.gate_proj.weight")
|
||||
gate_sf = P(f"{prefix}.gate_proj.weight_scale")
|
||||
gate_gs = P(f"{prefix}.gate_proj.weight_scale_2").item()
|
||||
up_w = P(f"{prefix}.up_proj.weight")
|
||||
up_sf = P(f"{prefix}.up_proj.weight_scale")
|
||||
up_gs = P(f"{prefix}.up_proj.weight_scale_2").item()
|
||||
down_w = P(f"{prefix}.down_proj.weight")
|
||||
down_sf = P(f"{prefix}.down_proj.weight_scale")
|
||||
down_gs = P(f"{prefix}.down_proj.weight_scale_2").item()
|
||||
|
||||
print(f"gate_proj: shape={gate_w.shape} gs={gate_gs:.8f}")
|
||||
print(f"up_proj: shape={up_w.shape} gs={up_gs:.8f}")
|
||||
print(f"down_proj: shape={down_w.shape} gs={down_gs:.8f}")
|
||||
|
||||
# Stack gate + up into gate_up_proj (same format as MoE L1)
|
||||
gate_up_w = torch.cat([gate_w, up_w], dim=0)
|
||||
gate_up_sf = torch.cat([gate_sf, up_sf], dim=0)
|
||||
mgs = max(gate_gs, up_gs)
|
||||
if gate_gs != up_gs:
|
||||
sf32 = gate_up_sf.float()
|
||||
sf32[:, :INTERMEDIATE_SIZE] *= (gate_gs / mgs)
|
||||
sf32[:, INTERMEDIATE_SIZE:] *= (up_gs / mgs)
|
||||
gate_up_sf = sf32.to(torch.float8_e4m3fn)
|
||||
|
||||
# Convert to CuTeDSL format
|
||||
l1_fp4 = gate_up_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
|
||||
l1_sf = gate_up_sf.permute(1, 0).contiguous()
|
||||
l2_fp4 = down_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
|
||||
l2_sf = down_sf.permute(1, 0).contiguous()
|
||||
|
||||
# Create MoE runner with 1 expert
|
||||
runner = CuTeDSLMoERunner(
|
||||
num_experts=1, hidden_size=HC_DIM,
|
||||
intermediate_size=INTERMEDIATE_SIZE, max_num_tokens=8192,
|
||||
top_k=1, device=DEVICE,
|
||||
)
|
||||
runner.l1_fp4 = [l1_fp4]
|
||||
runner.l1_sf = [l1_sf]
|
||||
runner.l1_gs = [mgs]
|
||||
runner.l2_fp4 = [l2_fp4]
|
||||
runner.l2_sf = [l2_sf]
|
||||
runner.l2_gs = [down_gs]
|
||||
runner.set_swiglu_limit(SWIGLU_LIMIT)
|
||||
|
||||
# Warmup
|
||||
dummy = torch.randn(NUM_TOKENS, HC_DIM, dtype=torch.bfloat16, device=DEVICE) * 2.0
|
||||
dummy_topk_ids = torch.zeros(NUM_TOKENS, 1, dtype=torch.int64, device=DEVICE)
|
||||
dummy_topk_weights = torch.ones(NUM_TOKENS, 1, dtype=torch.float32, device=DEVICE)
|
||||
runner.compute_activation_global_scales(dummy, dummy_topk_weights, dummy_topk_ids)
|
||||
print(f"Warmup gs: L1={runner._l1_activation_global_scale:.6f} L2={runner._l2_activation_global_scale:.6f}")
|
||||
|
||||
# Run CuTeDSL
|
||||
print("\n--- CuTeDSL Forward ---")
|
||||
hidden = torch.randn(NUM_TOKENS, HC_DIM, dtype=torch.bfloat16, device=DEVICE) * 2.0
|
||||
topk_ids = torch.zeros(NUM_TOKENS, 1, dtype=torch.int64, device=DEVICE)
|
||||
topk_weights = torch.ones(NUM_TOKENS, 1, dtype=torch.float32, device=DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
output = runner.run(hidden, topk_weights, topk_ids)
|
||||
print(f"CuTeDSL output: amax={output.amax():.4f} NaN={torch.isnan(output).any()}")
|
||||
|
||||
# BF16 reference
|
||||
print("\n--- BF16 Reference ---")
|
||||
gate_bf16 = dequant_nvfp4(gate_w, gate_sf, gate_gs)
|
||||
up_bf16 = dequant_nvfp4(up_w, up_sf, up_gs)
|
||||
down_bf16 = dequant_nvfp4(down_w, down_sf, down_gs)
|
||||
|
||||
with torch.no_grad():
|
||||
gate = hidden @ gate_bf16.T
|
||||
up = hidden @ up_bf16.T
|
||||
gate_silu = F.silu(gate).clamp(max=SWIGLU_LIMIT)
|
||||
up = up.clamp(min=-SWIGLU_LIMIT, max=SWIGLU_LIMIT)
|
||||
intermediate = gate_silu * up
|
||||
ref_output = intermediate @ down_bf16.T
|
||||
|
||||
print(f"BF16 ref: amax={ref_output.amax():.4f}")
|
||||
|
||||
# Compare
|
||||
cos = F.cosine_similarity(ref_output.flatten().unsqueeze(0), output.flatten().unsqueeze(0)).item()
|
||||
mse = (ref_output - output).pow(2).mean().item()
|
||||
print(f"\n=== RESULT: cosine={cos:.6f} MSE={mse:.6e} ===")
|
||||
if cos >= 0.98:
|
||||
print("✅ PASS")
|
||||
else:
|
||||
print("❌ FAIL")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user