Files
nvfp4-megamoe-kernel/tests/test_wo_a.py

172 lines
5.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Unit test: wo_a NVFP4 grouped linear + inverse RoPE.
Tests the CuTeDSL NVFP4 grouped GEMM that replaces DeepGEMM's fp8_einsum
for the wo_a (o-projection first half) in DeepSeek V4 attention.
Also tests inverse_rope_bf16 against a synthetic reference.
Usage (B200): python3 tests/test_wo_a.py
Requires: CuTeDSL, CUDA, Blackwell GPU
"""
import sys
import os
import torch
import torch.nn.functional as F
# Add repo root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from cutedsl.inverse_rope import inverse_rope_bf16
from cutedsl.wo_a_grouped_linear import CuTeDSLNvfp4WoA
DEVICE = "cuda:0"
# DeepSeek V4 Pro dimensions
N_LOCAL_GROUPS = 8
HEADS_PER_GROUP = 16 # 128 heads / 8 groups
HEAD_DIM = 512
NOPE_DIM = 448
ROPE_DIM = 64
O_LORA_RANK = 1536
GROUP_IN = HEADS_PER_GROUP * HEAD_DIM # 8192
NUM_TOKENS = 4
def test_inverse_rope():
"""Test inverse_rope_bf16: apply RoPE then inverse → should recover original."""
print("\n=== Test: inverse_rope_bf16 ===")
torch.manual_seed(42)
num_tokens = 4
num_heads = N_LOCAL_GROUPS * HEADS_PER_GROUP
max_pos = 128
# Build cos_sin_cache (same format as vLLM: cos||sin concatenated)
rope_dim = ROPE_DIM
half_rope = rope_dim // 2
base = 10000.0
inv_freq = 1.0 / (base ** (torch.arange(0, half_rope, dtype=torch.float32, device=DEVICE) / half_rope))
pos = torch.arange(max_pos, dtype=torch.float32, device=DEVICE)
freqs = torch.outer(pos, inv_freq) # (max_pos, half_rope)
cos_sin_cache = torch.cat([freqs.cos(), freqs.sin(), ], dim=-1) # (max_pos, rope_dim)
# Random attention output
o = torch.randn(num_tokens, num_heads, HEAD_DIM, dtype=torch.bfloat16, device=DEVICE) * 2.0
positions = torch.randint(0, max_pos, (num_tokens,), dtype=torch.int64, device=DEVICE)
# Apply RoPE (forward), then inverse
# Forward RoPE (GPT-J interleaved):
o_rope = o[:, :, NOPE_DIM:].clone()
cos_all = cos_sin_cache[positions, :half_rope].unsqueeze(1).to(o.dtype)
sin_all = cos_sin_cache[positions, half_rope:].unsqueeze(1).to(o.dtype)
o_even = o_rope[:, :, 0::2]
o_odd = o_rope[:, :, 1::2]
rope_even = o_even * cos_all - o_odd * sin_all
rope_odd = o_even * sin_all + o_odd * cos_all
o_fwd = o.clone()
o_fwd[:, :, NOPE_DIM:][:, :, 0::2] = rope_even
o_fwd[:, :, NOPE_DIM:][:, :, 1::2] = rope_odd
# Apply inverse RoPE
o_inv = inverse_rope_bf16(o_fwd, positions, cos_sin_cache, NOPE_DIM, ROPE_DIM)
# Compare with original
cos = F.cosine_similarity(
o.flatten().unsqueeze(0).float(),
o_inv.flatten().unsqueeze(0).float()
).item()
mse = (o.float() - o_inv.float()).pow(2).mean().item()
status = "" if cos > 0.999 else ""
print(f" inverse_rope → original: cosine={cos:.6f} MSE={mse:.6e} {status}")
return cos
def test_wo_a_grouped_linear():
"""Test CuTeDSL NVFP4 wo_a grouped linear against BF16 reference."""
print("\n=== Test: wo_a NVFP4 Grouped Linear ===")
torch.manual_seed(42)
num_tokens = NUM_TOKENS
# Random attention output (after inverse RoPE)
o = torch.randn(num_tokens, N_LOCAL_GROUPS * HEADS_PER_GROUP, HEAD_DIM,
dtype=torch.bfloat16, device=DEVICE) * 2.0
# Random wo_a weight (BF16, grouped format)
# In vLLM, wo_a is ColumnParallelLinear with is_bmm=True
# Weight shape: (n_local_groups, heads_per_group * head_dim, o_lora_rank)
wo_a_weight = torch.randn(
N_LOCAL_GROUPS, GROUP_IN, O_LORA_RANK,
dtype=torch.bfloat16, device=DEVICE
) * 0.1
# BF16 reference: grouped matmul
o_grouped = o.reshape(num_tokens, N_LOCAL_GROUPS, GROUP_IN)
z_ref = torch.empty(num_tokens, N_LOCAL_GROUPS, O_LORA_RANK,
dtype=torch.bfloat16, device=DEVICE)
for g in range(N_LOCAL_GROUPS):
# (tokens, GROUP_IN) × (GROUP_IN, O_LORA_RANK) → (tokens, O_LORA_RANK)
z_ref[:, g, :] = o_grouped[:, g, :] @ wo_a_weight[g]
# CuTeDSL NVFP4 runner
runner = CuTeDSLNvfp4WoA(
n_local_groups=N_LOCAL_GROUPS,
heads_per_group=HEADS_PER_GROUP,
head_dim=HEAD_DIM,
o_lora_rank=O_LORA_RANK,
max_num_tokens=8192,
device=DEVICE,
)
runner.set_bf16_weight(wo_a_weight)
runner.finalize_weights()
# Warmup + compute activation global scale
runner._ensure_initialized()
runner.compute_activation_global_scale(o)
# Run
with torch.no_grad():
z_out = runner.run(o)
# Compare
cos = F.cosine_similarity(
z_ref.flatten().unsqueeze(0).float(),
z_out.flatten().unsqueeze(0).float()
).item()
mse = (z_ref.float() - z_out.float()).pow(2).mean().item()
status = "" if cos >= 0.98 else ""
print(f" wo_a grouped linear: cosine={cos:.6f} MSE={mse:.6e} {status}")
print(f" z_ref amax={z_ref.amax():.4f} z_out amax={z_out.amax():.4f}")
return cos
def main():
torch.cuda.set_device(0)
print("=== wo_a NVFP4 Grouped Linear + Inverse RoPE Tests ===")
cos_rope = test_inverse_rope()
cos_woa = test_wo_a_grouped_linear()
print(f"\n=== SUMMARY ===")
results = {"inverse_rope": cos_rope, "wo_a_grouped_linear": cos_woa}
all_pass = True
for name, cos in results.items():
threshold = 0.999 if name == "inverse_rope" else 0.98
status = "" if cos >= threshold else ""
if cos < threshold:
all_pass = False
print(f" {name}: cosine={cos:.6f} {status}")
if all_pass:
print("\n✅ ALL PASS")
else:
print("\n❌ SOME FAILED")
if __name__ == "__main__":
main()