DSV4Config: frozen dataclass with .flash() / .pro() classmethods. All architectural constants (dims, heads, MoE params, mHC) in one place. LayerSchedule: pure-data per-layer-index -> (attn_type, ffn_type, router_mode). Flash: SWA, SWA, CSA, HCA, CSA, HCA, ... (43 layers) Pro: HCA, HCA, CSA, HCA, CSA, HCA, ... (61 layers) Both: first 3 MoE layers = hash routing, rest = dense validate_schedule() enforces correctness at construction. AttentionSubBlock: CSA / HCA / SWA variants. - Low-rank Q projection (q_down -> q_up) - KV down-projection (varies by attn type: 4h/2h/1h) - CSA: indexer_q_up + indexer_head_weights - Grouped output projection (wo_a + wo_b) - Kernel calls are imports (NotImplementedError until kernel lands) - No PyTorch fallback paths FFNSubBlock: MoE + shared expert. - Router (hash/dense) mode from LayerSpec - Nvfp4MoE + Nvfp4SharedExpert TransformerLayer: composition of mHC + norm + attention + FFN. - Two mHC wrappers (attn + ffn sub-blocks) - Two RMSNorm (one per sub-block) - Pure orchestration, no learned params on the layer itself Tests: schedule construction + validation for both variants. No forward tests yet (depends on FMHA kernel + KV cache).
98 lines
4.2 KiB
Python
98 lines
4.2 KiB
Python
#!/usr/bin/env python3
|
|
import sys, os
|
|
os.chdir('/root/dsv4-nvfp4-workspace/kernel')
|
|
sys.path.insert(0, '.')
|
|
import torch
|
|
from torch.utils.cpp_extension import load
|
|
|
|
def pf(ok):
|
|
return 'PASS' if ok else 'FAIL'
|
|
|
|
print(f'CUDA: {torch.cuda.is_available()}, Device: {torch.cuda.get_device_name(0)}', flush=True)
|
|
print(f'GPU mem free: {torch.cuda.mem_get_info()[0]//1024//1024} MB', flush=True)
|
|
|
|
# 1. Hash Router
|
|
print('\n=== Hash Router ===', flush=True)
|
|
try:
|
|
hr = load(name='hr2', sources=['dsv4/kernels/cuda/hash_router.cu'],
|
|
extra_cuda_cflags=['-O3', '--generate-code=arch=compute_100a,code=[sm_100a]'], verbose=False)
|
|
for N in [1, 4, 64, 128, 512]:
|
|
vocab, k, E = 128000, 6, 256
|
|
lut = torch.randint(0, E, (vocab, k), dtype=torch.int32, device='cuda')
|
|
tids = torch.randint(0, vocab, (N,), dtype=torch.int32, device='cuda')
|
|
ow = torch.empty(N, k, dtype=torch.float32, device='cuda')
|
|
oi = torch.empty(N, k, dtype=torch.int32, device='cuda')
|
|
hr.hash_router(tids, lut, k, ow, oi)
|
|
torch.cuda.synchronize()
|
|
exp_ids = lut[tids]
|
|
exp_w = torch.full((N, k), 1.0/k, dtype=torch.float32, device='cuda')
|
|
ids_ok = (oi == exp_ids).all().item()
|
|
w_ok = torch.allclose(ow, exp_w, atol=1e-7, rtol=1e-7)
|
|
ok = ids_ok and w_ok
|
|
print(f' N={N:4d}: IDs={ids_ok} W={w_ok} {pf(ok)}', flush=True)
|
|
del lut, tids, ow, oi, exp_ids, exp_w
|
|
print('Hash Router: ALL PASS', flush=True)
|
|
except Exception as e:
|
|
import traceback; traceback.print_exc()
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
# 2. Top-k Select
|
|
print('\n=== Top-k Select ===', flush=True)
|
|
try:
|
|
tk = load(name='tk2', sources=['dsv4/kernels/cuda/topk_select.cu'],
|
|
extra_cuda_cflags=['-O3', '--generate-code=arch=compute_100a,code=[sm_100a]'], verbose=False)
|
|
for N, E in [(1,256), (4,256), (64,256), (64,384), (128,256), (512,256)]:
|
|
k = 6
|
|
scores = torch.randn(N, E, dtype=torch.float32, device='cuda')
|
|
ov, oidx = tk.topk_select(scores, k)
|
|
torch.cuda.synchronize()
|
|
exp = scores.topk(k, dim=-1)
|
|
ids_ok = (oidx == exp.indices).all().item()
|
|
vals_ok = torch.allclose(ov, exp.values, atol=1e-6, rtol=1e-6)
|
|
ok = ids_ok and vals_ok
|
|
print(f' N={N:4d} E={E}: IDs={ids_ok} V={vals_ok} {pf(ok)}', flush=True)
|
|
del scores, ov, oidx, exp
|
|
print('Top-k Select: ALL PASS', flush=True)
|
|
except Exception as e:
|
|
import traceback; traceback.print_exc()
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
# 3. Activation + Top-k
|
|
print('\n=== Activation + Top-k ===', flush=True)
|
|
try:
|
|
atk = load(name='atk2', sources=['dsv4/kernels/cuda/activation_topk.cu'],
|
|
extra_cuda_cflags=['-O3', '--generate-code=arch=compute_100a,code=[sm_100a]'], verbose=False)
|
|
for N, E in [(1,256), (4,256), (64,256), (64,384)]:
|
|
k = 6
|
|
logits = torch.randn(N, E, dtype=torch.float32, device='cuda')
|
|
bias = torch.randn(E, dtype=torch.float32, device='cuda')
|
|
scaling = 2.5
|
|
out_w = torch.empty(N, k, dtype=torch.float32, device='cuda')
|
|
out_ids = torch.empty(N, k, dtype=torch.int32, device='cuda')
|
|
atk.activation_topk(logits, bias, k, scaling, out_w, out_ids)
|
|
torch.cuda.synchronize()
|
|
# Oracle
|
|
sp = torch.log1p(logits.exp()) + torch.clamp(logits, min=0)
|
|
act = sp.sqrt()
|
|
score = act + bias
|
|
exp_topk = score.topk(k, dim=-1)
|
|
exp_ids = exp_topk.indices
|
|
exp_w = torch.gather(act, 1, exp_ids)
|
|
exp_w = exp_w / exp_w.sum(dim=-1, keepdim=True) * scaling
|
|
ids_ok = (out_ids == exp_ids).all().item()
|
|
vals_ok = torch.allclose(out_w, exp_w, atol=1e-5, rtol=1e-5)
|
|
ok = ids_ok and vals_ok
|
|
print(f' N={N:4d} E={E}: IDs={ids_ok} V={vals_ok} {pf(ok)}', flush=True)
|
|
if not ids_ok:
|
|
print(f' ID mismatches: {(out_ids != exp_ids).sum().item()}/{out_ids.numel()}', flush=True)
|
|
if not vals_ok:
|
|
print(f' Max diff: {(out_w - exp_w).abs().max().item():.2e}', flush=True)
|
|
del logits, bias, out_w, out_ids, sp, act, score, exp_topk, exp_ids, exp_w
|
|
print('Activation+Topk: ALL PASS', flush=True)
|
|
except Exception as e:
|
|
import traceback; traceback.print_exc()
|
|
|
|
print('\n=== DONE ===', flush=True)
|