172 lines
5.5 KiB
Python
172 lines
5.5 KiB
Python
"""Unit test: wo_a NVFP4 grouped linear + inverse RoPE.
|
||
|
||
Tests the CuTeDSL NVFP4 grouped GEMM that replaces DeepGEMM's fp8_einsum
|
||
for the wo_a (o-projection first half) in DeepSeek V4 attention.
|
||
|
||
Also tests inverse_rope_bf16 against a synthetic reference.
|
||
|
||
Usage (B200): python3 tests/test_wo_a.py
|
||
|
||
Requires: CuTeDSL, CUDA, Blackwell GPU
|
||
"""
|
||
|
||
import sys
|
||
import os
|
||
import torch
|
||
import torch.nn.functional as F
|
||
|
||
# Add repo root to path
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from cutedsl.inverse_rope import inverse_rope_bf16
|
||
from cutedsl.wo_a_grouped_linear import CuTeDSLNvfp4WoA
|
||
|
||
DEVICE = "cuda:0"
|
||
|
||
# DeepSeek V4 Pro dimensions
|
||
N_LOCAL_GROUPS = 8
|
||
HEADS_PER_GROUP = 16 # 128 heads / 8 groups
|
||
HEAD_DIM = 512
|
||
NOPE_DIM = 448
|
||
ROPE_DIM = 64
|
||
O_LORA_RANK = 1536
|
||
GROUP_IN = HEADS_PER_GROUP * HEAD_DIM # 8192
|
||
NUM_TOKENS = 4
|
||
|
||
|
||
def test_inverse_rope():
|
||
"""Test inverse_rope_bf16: apply RoPE then inverse → should recover original."""
|
||
print("\n=== Test: inverse_rope_bf16 ===")
|
||
|
||
torch.manual_seed(42)
|
||
num_tokens = 4
|
||
num_heads = N_LOCAL_GROUPS * HEADS_PER_GROUP
|
||
max_pos = 128
|
||
|
||
# Build cos_sin_cache (same format as vLLM: cos||sin concatenated)
|
||
rope_dim = ROPE_DIM
|
||
half_rope = rope_dim // 2
|
||
base = 10000.0
|
||
inv_freq = 1.0 / (base ** (torch.arange(0, half_rope, dtype=torch.float32, device=DEVICE) / half_rope))
|
||
|
||
pos = torch.arange(max_pos, dtype=torch.float32, device=DEVICE)
|
||
freqs = torch.outer(pos, inv_freq) # (max_pos, half_rope)
|
||
cos_sin_cache = torch.cat([freqs.cos(), freqs.sin(), ], dim=-1) # (max_pos, rope_dim)
|
||
|
||
# Random attention output
|
||
o = torch.randn(num_tokens, num_heads, HEAD_DIM, dtype=torch.bfloat16, device=DEVICE) * 2.0
|
||
positions = torch.randint(0, max_pos, (num_tokens,), dtype=torch.int64, device=DEVICE)
|
||
|
||
# Apply RoPE (forward), then inverse
|
||
# Forward RoPE (GPT-J interleaved):
|
||
o_rope = o[:, :, NOPE_DIM:].clone()
|
||
cos_all = cos_sin_cache[positions, :half_rope].unsqueeze(1).to(o.dtype)
|
||
sin_all = cos_sin_cache[positions, half_rope:].unsqueeze(1).to(o.dtype)
|
||
o_even = o_rope[:, :, 0::2]
|
||
o_odd = o_rope[:, :, 1::2]
|
||
rope_even = o_even * cos_all - o_odd * sin_all
|
||
rope_odd = o_even * sin_all + o_odd * cos_all
|
||
o_fwd = o.clone()
|
||
o_fwd[:, :, NOPE_DIM:][:, :, 0::2] = rope_even
|
||
o_fwd[:, :, NOPE_DIM:][:, :, 1::2] = rope_odd
|
||
|
||
# Apply inverse RoPE
|
||
o_inv = inverse_rope_bf16(o_fwd, positions, cos_sin_cache, NOPE_DIM, ROPE_DIM)
|
||
|
||
# Compare with original
|
||
cos = F.cosine_similarity(
|
||
o.flatten().unsqueeze(0).float(),
|
||
o_inv.flatten().unsqueeze(0).float()
|
||
).item()
|
||
mse = (o.float() - o_inv.float()).pow(2).mean().item()
|
||
status = "✅" if cos > 0.999 else "❌"
|
||
print(f" inverse_rope → original: cosine={cos:.6f} MSE={mse:.6e} {status}")
|
||
return cos
|
||
|
||
|
||
def test_wo_a_grouped_linear():
|
||
"""Test CuTeDSL NVFP4 wo_a grouped linear against BF16 reference."""
|
||
print("\n=== Test: wo_a NVFP4 Grouped Linear ===")
|
||
|
||
torch.manual_seed(42)
|
||
num_tokens = NUM_TOKENS
|
||
|
||
# Random attention output (after inverse RoPE)
|
||
o = torch.randn(num_tokens, N_LOCAL_GROUPS * HEADS_PER_GROUP, HEAD_DIM,
|
||
dtype=torch.bfloat16, device=DEVICE) * 2.0
|
||
|
||
# Random wo_a weight (BF16, grouped format)
|
||
# In vLLM, wo_a is ColumnParallelLinear with is_bmm=True
|
||
# Weight shape: (n_local_groups, heads_per_group * head_dim, o_lora_rank)
|
||
wo_a_weight = torch.randn(
|
||
N_LOCAL_GROUPS, GROUP_IN, O_LORA_RANK,
|
||
dtype=torch.bfloat16, device=DEVICE
|
||
) * 0.1
|
||
|
||
# BF16 reference: grouped matmul
|
||
o_grouped = o.reshape(num_tokens, N_LOCAL_GROUPS, GROUP_IN)
|
||
z_ref = torch.empty(num_tokens, N_LOCAL_GROUPS, O_LORA_RANK,
|
||
dtype=torch.bfloat16, device=DEVICE)
|
||
for g in range(N_LOCAL_GROUPS):
|
||
# (tokens, GROUP_IN) × (GROUP_IN, O_LORA_RANK) → (tokens, O_LORA_RANK)
|
||
z_ref[:, g, :] = o_grouped[:, g, :] @ wo_a_weight[g]
|
||
|
||
# CuTeDSL NVFP4 runner
|
||
runner = CuTeDSLNvfp4WoA(
|
||
n_local_groups=N_LOCAL_GROUPS,
|
||
heads_per_group=HEADS_PER_GROUP,
|
||
head_dim=HEAD_DIM,
|
||
o_lora_rank=O_LORA_RANK,
|
||
max_num_tokens=8192,
|
||
device=DEVICE,
|
||
)
|
||
runner.set_bf16_weight(wo_a_weight)
|
||
runner.finalize_weights()
|
||
|
||
# Warmup + compute activation global scale
|
||
runner._ensure_initialized()
|
||
runner.compute_activation_global_scale(o)
|
||
|
||
# Run
|
||
with torch.no_grad():
|
||
z_out = runner.run(o)
|
||
|
||
# Compare
|
||
cos = F.cosine_similarity(
|
||
z_ref.flatten().unsqueeze(0).float(),
|
||
z_out.flatten().unsqueeze(0).float()
|
||
).item()
|
||
mse = (z_ref.float() - z_out.float()).pow(2).mean().item()
|
||
status = "✅" if cos >= 0.98 else "❌"
|
||
print(f" wo_a grouped linear: cosine={cos:.6f} MSE={mse:.6e} {status}")
|
||
print(f" z_ref amax={z_ref.amax():.4f} z_out amax={z_out.amax():.4f}")
|
||
|
||
return cos
|
||
|
||
|
||
def main():
|
||
torch.cuda.set_device(0)
|
||
print("=== wo_a NVFP4 Grouped Linear + Inverse RoPE Tests ===")
|
||
|
||
cos_rope = test_inverse_rope()
|
||
cos_woa = test_wo_a_grouped_linear()
|
||
|
||
print(f"\n=== SUMMARY ===")
|
||
results = {"inverse_rope": cos_rope, "wo_a_grouped_linear": cos_woa}
|
||
all_pass = True
|
||
for name, cos in results.items():
|
||
threshold = 0.999 if name == "inverse_rope" else 0.98
|
||
status = "✅" if cos >= threshold else "❌"
|
||
if cos < threshold:
|
||
all_pass = False
|
||
print(f" {name}: cosine={cos:.6f} {status}")
|
||
|
||
if all_pass:
|
||
print("\n✅ ALL PASS")
|
||
else:
|
||
print("\n❌ SOME FAILED")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|