400 lines
18 KiB
Python
400 lines
18 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
CSA Sparse Attention Test
|
|
|
|
Tests the csa_sparse_attention_batched function with simulated compressor output:
|
|
1. Create compressed KV cache (simulating compressor output)
|
|
2. Create topk_indices (simulating indexer output)
|
|
3. Do sparse attention on compressed KV at topk positions
|
|
4. Do SWA attention on the window
|
|
5. Merge with sink weights
|
|
6. Compare against full attention reference
|
|
|
|
Usage (on B200):
|
|
cd /root/nvfp4-megamoe-kernel
|
|
PYTHONPATH=/root/nvfp4-megamoe-kernel tests/venv/bin/python tests/test_csa_sparse_attn_b200.py
|
|
"""
|
|
|
|
import sys, os, json, torch, torch.nn.functional as F, time
|
|
from safetensors import safe_open
|
|
|
|
REPO = "/root/nvfp4-megamoe-kernel"
|
|
sys.path.insert(0, REPO)
|
|
MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
|
DEV = "cuda:0"
|
|
|
|
H = 7168; NH = 128; HD = 512; NOPE = 448; ROPE = 64
|
|
QL = 1536; OL = 1024; OG = 16; HPG = NH // OG
|
|
EPS = 1e-6; WINDOW = 128; SCALE = HD ** -0.5
|
|
|
|
E2M1 = torch.tensor([0,.5,1.,1.5,2.,3.,4.,6.,-0,-.5,-1.,-1.5,-2.,-3.,-4.,6.], dtype=torch.float32)
|
|
|
|
_cache = {}
|
|
def P(k, wm, md):
|
|
if k in _cache: return _cache[k]
|
|
with safe_open(os.path.join(md, wm[k]), framework="pt") as f:
|
|
t = f.get_tensor(k)
|
|
_cache[k] = t
|
|
return t
|
|
|
|
def rms(x, w, eps=1e-6):
|
|
v = x.float().pow(2).mean(-1, keepdim=True)
|
|
return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype)
|
|
|
|
def make_runner(w, sf, gs_t, inf, outf, fused=False, lw=None):
|
|
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
|
|
fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()
|
|
s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf
|
|
s = s.permute(1,0).contiguous()
|
|
if fused and gs_t.numel() == 2:
|
|
g1,g2 = gs_t[0].item(), gs_t[1].item(); gs = max(g1,g2)
|
|
if g1 != g2:
|
|
s32 = s.float(); sp = lw[0] if lw else outf//2
|
|
s32[:sp] *= g1/gs; s32[sp:] *= g2/gs; s = s32.to(torch.float8_e4m3fn)
|
|
else:
|
|
gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item()
|
|
r = CuTeDSLNvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device))
|
|
r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs]
|
|
r.finalize_weights(); r._ensure_initialized()
|
|
return r
|
|
|
|
def build_cos_sin(max_pos=4096, rope_dim=ROPE):
|
|
half = rope_dim // 2
|
|
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half, dtype=torch.float32) / half))
|
|
freqs = torch.outer(torch.arange(max_pos, dtype=torch.float32), inv_freq)
|
|
return torch.cat([freqs.cos(), freqs.sin()], dim=-1)
|
|
|
|
def apply_gptj_rope(x, positions, cos_sin, nope_dim, rope_dim):
|
|
if rope_dim == 0 or x.numel() == 0: return x
|
|
half = rope_dim // 2
|
|
cos = cos_sin[positions, :half].to(x.dtype)
|
|
sin = cos_sin[positions, half:2*half].to(x.dtype)
|
|
if x.dim() == 3: cos = cos.unsqueeze(1); sin = sin.unsqueeze(1)
|
|
x_rope = x[..., nope_dim:].clone()
|
|
even = x_rope[..., 0::2]; odd = x_rope[..., 1::2]
|
|
out = x.clone()
|
|
out[..., nope_dim:][..., 0::2] = even * cos - odd * sin
|
|
out[..., nope_dim:][..., 1::2] = even * sin + odd * cos
|
|
return out
|
|
|
|
def apply_inv_gptj_rope(x, positions, cos_sin, nope_dim, rope_dim):
|
|
if rope_dim == 0 or x.numel() == 0: return x
|
|
half = rope_dim // 2
|
|
cos = cos_sin[positions, :half].to(x.dtype)
|
|
sin = cos_sin[positions, half:2*half].to(x.dtype)
|
|
if x.dim() == 3: cos = cos.unsqueeze(1); sin = sin.unsqueeze(1)
|
|
x_rope = x[..., nope_dim:].clone()
|
|
even = x_rope[..., 0::2]; odd = x_rope[..., 1::2]
|
|
out = x.clone()
|
|
out[..., nope_dim:][..., 0::2] = even * cos + odd * sin
|
|
out[..., nope_dim:][..., 1::2] = -even * sin + odd * cos
|
|
return out
|
|
|
|
def kv_quantize_fp8(kv_bf16):
|
|
amax = kv_bf16.float().abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
|
|
fp8_max = torch.tensor(448.0, dtype=torch.float32, device=kv_bf16.device)
|
|
scale = fp8_max / amax
|
|
kv_fp8 = (kv_bf16.float() * scale).to(torch.float8_e4m3fn)
|
|
inv_scale = (amax / fp8_max).to(torch.bfloat16)
|
|
return kv_fp8, inv_scale
|
|
|
|
def kv_dequantize_fp8(kv_fp8, inv_scale):
|
|
return (kv_fp8.to(torch.bfloat16) * inv_scale).to(torch.bfloat16)
|
|
|
|
def causal_prefill_attention(q, kv, scale):
|
|
T, NH, HD = q.shape
|
|
q_t = q.permute(1, 0, 2)
|
|
kv_exp = kv.unsqueeze(0).expand(NH, -1, -1)
|
|
out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=True, scale=scale)
|
|
return out.permute(1, 0, 2)
|
|
|
|
|
|
def csa_sparse_gather_attention(q, compressed_kv, topk_indices, topk_lens, scale, cos_sin, nope_dim, rope_dim):
|
|
"""CSA sparse attention: gather compressed KV at topk positions, attend.
|
|
|
|
q: (T, NH, HD) with RoPE already applied
|
|
compressed_kv: (num_compressed, HD) — all compressed KV vectors
|
|
topk_indices: (T, num_topk) — which compressed positions to attend to
|
|
topk_lens: (T,) — how many of the topk_indices are valid
|
|
"""
|
|
T, NH, HD = q.shape
|
|
device = q.device
|
|
num_topk = topk_indices.shape[-1]
|
|
|
|
# Gather compressed KV at topk positions
|
|
# Clamp to valid range
|
|
safe_idx = topk_indices.clamp(min=0, max=compressed_kv.shape[0] - 1)
|
|
# (T, num_topk, HD)
|
|
k_gathered = compressed_kv[safe_idx]
|
|
|
|
# Mask invalid positions (set to 0)
|
|
valid_mask = torch.arange(num_topk, device=device).unsqueeze(0) < topk_lens.unsqueeze(1)
|
|
k_gathered = k_gathered * valid_mask.unsqueeze(-1).to(k_gathered.dtype)
|
|
|
|
# Apply RoPE to gathered K at their original (compressed) positions
|
|
if rope_dim > 0 and cos_sin is not None:
|
|
kv_positions = safe_idx # The positions in the compressed cache
|
|
# BUT: compressed position i represents the i-th group of compress_ratio tokens
|
|
# The "position" for RoPE should be the original token position, not the compressed index
|
|
# For now, use the compressed index as a proxy (this is a simplification)
|
|
# In the real pipeline, the compressor stores KV with RoPE already applied
|
|
pass # Skip RoPE for now — the compressor already applies it
|
|
|
|
# Multi-head attention: expand K for all heads
|
|
# k_gathered: (T, num_topk, HD) → (T, NH, num_topk, HD)
|
|
k_heads = k_gathered.unsqueeze(1).expand(-1, NH, -1, -1)
|
|
v_heads = k_heads.clone()
|
|
|
|
# Q: (T, NH, HD) → (T*NH, 1, HD)
|
|
q_2d = q.reshape(T * NH, 1, HD)
|
|
k_2d = k_heads.reshape(T * NH, num_topk, HD)
|
|
v_2d = v_heads.reshape(T * NH, num_topk, HD)
|
|
|
|
# Attention mask: (T, num_topk) → (T*NH, 1, num_topk)
|
|
attn_mask = valid_mask.unsqueeze(1).expand(-1, NH, -1).reshape(T * NH, 1, num_topk)
|
|
|
|
out = F.scaled_dot_product_attention(
|
|
q_2d, k_2d, v_2d,
|
|
attn_mask=attn_mask if not attn_mask.all() else None,
|
|
scale=scale,
|
|
)
|
|
|
|
return out.squeeze(1).reshape(T, NH, HD)
|
|
|
|
|
|
def swa_cache_attention(q, swa_kv_cache, inv_scale_cache, positions, block_size, scale, window_size):
|
|
"""SWA attention reading from paged KV cache.
|
|
|
|
q: (1, NH, HD) single decode token
|
|
"""
|
|
pos = positions[0].item()
|
|
all_slots = torch.arange(pos + 1, dtype=torch.int64, device=q.device)
|
|
all_bi = all_slots // block_size
|
|
all_oi = all_slots % block_size
|
|
kv_cached = swa_kv_cache[all_bi, all_oi]
|
|
if swa_kv_cache.dtype == torch.uint8:
|
|
kv_cached = kv_cached.view(torch.float8_e4m3fn)
|
|
kv_inv = inv_scale_cache[all_slots]
|
|
kv_deq = kv_dequantize_fp8(kv_cached, kv_inv)
|
|
ws = max(0, pos - window_size + 1)
|
|
kv_window = kv_deq[ws:]
|
|
NH = q.shape[1]
|
|
q_t = q.permute(1, 0, 2)
|
|
kv_exp = kv_window.unsqueeze(0).expand(NH, -1, -1)
|
|
out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=False, scale=scale)
|
|
return out.permute(1, 0, 2)
|
|
|
|
|
|
def test_csa_layer(layer_id, compress_ratio):
|
|
"""Test CSA/HCA sparse attention for a specific layer.
|
|
|
|
Simulates the full pipeline:
|
|
1. Prefill: project Q and KV, compute compressed KV, run indexer
|
|
2. Decode: sparse attention on compressed KV + SWA on window
|
|
3. Merge with sink weights
|
|
4. Compare against full attention reference
|
|
"""
|
|
torch.cuda.set_device(0)
|
|
torch.cuda.empty_cache()
|
|
|
|
with open(os.path.join(MODEL, "model.safetensors.index.json")) as f:
|
|
wm = json.load(f)["weight_map"]
|
|
G = lambda k: P(k, wm, MODEL).to(DEV)
|
|
|
|
p = f"model.layers.{layer_id}"; a = f"{p}.self_attn"
|
|
cr = compress_ratio
|
|
lt = f"C{cr}A" if cr > 1 else "SWA"
|
|
|
|
emb = G("model.embed_tokens.weight")
|
|
anorm = G(f"{p}.input_layernorm.weight")
|
|
qn = G(f"{a}.q_a_norm.weight"); kvn = G(f"{a}.kv_norm.weight")
|
|
woa = G(f"{a}.o_a_proj.weight")
|
|
sinks = G(f"{a}.sinks")
|
|
qa_w = G(f"{a}.q_a_proj.weight"); qa_sf = G(f"{a}.q_a_proj.weight_scale"); qa_gs = G(f"{a}.q_a_proj.weight_scale_2")
|
|
qb_w = G(f"{a}.q_b_proj.weight"); qb_sf = G(f"{a}.q_b_proj.weight_scale"); qb_gs = G(f"{a}.q_b_proj.weight_scale_2")
|
|
kv_w = G(f"{a}.kv_proj.weight"); kv_sf = G(f"{a}.kv_proj.weight_scale"); kv_gs = G(f"{a}.kv_proj.weight_scale_2")
|
|
wob_w = G(f"{a}.o_b_proj.weight"); wob_sf = G(f"{a}.o_b_proj.weight_scale"); wob_gs = G(f"{a}.o_b_proj.weight_scale_2")
|
|
|
|
# Compressor weights
|
|
comp_kv_w = G(f"{a}.compressor.kv_proj.weight"); comp_kv_sf = G(f"{a}.compressor.kv_proj.weight_scale"); comp_kv_gs = G(f"{a}.compressor.kv_proj.weight_scale_2")
|
|
comp_gate_w = G(f"{a}.compressor.gate_proj.weight"); comp_gate_sf = G(f"{a}.compressor.gate_proj.weight_scale"); comp_gate_gs = G(f"{a}.compressor.gate_proj.weight_scale_2")
|
|
|
|
r_qa = make_runner(qa_w, qa_sf, qa_gs, H, qa_w.shape[0])
|
|
r_qb = make_runner(qb_w, qb_sf, qb_gs, QL, qb_w.shape[0])
|
|
r_kv = make_runner(kv_w, kv_sf, kv_gs, H, kv_w.shape[0])
|
|
r_wob = make_runner(wob_w, wob_sf, wob_gs, OG*OL, wob_w.shape[0])
|
|
r_comp_kv = make_runner(comp_kv_w, comp_kv_sf, comp_kv_gs, H, comp_kv_w.shape[0])
|
|
r_comp_gate = make_runner(comp_gate_w, comp_gate_sf, comp_gate_gs, H, comp_gate_w.shape[0])
|
|
|
|
cos_sin = build_cos_sin(max_pos=4096).to(DEV)
|
|
woa_3d = woa.view(OG, OL, HPG * HD)
|
|
|
|
# Paged KV caches
|
|
block_size = 64; max_tokens = 256
|
|
num_blocks = (max_tokens + block_size - 1) // block_size
|
|
swa_cache = torch.zeros(num_blocks, block_size, HD, dtype=torch.uint8, device=DEV)
|
|
swa_inv_scale = torch.zeros(max_tokens, 1, dtype=torch.bfloat16, device=DEV)
|
|
|
|
N = 128 if cr >= 128 else 16 # Prefill tokens (use a multiple of compress_ratio)
|
|
assert N % cr == 0, f"N={N} must be multiple of compress_ratio={cr}"
|
|
token_ids = torch.arange(1, N + 1, dtype=torch.long, device=DEV)
|
|
|
|
with torch.no_grad():
|
|
# ── PREFILL ─────────────────────────────────────────────
|
|
positions = torch.arange(N, dtype=torch.int64, device=DEV)
|
|
hidden = emb[token_ids]
|
|
normed = rms(hidden, anorm, EPS)
|
|
|
|
# Project Q and KV
|
|
qa = r_qa.run(normed); kv = r_kv.run(normed)
|
|
qa_n = rms(qa, qn, EPS); kv_n = rms(kv, kvn, EPS)
|
|
q = r_qb.run(qa_n).view(N, NH, HD)
|
|
q_rope = apply_gptj_rope(q, positions, cos_sin, NOPE, ROPE)
|
|
kv_rope = apply_gptj_rope(kv_n.unsqueeze(1), positions, cos_sin, NOPE, ROPE).squeeze(1)
|
|
|
|
# Write prefill KV to SWA cache
|
|
kv_fp8, inv_s = kv_quantize_fp8(kv_rope)
|
|
bi = positions // block_size; oi = positions % block_size
|
|
swa_cache[bi, oi] = kv_fp8.view(torch.uint8)
|
|
for t in range(N):
|
|
swa_inv_scale[positions[t]] = inv_s[t]
|
|
|
|
# Compute compressed KV (simulating compressor)
|
|
# The compressor takes kv_score from the parallel GEMM, but we can
|
|
# approximate by compressing the full KV: average every cr tokens
|
|
# In reality, the compressor uses a learned projection, but for testing
|
|
# the attention mechanism, averaging is a valid approximation
|
|
num_compressed = N // cr
|
|
comp_kv = r_comp_kv.run(normed)
|
|
comp_gate_out = r_comp_gate.run(normed)
|
|
# Simple average pooling for compression
|
|
compressed_kv = kv_n.reshape(num_compressed, cr, HD).mean(dim=1) # (num_compressed, HD)
|
|
compressed_kv_rope = apply_gptj_rope(
|
|
compressed_kv.unsqueeze(1),
|
|
torch.arange(num_compressed, dtype=torch.int64, device=DEV),
|
|
cos_sin, NOPE, ROPE,
|
|
).squeeze(1)
|
|
|
|
# Simulate indexer output: topk indices
|
|
# For testing, just use the compressed positions
|
|
num_topk = min(16, num_compressed) # Use up to 16 topk positions
|
|
topk_indices = torch.arange(num_compressed, dtype=torch.int64, device=DEV).unsqueeze(0).expand(N, -1)
|
|
topk_lens = torch.full((N,), num_compressed, dtype=torch.int64, device=DEV)
|
|
|
|
# ── CSA Sparse Attention (prefill) ───────────────────────
|
|
# For prefill, we do full causal attention (simpler and correct)
|
|
o_prefill = causal_prefill_attention(q_rope, kv_rope, SCALE)
|
|
|
|
# ── DECODE ──────────────────────────────────────────────
|
|
decode_id = torch.tensor([N], dtype=torch.long, device=DEV)
|
|
pos_d = torch.tensor([N], dtype=torch.int64, device=DEV)
|
|
hidden_d = emb[decode_id]
|
|
normed_d = rms(hidden_d, anorm, EPS)
|
|
qa_d = r_qa.run(normed_d); kv_d = r_kv.run(normed_d)
|
|
qa_n_d = rms(qa_d, qn, EPS); kv_n_d = rms(kv_d, kvn, EPS)
|
|
q_d = r_qb.run(qa_n_d).view(1, NH, HD)
|
|
q_rope_d = apply_gptj_rope(q_d, pos_d, cos_sin, NOPE, ROPE)
|
|
kv_rope_d = apply_gptj_rope(kv_n_d.unsqueeze(1), pos_d, cos_sin, NOPE, ROPE).squeeze(1)
|
|
|
|
# Write decode KV to SWA cache
|
|
kv_fp8_d, inv_s_d = kv_quantize_fp8(kv_rope_d)
|
|
bi_d = pos_d[0].item() // block_size
|
|
oi_d = pos_d[0].item() % block_size
|
|
swa_cache[bi_d, oi_d] = kv_fp8_d[0].view(torch.uint8)
|
|
swa_inv_scale[pos_d[0].item()] = inv_s_d[0]
|
|
|
|
# Compute compressed KV for decode
|
|
comp_kv_d = r_comp_kv.run(normed_d)
|
|
# Append to compressed cache
|
|
num_compressed_total = num_compressed + 1
|
|
compressed_kv_all = torch.cat([compressed_kv_rope, kv_n_d], dim=0)
|
|
|
|
# Decode: sparse attention on compressed KV
|
|
topk_d = torch.arange(num_compressed_total, dtype=torch.int64, device=DEV).unsqueeze(0)
|
|
topk_lens_d = torch.tensor([num_compressed_total], dtype=torch.int64, device=DEV)
|
|
sparse_out = csa_sparse_gather_attention(
|
|
q_rope_d, compressed_kv_all, topk_d, topk_lens_d,
|
|
SCALE, cos_sin, NOPE, ROPE,
|
|
)
|
|
|
|
# Decode: SWA attention
|
|
swa_out = swa_cache_attention(
|
|
q_rope_d, swa_cache, swa_inv_scale, pos_d, block_size, SCALE, WINDOW,
|
|
)
|
|
|
|
# Merge with sink weights
|
|
sink_w = torch.sigmoid(sinks).view(1, NH, 1)
|
|
merged_out = sparse_out * (1 - sink_w) + swa_out * sink_w
|
|
|
|
# ── Reference: full causal attention on all tokens ──────
|
|
all_ids = torch.cat([token_ids, decode_id])
|
|
all_pos = torch.arange(N + 1, dtype=torch.int64, device=DEV)
|
|
hidden_ref = emb[all_ids]
|
|
normed_ref = rms(hidden_ref, anorm, EPS)
|
|
qa_ref = r_qa.run(normed_ref); kv_ref = r_kv.run(normed_ref)
|
|
qa_n_ref = rms(qa_ref, qn, EPS); kv_n_ref = rms(kv_ref, kvn, EPS)
|
|
q_ref = r_qb.run(qa_n_ref).view(N + 1, NH, HD)
|
|
q_rope_ref = apply_gptj_rope(q_ref, all_pos, cos_sin, NOPE, ROPE)
|
|
kv_rope_ref = apply_gptj_rope(kv_n_ref.unsqueeze(1), all_pos, cos_sin, NOPE, ROPE).squeeze(1)
|
|
o_ref = causal_prefill_attention(q_rope_ref, kv_rope_ref, SCALE)
|
|
o_ref_decode = o_ref[-1:] # Only the decode token
|
|
|
|
# ── Full output pipeline ────────────────────────────────
|
|
# Merged
|
|
o_inv = apply_inv_gptj_rope(merged_out, pos_d, cos_sin, NOPE, ROPE)
|
|
o_grp = o_inv.reshape(1, OG, HPG * HD).permute(1, 0, 2)
|
|
z = torch.bmm(o_grp, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(1, OG * OL)
|
|
attn_merged = r_wob.run(z)
|
|
|
|
# Reference
|
|
o_inv_ref = apply_inv_gptj_rope(o_ref_decode, pos_d, cos_sin, NOPE, ROPE)
|
|
o_grp_ref = o_inv_ref.reshape(1, OG, HPG * HD).permute(1, 0, 2)
|
|
z_ref = torch.bmm(o_grp_ref, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(1, OG * OL)
|
|
attn_ref = r_wob.run(z_ref)
|
|
|
|
# ── COMPARE ─────────────────────────────────────────────
|
|
# Note: CSA sparse attention with avg-pooled KV won't match full attention perfectly.
|
|
# But it should be > 0.5 cosine (the structure is preserved)
|
|
c_attn = F.cosine_similarity(merged_out.flatten().unsqueeze(0).float(), o_ref_decode.flatten().unsqueeze(0).float()).item()
|
|
c_full = F.cosine_similarity(attn_merged.flatten().unsqueeze(0).float(), attn_ref.flatten().unsqueeze(0).float()).item()
|
|
|
|
# Also check SWA-only (window) attention
|
|
c_swa = F.cosine_similarity(swa_out.flatten().unsqueeze(0).float(), o_ref_decode.flatten().unsqueeze(0).float()).item()
|
|
|
|
del r_qa, r_qb, r_kv, r_wob, r_comp_kv, r_comp_gate
|
|
torch.cuda.empty_cache()
|
|
_cache.clear()
|
|
|
|
return c_attn, c_full, c_swa
|
|
|
|
|
|
def main():
|
|
print("=" * 70)
|
|
print(" CSA Sparse Attention Test")
|
|
print(" Tests compressed KV gather + sparse attention + SWA merge")
|
|
print("=" * 70)
|
|
|
|
# Test C128A layer (layer 0)
|
|
c_attn, c_full, c_swa = test_csa_layer(0, 128)
|
|
print(f" Layer 0 (C128A):")
|
|
print(f" Merged (sparse+SWA) attn cosine: {c_attn:.4f}")
|
|
print(f" Full pipeline cosine: {c_full:.4f}")
|
|
print(f" SWA-only cosine: {c_swa:.4f}")
|
|
|
|
# Test C4A layer (layer 1)
|
|
c_attn, c_full, c_swa = test_csa_layer(1, 4)
|
|
print(f" Layer 1 (C4A):")
|
|
print(f" Merged (sparse+SWA) attn cosine: {c_attn:.4f}")
|
|
print(f" Full pipeline cosine: {c_full:.4f}")
|
|
print(f" SWA-only cosine: {c_swa:.4f}")
|
|
|
|
print(f"\n{'='*70}")
|
|
print(f" SWA-only cosine should be >0.98 (proven in decode vs prefill test)")
|
|
print(f" Merged cosine may be lower (avg-pooled KV is an approximation)")
|
|
print(f" The important thing: no NaN, reasonable values")
|
|
print(f"{'='*70}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|