fucken aye

This commit is contained in:
2026-05-12 22:33:55 +00:00
parent fa825c16b9
commit 2bdda36bb7
2 changed files with 50 additions and 133 deletions

View File

@@ -63,26 +63,64 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
scale_exp = (scale_bits >> 23) & 0xFF
scale_mant = scale_bits & 0x7FFFFF
# Convert FP32 → E4M3 manually
e4m3_exp = scale_exp - 120 # FP32 bias=127, E4M3 bias=7
e4m3_exp = tl.maximum(e4m3_exp, 0)
e4m3_exp = tl.minimum(e4m3_exp, 15)
e4m3_mant = scale_mant >> 20
round_bit = (scale_mant >> 19) & 1
e4m3_mant = e4m3_mant + round_bit
# Convert FP32 → E4M3 manually (with subnormal support)
# FP32 bias=127, E4M3 bias=7 → raw exp = scale_exp - 120
e4m3_exp_raw = scale_exp - 120 # can be negative → subnormal
# Normal path: exp >= 1, just truncate mantissa top 3 bits
# RNE rounding: need guard (bit 19), sticky (OR of bits 18:0), and LSB of result
normal_mant = scale_mant >> 20
guard_bit = (scale_mant >> 19) & 1
sticky_bit = tl.where((scale_mant & 0x7FFFF) != 0, 1, 0) # OR of bits [18:0]
result_lsb = normal_mant & 1
# RNE: round up if (guard=1 and sticky=1) or (guard=1 and sticky=0 and lsb=1)
round_up = guard_bit & (sticky_bit | result_lsb)
normal_mant = normal_mant + round_up
normal_exp = e4m3_exp_raw
# Subnormal path: exp_raw <= 0
# Insert implicit leading 1 and right-shift by (1 - exp_raw)
# E4M3 subnormal: value = (mant/8) * 2^(1-7) = (mant/8) * 2^-6
# So we need: (1 + mant_fp32/2^23) * 2^(exp_raw - 7) = (shifted_mant/8) * 2^-6
# shifted_mant = (implicit_1 | mant_fp32) >> (1 - exp_raw - 1) then take top 3 bits
shift = 1 - e4m3_exp_raw # positive when subnormal
mant_with_leading = (0x800000 | scale_mant) # insert implicit 1
# Right-shift to get into the 3-bit E4M3 mantissa window
# We want bits [shift+19 : shift+23) of mant_with_leading for 3 mantissa bits + 1 round bit
subnormal_mant = (mant_with_leading >> (shift.to(tl.int32) + 20)) & 0x7
sub_guard_bit = (mant_with_leading >> (shift.to(tl.int32) + 19)) & 1
# Sticky: OR of all bits below the guard bit in the shifted result
# shift ≤ 8 in practice (amax floor = 1e-4 → scale ≈ 2^-15 → exp_raw ≈ -7), so mask ≤ 2^27
sub_sticky_mask = (1 << (shift.to(tl.int32) + 19)) - 1
sub_sticky_bit = tl.where((mant_with_leading & sub_sticky_mask) != 0, 1, 0)
sub_result_lsb = subnormal_mant & 1
sub_round_up = sub_guard_bit & (sub_sticky_bit | sub_result_lsb)
subnormal_mant = subnormal_mant + sub_round_up
is_normal = e4m3_exp_raw >= 1
e4m3_mant = tl.where(is_normal, normal_mant, subnormal_mant)
e4m3_exp = tl.where(is_normal, normal_exp, 0) # exp=0 for subnormals
# Handle mantissa overflow after rounding
overflow = e4m3_mant >= 8
e4m3_mant = tl.where(overflow, 0, e4m3_mant)
e4m3_exp = tl.where(overflow, e4m3_exp + 1, e4m3_exp)
e4m3_exp = tl.maximum(e4m3_exp, 0)
e4m3_exp = tl.minimum(e4m3_exp, 15)
scale_e4m3_bits = (e4m3_exp << 3) | e4m3_mant
# Reconstruct dequantized scale for E2M1 quantization
e4m3_exp_for_recon = tl.maximum(e4m3_exp.to(tl.int32) - 7, -126)
# Reconstruct dequantized scale by decoding the STORED E4M3 bits.
# This guarantees the E2M1 quantization divides by exactly the value
# the CUDA kernel will multiply back — same bits, single decode, no
# possibility of encode/decode disagreement.
stored_exp = (scale_e4m3_bits >> 3) & 0xF
stored_mant = scale_e4m3_bits & 0x7
e4m3_exp_for_recon = tl.maximum(stored_exp.to(tl.int32) - 7, -126)
two_pow_exp_bits = (e4m3_exp_for_recon + 127).to(tl.uint32) << 23
two_pow_exp = two_pow_exp_bits.to(tl.float32, bitcast=True)
normal_value = (1.0 + e4m3_mant.to(tl.float32) / 8.0) * two_pow_exp
subnormal_value = (e4m3_mant.to(tl.float32) / 8.0) * 0.015625
e4m3_value = tl.where(e4m3_exp == 0, subnormal_value, normal_value)
normal_value = (1.0 + stored_mant.to(tl.float32) / 8.0) * two_pow_exp
subnormal_value = (stored_mant.to(tl.float32) / 8.0) * 0.015625
e4m3_value = tl.where(stored_exp == 0, subnormal_value, normal_value)
# ---- E2M1 FP4 quantization (unpacked, 1 byte/element) ----
# E2M1 LUT (unsigned): [0, 0.5, 1, 1.5, 2, 3, 4, 6]

View File

@@ -1,121 +0,0 @@
"""Minimal test for fp8_nvfp4_mega_moe kernel with synthetic data.
Run inside the vllm container:
python3 /patches/test_nvfp4_mega_moe.py
"""
import torch
import torch.distributed as dist
import os
import sys
def test_nvfp4_mega_moe():
# Use dimensions that satisfy all alignment requirements:
# - hidden and intermediate_hidden must be multiples of 128 and 64
# - block_m will be at least 32 (SMEM alignment: 32 * 64 = 2048 >= 1024)
num_experts = 2
num_tokens = 32 # must be multiple of alignment
top_k = 2
hidden = 512 # multiple of 128 and 64
intermediate_hidden = 1024 # multiple of 128 and 64
device = "cuda"
torch.cuda.set_device(0)
# Single-rank process group for SymmBuffer
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault("MASTER_PORT", "29501")
os.environ.setdefault("RANK", "0")
os.environ.setdefault("WORLD_SIZE", "1")
if not dist.is_initialized():
dist.init_process_group("nccl")
group = dist.new_group()
from deep_gemm.mega import (
fp8_nvfp4_mega_moe,
get_symm_buffer_for_nvfp4_mega_moe,
transform_nvfp4_weights_for_mega_moe,
)
# --- Weights: random NVFP4 ---
w13_weight = torch.randint(0, 256, (num_experts, 2 * intermediate_hidden, hidden // 2),
dtype=torch.uint8, device=device).view(torch.int8)
w13_weight_scale = torch.randn(num_experts, 2 * intermediate_hidden, hidden // 16,
device=device).abs().clamp(0.1, 10.0).to(torch.float8_e4m3fn)
w13_weight_scale_2 = torch.ones(num_experts, device=device) # global scale = 1 for simplicity
w2_weight = torch.randint(0, 256, (num_experts, hidden, intermediate_hidden // 2),
dtype=torch.uint8, device=device).view(torch.int8)
w2_weight_scale = torch.randn(num_experts, hidden, intermediate_hidden // 16,
device=device).abs().clamp(0.1, 10.0).to(torch.float8_e4m3fn)
w2_weight_scale_2 = torch.ones(num_experts, device=device)
print("Transforming weights...")
l1_weights, l2_weights = transform_nvfp4_weights_for_mega_moe(
(w13_weight, w13_weight_scale),
(w2_weight, w2_weight_scale),
l1_weight_scale_2=w13_weight_scale_2,
l2_weight_scale_2=w2_weight_scale_2,
)
for name, t in [("l1_w", l1_weights[0]), ("l1_w_sf", l1_weights[1]),
("l2_w", l2_weights[0]), ("l2_w_sf", l2_weights[1])]:
print(f" {name}: dtype={t.dtype} shape={tuple(t.shape)} strides={t.stride()} contig={t.is_contiguous()}")
# --- Symm buffer ---
print("Creating symm buffer...")
symm_buffer = get_symm_buffer_for_nvfp4_mega_moe(
group, num_experts, num_tokens, top_k, hidden, intermediate_hidden)
for name, t in [("x", symm_buffer.x), ("x_sf", symm_buffer.x_sf),
("l1_acts", symm_buffer.l1_acts), ("l1_acts_sf", symm_buffer.l1_acts_sf),
("l2_acts", symm_buffer.l2_acts), ("l2_acts_sf", symm_buffer.l2_acts_sf)]:
print(f" symm_{name}: dtype={t.dtype} shape={tuple(t.shape)} strides={t.stride()}")
# --- Stage inputs (BF16 hidden_states → FP4 packed + UE4M3 scales) ---
print("Staging inputs...")
hidden_states = torch.randn(num_tokens, hidden, dtype=torch.bfloat16, device=device) * 0.5
topk_weights = torch.softmax(torch.randn(num_tokens, top_k, device=device), dim=-1)
topk_ids = torch.randint(0, num_experts, (num_tokens, top_k), device=device)
# Import the staging kernel directly (can't import from vllm model without full init)
import triton
import triton.language as tl
# Just manually pack a few tokens into FP4 for testing
# Write actual FP4 data to the activation buffer (random but valid packed E2M1)
symm_buffer.x[:num_tokens].copy_(
torch.randint(0, 256, (num_tokens, hidden // 2), dtype=torch.uint8, device=device).view(torch.int8))
# Write valid UE4M3 scales (random but non-zero)
# x_sf shape is (tokens, hidden//64) as int32 — each int32 = 4 packed UE4M3 bytes
# Just fill with simple non-zero int32 values (the data doesn't need to be
# perfectly valid UE4M3 for a launch test, just non-garbage)
symm_buffer.x_sf[:num_tokens].fill_(0x3C3C3C3C) # repeating 0x3C = ~0.5 in E4M3
# Write topk data directly
for i in range(num_tokens):
for j in range(top_k):
symm_buffer.topk_idx[i, j] = topk_ids[i, j].item()
symm_buffer.topk_weights[i, j] = topk_weights[i, j].item()
torch.cuda.synchronize()
print("Buffer populated with random FP4 data")
# --- Run kernel ---
y = torch.zeros(num_tokens, hidden, dtype=torch.bfloat16, device=device)
print("Calling fp8_nvfp4_mega_moe...", flush=True)
import signal
timed_out = False
def handler(signum, frame):
nonlocal timed_out
timed_out = True
raise TimeoutError("Kernel timeout")
signal.signal(signal.SIGALRM, handler)
signal.alarm(15) # 15 second timeout
try:
fp8_nvfp4_mega_moe(y, l1_weights, l2_weights, symm_buffer)
torch.cuda.synchronize()
signal.alarm(0)
print(f"SUCCESS! y stats: min={y.min().item():.4f} max={y.max().item():.4f} mean={y.mean().item():.4f} nonzero={torch.count_nonzero(y).item()}")
except TimeoutError:
print("TIMEOUT: kernel did not complete in 15s (GPU hang?)")
except Exception as e:
signal.alarm(0)
print(f"FAILED: {e}")
raise
if __name__ == "__main__":
test_nvfp4_mega_moe()