1 Commits

View File

@@ -54,22 +54,30 @@ from tqdm import tqdm
# Classification: which tensors do we quantize, which do we preserve? # Classification: which tensors do we quantize, which do we preserve?
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# NVFP4-everything: only preserve 1D/non-weight tensors that can't be NVFP4
PRESERVE_REGEXES = [ PRESERVE_REGEXES = [
r".*lm_head.*", r".*embed_tokens.*", # embeddings (kept in original precision)
r".*embed_tokens.*", r".*\.(mlp|ffn)\.gate(\.weight)?$", # MoE router (1D or small gate, not a GEMM weight)
r".*\.(mlp|ffn)\.gate(\.weight)?$", # MoE router (NOT gate_proj) r".*norm.*", # all norms (1D)
r".*norm.*", r".*indexer.*", # V4 CSA indexer (non-GEMM)
r".*indexer.*", # V3.2 DSA / V4 CSA indexer r".*scoring.*", # V4 scoring tensors
r".*hyper_conn.*", # V4 mHC r".*attn_sink.*", # V4 attention sink (scalar/1D)
r".*\.mhc.*", r".*compressor\.ape.*", # V4 compressor APE (1D)
r".*hc_attn.*", # V4 hyper-connection attn r".*tid2eid.*", # V4 MoE token-to-expert mapping (1D)
r".*hc_ffn.*", # V4 hyper-connection ffn
r".*hc_head.*", # V4 hyper-connection head
r".*scoring.*",
r".*attn_sink.*", # V4 attention sink
r".*compressor\.ape.*", # V4 compressor absolute pos encoding
r".*tid2eid.*", # V4 MoE token-to-expert mapping
r".*\.bias$", # any biases r".*\.bias$", # any biases
r".*hc_attn_base.*", # V4 hyper-connection scalars
r".*hc_attn_fn.*",
r".*hc_ffn_base.*",
r".*hc_ffn_fn.*",
r".*hc_head_scale.*",
r".*compressor\.wgate\.weight$", # V4 compressor gate (small, preserve)
r".*compressor\.wkv\.weight$", # V4 compressor KV proj (small, preserve)
r".*indexer\.wq_b\.weight$", # V4 indexer projections (small, preserve)
r".*indexer\.wkv\.weight$",
r".*indexer\.compressor\.wkv\.weight$",
r".*indexer\.gate_proj\.weight$",
r".*indexer\.compressor\.wgate\.weight$",
r".*indexer\.q_b_proj\.weight$",
] ]
PRESERVE_RE = re.compile("|".join(f"(?:{p})" for p in PRESERVE_REGEXES)) PRESERVE_RE = re.compile("|".join(f"(?:{p})" for p in PRESERVE_REGEXES))