diff --git a/fp8_to_nvfp4_streaming.py b/fp8_to_nvfp4_streaming.py index bdbcf08..6a668f4 100644 --- a/fp8_to_nvfp4_streaming.py +++ b/fp8_to_nvfp4_streaming.py @@ -54,22 +54,30 @@ from tqdm import tqdm # Classification: which tensors do we quantize, which do we preserve? # --------------------------------------------------------------------------- +# NVFP4-everything: only preserve 1D/non-weight tensors that can't be NVFP4 PRESERVE_REGEXES = [ - r".*lm_head.*", - r".*embed_tokens.*", - r".*\.(mlp|ffn)\.gate(\.weight)?$", # MoE router (NOT gate_proj) - r".*norm.*", - r".*indexer.*", # V3.2 DSA / V4 CSA indexer - r".*hyper_conn.*", # V4 mHC - r".*\.mhc.*", - r".*hc_attn.*", # V4 hyper-connection attn - r".*hc_ffn.*", # V4 hyper-connection ffn - r".*hc_head.*", # V4 hyper-connection head - r".*scoring.*", - r".*attn_sink.*", # V4 attention sink - r".*compressor\.ape.*", # V4 compressor absolute pos encoding - r".*tid2eid.*", # V4 MoE token-to-expert mapping + r".*embed_tokens.*", # embeddings (kept in original precision) + r".*\.(mlp|ffn)\.gate(\.weight)?$", # MoE router (1D or small gate, not a GEMM weight) + r".*norm.*", # all norms (1D) + r".*indexer.*", # V4 CSA indexer (non-GEMM) + r".*scoring.*", # V4 scoring tensors + r".*attn_sink.*", # V4 attention sink (scalar/1D) + r".*compressor\.ape.*", # V4 compressor APE (1D) + r".*tid2eid.*", # V4 MoE token-to-expert mapping (1D) r".*\.bias$", # any biases + r".*hc_attn_base.*", # V4 hyper-connection scalars + r".*hc_attn_fn.*", + r".*hc_ffn_base.*", + r".*hc_ffn_fn.*", + r".*hc_head_scale.*", + r".*compressor\.wgate\.weight$", # V4 compressor gate (small, preserve) + r".*compressor\.wkv\.weight$", # V4 compressor KV proj (small, preserve) + r".*indexer\.wq_b\.weight$", # V4 indexer projections (small, preserve) + r".*indexer\.wkv\.weight$", + r".*indexer\.compressor\.wkv\.weight$", + r".*indexer\.gate_proj\.weight$", + r".*indexer\.compressor\.wgate\.weight$", + r".*indexer\.q_b_proj\.weight$", ] PRESERVE_RE = re.compile("|".join(f"(?:{p})" for p in PRESERVE_REGEXES))