- CuTeDSLNvfp4Linear: generic single-GEMM runner for any NVFP4 projection - test_attention.py: tests q_a_proj, q_b_proj, kv_proj, o_b_proj vs BF16 - Same pad+swizzle pattern as shared expert, but no SiLU/fusion
174 lines
6.2 KiB
Python
174 lines
6.2 KiB
Python
"""Standalone test: Attention projections using CuTeDSL NVFP4 linear runner.
|
|
|
|
Tests q_a_proj, q_b_proj, kv_proj, o_b_proj against BF16 reference.
|
|
o_a_proj is BF16 (not NVFP4) — not tested here.
|
|
|
|
Usage: python3 test_attention.py
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import sys, os, json
|
|
from safetensors import safe_open
|
|
|
|
MODEL_PATH = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
|
DEVICE = "cuda:0"
|
|
LAYER_IDX = 0
|
|
HIDDEN_SIZE = 7168
|
|
NUM_TOKENS = 4
|
|
|
|
E2M1_LUT = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6., -0., -0.5, -1., -1.5, -2., -3., -4., -6.],
|
|
dtype=torch.float32)
|
|
|
|
_cache = {}
|
|
|
|
def load_tensor(key, wm, model_dir):
|
|
if key in _cache:
|
|
return _cache[key]
|
|
shard_path = os.path.join(model_dir, wm[key])
|
|
with safe_open(shard_path, framework="pt") as f:
|
|
t = f.get_tensor(key)
|
|
_cache[key] = t
|
|
return t
|
|
|
|
|
|
def dequant_nvfp4(packed_uint8, scale_e4m3, global_scale):
|
|
"""Dequantize NVFP4 weight to BF16 for reference."""
|
|
device = packed_uint8.device
|
|
lut = E2M1_LUT.to(device)
|
|
lower = lut[(packed_uint8 & 0x0F).long()]
|
|
upper = lut[((packed_uint8 >> 4) & 0x0F).long()]
|
|
out_features = packed_uint8.shape[0]
|
|
in_features = packed_uint8.shape[1] * 2
|
|
unpacked = torch.empty(out_features, in_features, dtype=torch.float32, device=device)
|
|
unpacked[:, 0::2] = lower
|
|
unpacked[:, 1::2] = upper
|
|
block_scale = scale_e4m3.float()
|
|
block_expanded = block_scale.repeat_interleave(16, dim=1)[:out_features, :in_features]
|
|
return (unpacked * block_expanded * global_scale).to(torch.bfloat16)
|
|
|
|
|
|
def test_projection(name, weight, weight_sf, weight_gs, hidden_states, in_features, out_features):
|
|
"""Test a single NVFP4 projection."""
|
|
sys.path.insert(0, "/root/nvfp4-megamoe-kernel")
|
|
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
|
|
|
|
# Convert weight to CuTeDSL format: (out, in_packed) uint8 → (in_packed, out) float4
|
|
fp4 = [weight.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()]
|
|
sf = [weight_sf.permute(1, 0).contiguous()]
|
|
gs = [weight_gs]
|
|
|
|
runner = CuTeDSLNvfp4Linear(
|
|
in_features=in_features,
|
|
out_features=out_features,
|
|
max_num_tokens=8192,
|
|
device=DEVICE,
|
|
)
|
|
runner.fp4 = fp4
|
|
runner.sf = sf
|
|
runner.gs = gs
|
|
runner.finalize_weights()
|
|
|
|
# Warmup
|
|
runner._ensure_initialized()
|
|
runner.compute_activation_global_scale(hidden_states)
|
|
|
|
# Run CuTeDSL
|
|
with torch.no_grad():
|
|
output = runner.run(hidden_states)
|
|
|
|
# BF16 reference
|
|
bf16_w = dequant_nvfp4(weight, weight_sf, weight_gs)
|
|
with torch.no_grad():
|
|
ref = hidden_states @ bf16_w.T
|
|
|
|
# Compare
|
|
cos = F.cosine_similarity(ref.flatten().unsqueeze(0),
|
|
output.flatten().unsqueeze(0)).item()
|
|
mse = (ref - output).pow(2).mean().item()
|
|
status = "✅" if cos >= 0.98 else "❌"
|
|
print(f" {name}: cosine={cos:.6f} MSE={mse:.6e} amax_ref={ref.amax():.4f} amax_out={output.amax():.4f} {status}")
|
|
return cos
|
|
|
|
|
|
def main():
|
|
torch.cuda.set_device(0)
|
|
torch.manual_seed(42)
|
|
|
|
with open(os.path.join(MODEL_PATH, "model.safetensors.index.json")) as f:
|
|
wm = json.load(f)["weight_map"]
|
|
P = lambda key: load_tensor(key, wm, MODEL_PATH).to(DEVICE)
|
|
|
|
prefix = f"model.layers.{LAYER_IDX}.self_attn"
|
|
|
|
print("=== Attention Projection Tests (CuTeDSL NVFP4 Linear) ===\n")
|
|
|
|
# Load weights and determine dimensions from shapes
|
|
projs = {
|
|
"q_a_proj": {"key": f"{prefix}.q_a_proj"},
|
|
"q_b_proj": {"key": f"{prefix}.q_b_proj"},
|
|
"kv_proj": {"key": f"{prefix}.kv_proj"},
|
|
"o_b_proj": {"key": f"{prefix}.o_b_proj"},
|
|
}
|
|
|
|
for name, info in projs.items():
|
|
key = info["key"]
|
|
w = P(f"{key}.weight")
|
|
sf = P(f"{key}.weight_scale")
|
|
gs = P(f"{key}.weight_scale_2").item()
|
|
out_features = w.shape[0]
|
|
in_features = w.shape[1] * 2 # unpacked
|
|
info["weight"] = w
|
|
info["sf"] = sf
|
|
info["gs"] = gs
|
|
info["in_features"] = in_features
|
|
info["out_features"] = out_features
|
|
print(f" {name}: weight={w.shape} → in={in_features} out={out_features} gs={gs:.8f}")
|
|
|
|
print()
|
|
|
|
# Test each projection
|
|
# q_a_proj: input is hidden_states (HIDDEN_SIZE=7168)
|
|
hidden = torch.randn(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) * 2.0
|
|
|
|
cos_qa = test_projection("q_a_proj", projs["q_a_proj"]["weight"],
|
|
projs["q_a_proj"]["sf"], projs["q_a_proj"]["gs"],
|
|
hidden, projs["q_a_proj"]["in_features"], projs["q_a_proj"]["out_features"])
|
|
|
|
# q_b_proj: input is q_a output (1536 features)
|
|
q_a_out_features = projs["q_a_proj"]["out_features"]
|
|
q_a_out = torch.randn(NUM_TOKENS, q_a_out_features, dtype=torch.bfloat16, device=DEVICE) * 2.0
|
|
cos_qb = test_projection("q_b_proj", projs["q_b_proj"]["weight"],
|
|
projs["q_b_proj"]["sf"], projs["q_b_proj"]["gs"],
|
|
q_a_out, projs["q_b_proj"]["in_features"], projs["q_b_proj"]["out_features"])
|
|
|
|
# kv_proj: input is hidden_states (7168)
|
|
cos_kv = test_projection("kv_proj", projs["kv_proj"]["weight"],
|
|
projs["kv_proj"]["sf"], projs["kv_proj"]["gs"],
|
|
hidden, projs["kv_proj"]["in_features"], projs["kv_proj"]["out_features"])
|
|
|
|
# o_b_proj: input is o_a output (16384 features after attention)
|
|
o_b_in_features = projs["o_b_proj"]["in_features"]
|
|
o_b_input = torch.randn(NUM_TOKENS, o_b_in_features, dtype=torch.bfloat16, device=DEVICE) * 2.0
|
|
cos_ob = test_projection("o_b_proj", projs["o_b_proj"]["weight"],
|
|
projs["o_b_proj"]["sf"], projs["o_b_proj"]["gs"],
|
|
o_b_input, projs["o_b_proj"]["in_features"], projs["o_b_proj"]["out_features"])
|
|
|
|
print(f"\n=== SUMMARY ===")
|
|
results = {"q_a_proj": cos_qa, "q_b_proj": cos_qb, "kv_proj": cos_kv, "o_b_proj": cos_ob}
|
|
all_pass = True
|
|
for name, cos in results.items():
|
|
status = "✅" if cos >= 0.98 else "❌"
|
|
if cos < 0.98:
|
|
all_pass = False
|
|
print(f" {name}: cosine={cos:.6f} {status}")
|
|
|
|
if all_pass:
|
|
print("\n✅ ALL PASS")
|
|
else:
|
|
print("\n❌ SOME FAILED")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|