Files
nvfp4-megamoe-kernel/tests/archive/test_wo_a.py

172 lines
5.5 KiB
Python
Raw Normal View History

"""Unit test: wo_a NVFP4 grouped linear + inverse RoPE.
Tests the CuTeDSL NVFP4 grouped GEMM that replaces DeepGEMM's fp8_einsum
for the wo_a (o-projection first half) in DeepSeek V4 attention.
Also tests inverse_rope_bf16 against a synthetic reference.
Usage (B200): python3 tests/test_wo_a.py
Requires: CuTeDSL, CUDA, Blackwell GPU
"""
import sys
import os
import torch
import torch.nn.functional as F
# Add repo root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from dsv4.ops.rope import inverse_rope_bf16
from dsv4.layers.grouped_linear import Nvfp4GroupedLinear
DEVICE = "cuda:0"
# DeepSeek V4 Pro dimensions
N_LOCAL_GROUPS = 8
HEADS_PER_GROUP = 16 # 128 heads / 8 groups
HEAD_DIM = 512
NOPE_DIM = 448
ROPE_DIM = 64
O_LORA_RANK = 1536
GROUP_IN = HEADS_PER_GROUP * HEAD_DIM # 8192
NUM_TOKENS = 4
def test_inverse_rope():
"""Test inverse_rope_bf16: apply RoPE then inverse → should recover original."""
print("\n=== Test: inverse_rope_bf16 ===")
torch.manual_seed(42)
num_tokens = 4
num_heads = N_LOCAL_GROUPS * HEADS_PER_GROUP
max_pos = 128
# Build cos_sin_cache (same format as vLLM: cos||sin concatenated)
rope_dim = ROPE_DIM
half_rope = rope_dim // 2
base = 10000.0
2026-05-19 02:37:50 +00:00
inv_freq = 1.0 / (base ** (torch.arange(0, half_rope, dtype=torch.float32, device=DEVICE) / half_rope))
2026-05-19 02:37:50 +00:00
pos = torch.arange(max_pos, dtype=torch.float32, device=DEVICE)
freqs = torch.outer(pos, inv_freq) # (max_pos, half_rope)
2026-05-19 02:37:50 +00:00
cos_sin_cache = torch.cat([freqs.cos(), freqs.sin(), ], dim=-1) # (max_pos, rope_dim)
# Random attention output
o = torch.randn(num_tokens, num_heads, HEAD_DIM, dtype=torch.bfloat16, device=DEVICE) * 2.0
positions = torch.randint(0, max_pos, (num_tokens,), dtype=torch.int64, device=DEVICE)
# Apply RoPE (forward), then inverse
# Forward RoPE (GPT-J interleaved):
o_rope = o[:, :, NOPE_DIM:].clone()
cos_all = cos_sin_cache[positions, :half_rope].unsqueeze(1).to(o.dtype)
sin_all = cos_sin_cache[positions, half_rope:].unsqueeze(1).to(o.dtype)
o_even = o_rope[:, :, 0::2]
o_odd = o_rope[:, :, 1::2]
rope_even = o_even * cos_all - o_odd * sin_all
rope_odd = o_even * sin_all + o_odd * cos_all
o_fwd = o.clone()
o_fwd[:, :, NOPE_DIM:][:, :, 0::2] = rope_even
o_fwd[:, :, NOPE_DIM:][:, :, 1::2] = rope_odd
# Apply inverse RoPE
o_inv = inverse_rope_bf16(o_fwd, positions, cos_sin_cache, NOPE_DIM, ROPE_DIM)
# Compare with original
cos = F.cosine_similarity(
o.flatten().unsqueeze(0).float(),
o_inv.flatten().unsqueeze(0).float()
).item()
mse = (o.float() - o_inv.float()).pow(2).mean().item()
status = "" if cos > 0.999 else ""
print(f" inverse_rope → original: cosine={cos:.6f} MSE={mse:.6e} {status}")
return cos
def test_wo_a_grouped_linear():
"""Test CuTeDSL NVFP4 wo_a grouped linear against BF16 reference."""
print("\n=== Test: wo_a NVFP4 Grouped Linear ===")
torch.manual_seed(42)
num_tokens = NUM_TOKENS
# Random attention output (after inverse RoPE)
o = torch.randn(num_tokens, N_LOCAL_GROUPS * HEADS_PER_GROUP, HEAD_DIM,
dtype=torch.bfloat16, device=DEVICE) * 2.0
# Random wo_a weight (BF16, grouped format)
# In vLLM, wo_a is ColumnParallelLinear with is_bmm=True
# Weight shape: (n_local_groups, heads_per_group * head_dim, o_lora_rank)
wo_a_weight = torch.randn(
N_LOCAL_GROUPS, GROUP_IN, O_LORA_RANK,
dtype=torch.bfloat16, device=DEVICE
) * 0.1
# BF16 reference: grouped matmul
o_grouped = o.reshape(num_tokens, N_LOCAL_GROUPS, GROUP_IN)
z_ref = torch.empty(num_tokens, N_LOCAL_GROUPS, O_LORA_RANK,
dtype=torch.bfloat16, device=DEVICE)
for g in range(N_LOCAL_GROUPS):
# (tokens, GROUP_IN) × (GROUP_IN, O_LORA_RANK) → (tokens, O_LORA_RANK)
z_ref[:, g, :] = o_grouped[:, g, :] @ wo_a_weight[g]
# CuTeDSL NVFP4 runner
runner = Nvfp4GroupedLinear(
n_local_groups=N_LOCAL_GROUPS,
heads_per_group=HEADS_PER_GROUP,
head_dim=HEAD_DIM,
o_lora_rank=O_LORA_RANK,
max_num_tokens=8192,
device=DEVICE,
)
runner.set_bf16_weight(wo_a_weight)
runner.finalize_weights()
# Warmup + compute activation global scale
runner._ensure_initialized()
runner.compute_activation_global_scale(o)
# Run
with torch.no_grad():
z_out = runner.run(o)
# Compare
cos = F.cosine_similarity(
z_ref.flatten().unsqueeze(0).float(),
z_out.flatten().unsqueeze(0).float()
).item()
mse = (z_ref.float() - z_out.float()).pow(2).mean().item()
status = "" if cos >= 0.98 else ""
print(f" wo_a grouped linear: cosine={cos:.6f} MSE={mse:.6e} {status}")
print(f" z_ref amax={z_ref.amax():.4f} z_out amax={z_out.amax():.4f}")
return cos
def main():
torch.cuda.set_device(0)
print("=== wo_a NVFP4 Grouped Linear + Inverse RoPE Tests ===")
cos_rope = test_inverse_rope()
cos_woa = test_wo_a_grouped_linear()
print(f"\n=== SUMMARY ===")
results = {"inverse_rope": cos_rope, "wo_a_grouped_linear": cos_woa}
all_pass = True
for name, cos in results.items():
threshold = 0.999 if name == "inverse_rope" else 0.98
status = "" if cos >= threshold else ""
if cos < threshold:
all_pass = False
print(f" {name}: cosine={cos:.6f} {status}")
if all_pass:
print("\n✅ ALL PASS")
else:
print("\n❌ SOME FAILED")
if __name__ == "__main__":
main()