temp: restore EXACT old bridge.py from b685112
This commit is contained in:
@@ -7,75 +7,48 @@ the ScaledGroupedGemmKernel expects:
|
||||
- Scale factor assembly (padding + swizzle)
|
||||
- B tensor K-major stride conversion
|
||||
- Expert offset computation
|
||||
|
||||
CUDA-graph-compatible: no .item() calls, no torch.cuda.synchronize()
|
||||
in the forward path, no Python control flow on GPU data.
|
||||
|
||||
Compilation uses real tensors (not dummy shapes) because the CuTeDSL
|
||||
kernel's TMA descriptors are sized from compilation-time tensor shapes.
|
||||
This happens once during warmup, outside cudagraph capture.
|
||||
"""
|
||||
import math
|
||||
import threading
|
||||
|
||||
import torch
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as cutlass_torch
|
||||
import cutlass.utils as utils
|
||||
|
||||
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
|
||||
ScaledGroupedGemmKernel,
|
||||
utils,
|
||||
ceil_div,
|
||||
pad_and_swizzle_single,
|
||||
assemble_raw_scales_2d3d_2d_side,
|
||||
assemble_raw_scales_2d3d_3d_side,
|
||||
cat_byte_reinterpretable_tensors,
|
||||
stack_byte_reinterpretable_tensors,
|
||||
)
|
||||
|
||||
# ── Constants ──────────────────────────────────────────────────────────
|
||||
|
||||
E2M1_MAGNITUDES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
|
||||
SF_VEC_SIZE = 16 # NVFP4 block size
|
||||
|
||||
E2M1_MAGNITUDES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
|
||||
|
||||
# Cached LUT for E2M1 quantization (created once per device, cudagraph-safe)
|
||||
_NVFP4_STEP_LUT_CACHE = {}
|
||||
_NVFP4_STEP_LUT_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def _get_step_to_idx_lut(device):
|
||||
"""Get or create the E2M1 step-to-index LUT for the given device.
|
||||
|
||||
Cached per device to avoid CPU→CUDA copies during cudagraph capture.
|
||||
"""
|
||||
with _NVFP4_STEP_LUT_LOCK:
|
||||
if device not in _NVFP4_STEP_LUT_CACHE:
|
||||
_NVFP4_STEP_LUT_CACHE[device] = torch.as_tensor(
|
||||
[0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 6, 7, 7],
|
||||
dtype=torch.int8, device=device,
|
||||
)
|
||||
return _NVFP4_STEP_LUT_CACHE[device]
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
def round_up(a, b):
|
||||
return ((a + b - 1) // b) * b
|
||||
return ceil_div(a, b) * b
|
||||
|
||||
|
||||
# ── Quantization ───────────────────────────────────────────────────────
|
||||
# ── Quantization ──────────────────────────────────────────────────────
|
||||
|
||||
def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE):
|
||||
"""Quantize BF16 tensor to NVFP4.
|
||||
|
||||
NOTE: This function is NOT cudagraph-safe because it uses .max()
|
||||
which forces a CPU-GPU sync. It should only be called during
|
||||
weight preparation (offline), NOT during the forward pass.
|
||||
|
||||
For activation quantization during forward, use
|
||||
quantize_activation_nvfp4() instead (cudagraph-safe, fixed global scale).
|
||||
|
||||
Args:
|
||||
x_bf16: (..., D) BF16 tensor
|
||||
|
||||
Returns:
|
||||
x_fp4: (..., D//2) float4_e2m1fn_x2
|
||||
x_sf: (..., D//16) float8_e4m3fn
|
||||
x_fp4: (..., D//2) float4_e2m1fn_x2 — native PyTorch FP4
|
||||
x_sf: (..., D//16) float8_e4m3fn — block scales
|
||||
global_scale: float32 scalar
|
||||
"""
|
||||
x_f32 = x_bf16.float()
|
||||
@@ -94,15 +67,15 @@ def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE):
|
||||
block_amax = x_reshaped.abs().amax(dim=-1).clamp(min=1e-8)
|
||||
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
|
||||
|
||||
# Nearest E2M1 — memory-efficient clamp approach
|
||||
# Nearest E2M1
|
||||
block_sf_expanded = block_scale.float().unsqueeze(-1)
|
||||
x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8)
|
||||
|
||||
magnitudes = torch.tensor(E2M1_MAGNITUDES, dtype=torch.float32, device=x_bf16.device)
|
||||
signs = torch.sign(x_scaled)
|
||||
abs_scaled = x_scaled.abs().clamp(max=6.0)
|
||||
|
||||
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
|
||||
step_to_idx = _get_step_to_idx_lut(x_bf16.device)
|
||||
indices = step_to_idx[half_steps.long()]
|
||||
abs_scaled = x_scaled.abs().unsqueeze(-1)
|
||||
distances = (abs_scaled - magnitudes).abs()
|
||||
indices = distances.argmin(dim=-1)
|
||||
|
||||
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
|
||||
even = nibbles[..., ::2]
|
||||
@@ -119,70 +92,12 @@ def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE):
|
||||
return x_fp4, block_scale, global_scale
|
||||
|
||||
|
||||
def quantize_activation_nvfp4(x_bf16, global_scale, block_size=SF_VEC_SIZE):
|
||||
"""Quantize BF16 activation tensor to NVFP4 (cudagraph-safe).
|
||||
|
||||
Unlike quantize_to_nvfp4(), this takes a pre-computed global_scale
|
||||
instead of computing it via .max() (which forces CPU-GPU sync).
|
||||
The global_scale should be computed once during warmup and cached.
|
||||
|
||||
All operations are pure GPU with no CPU-GPU syncs.
|
||||
|
||||
Args:
|
||||
x_bf16: (..., D) BF16 tensor
|
||||
global_scale: float32 scalar (pre-computed, NOT from .max())
|
||||
block_size: NVFP4 block size
|
||||
|
||||
Returns:
|
||||
x_fp4: (..., D//2) float4_e2m1fn_x2
|
||||
x_sf: (..., D//16) float8_e4m3fn
|
||||
"""
|
||||
x_f32 = x_bf16.float()
|
||||
x_norm = x_f32 / global_scale
|
||||
|
||||
last_dim = x_norm.shape[-1]
|
||||
n_blocks = ceil_div(last_dim, block_size)
|
||||
|
||||
if last_dim % block_size != 0:
|
||||
pad_size = n_blocks * block_size - last_dim
|
||||
x_norm = torch.nn.functional.pad(x_norm, (0, pad_size))
|
||||
|
||||
x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size)
|
||||
block_amax = x_reshaped.abs().amax(dim=-1).clamp(min=1e-8)
|
||||
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
|
||||
|
||||
block_sf_expanded = block_scale.float().unsqueeze(-1)
|
||||
x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8)
|
||||
signs = torch.sign(x_scaled)
|
||||
abs_scaled = x_scaled.abs().clamp(max=6.0)
|
||||
|
||||
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
|
||||
step_to_idx = _get_step_to_idx_lut(x_bf16.device)
|
||||
indices = step_to_idx[half_steps.long()]
|
||||
|
||||
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
|
||||
even = nibbles[..., ::2]
|
||||
odd = nibbles[..., 1::2]
|
||||
packed = (odd << 4) | even
|
||||
|
||||
packed_shape = list(x_bf16.shape)
|
||||
packed_shape[-1] = last_dim // 2
|
||||
x_fp4 = packed.view(torch.float4_e2m1fn_x2).reshape(packed_shape)
|
||||
|
||||
sf_shape = list(x_bf16.shape[:-1]) + [n_blocks]
|
||||
block_scale = block_scale.reshape(sf_shape)
|
||||
|
||||
return x_fp4, block_scale
|
||||
|
||||
|
||||
def quantize_weight_to_nvfp4(w_bf16, block_size=SF_VEC_SIZE):
|
||||
"""Quantize BF16 weight matrix to NVFP4.
|
||||
|
||||
The weight is (K, N) where K is the input dim (packed dimension).
|
||||
Block scales are computed along K (dim 0).
|
||||
|
||||
NOTE: NOT cudagraph-safe — uses .max() for global scale.
|
||||
|
||||
Args:
|
||||
w_bf16: (K, N) BF16 weight matrix
|
||||
|
||||
@@ -203,67 +118,27 @@ def quantize_weight_to_nvfp4(w_bf16, block_size=SF_VEC_SIZE):
|
||||
|
||||
w_reshaped = w_norm.reshape(k_blocks, block_size, N)
|
||||
w_block_amax = w_reshaped.abs().amax(dim=1).clamp(min=1e-8)
|
||||
block_scale = (w_block_amax / 6.0).to(torch.float8_e4m3fn)
|
||||
w_sf = (w_block_amax / 6.0).to(torch.float8_e4m3fn)
|
||||
|
||||
block_sf_expanded = block_scale.float().unsqueeze(1)
|
||||
w_scaled = w_reshaped / block_sf_expanded.clamp(min=1e-8)
|
||||
w_block_sf = w_sf.float().unsqueeze(1)
|
||||
w_scaled = w_reshaped / w_block_sf.clamp(min=1e-8)
|
||||
|
||||
# Nearest E2M1 — memory-efficient clamp approach
|
||||
magnitudes = torch.tensor(E2M1_MAGNITUDES, dtype=torch.float32, device=w_bf16.device)
|
||||
signs = torch.sign(w_scaled)
|
||||
abs_scaled = w_scaled.abs().clamp(max=6.0)
|
||||
|
||||
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
|
||||
step_to_idx = _get_step_to_idx_lut(w_bf16.device)
|
||||
indices = step_to_idx[half_steps.long()]
|
||||
|
||||
abs_scaled = w_scaled.abs().unsqueeze(-1)
|
||||
distances = (abs_scaled - magnitudes).abs()
|
||||
indices = distances.argmin(dim=-1)
|
||||
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
|
||||
|
||||
even = nibbles[:, ::2, :]
|
||||
odd = nibbles[:, 1::2, :]
|
||||
packed = (odd << 4) | even
|
||||
|
||||
w_fp4 = packed.reshape(K // 2, N).view(torch.float4_e2m1fn_x2)
|
||||
return w_fp4, block_scale, global_scale
|
||||
return w_fp4, w_sf, global_scale
|
||||
|
||||
|
||||
# ── Tensor Layout Conversion ───────────────────────────────────────────
|
||||
|
||||
def make_b_k_major(b_tensor):
|
||||
"""Convert B tensor from N-major to K-major layout.
|
||||
|
||||
The kernel expects B with stride (E*K*N, 1, K) — K is contiguous.
|
||||
torch.stack produces stride (E*K*N, N, 1) — N is contiguous.
|
||||
|
||||
double-permute trick: transpose, make contiguous, transpose back.
|
||||
Same shape, but K-contiguous memory layout.
|
||||
"""
|
||||
return b_tensor.permute(0, 2, 1).contiguous().permute(0, 2, 1)
|
||||
|
||||
|
||||
def compute_expert_offsets(tokens_per_expert, num_experts, device="cuda"):
|
||||
"""Compute cumulative token offsets for grouped GEMM.
|
||||
|
||||
Args:
|
||||
tokens_per_expert: list of int, number of tokens per expert
|
||||
num_experts: int, total number of experts
|
||||
device: torch device
|
||||
|
||||
Returns:
|
||||
(num_experts + 1,) int32 tensor of cumulative offsets
|
||||
"""
|
||||
offsets = [0]
|
||||
for t in tokens_per_expert:
|
||||
offsets.append(offsets[-1] + t)
|
||||
return torch.tensor(offsets, dtype=torch.int32, device=device)
|
||||
|
||||
|
||||
# ── Scale Assembly ─────────────────────────────────────────────────────
|
||||
|
||||
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
|
||||
assemble_raw_scales_2d3d_2d_side,
|
||||
assemble_raw_scales_2d3d_3d_side,
|
||||
pad_and_swizzle_single,
|
||||
)
|
||||
|
||||
# ── Scale Factor Assembly ─────────────────────────────────────────────
|
||||
|
||||
def assemble_scales_2d_side(raw_scales):
|
||||
"""Assemble activation scale factors for the 2Dx3D scenario.
|
||||
@@ -282,33 +157,58 @@ def assemble_scales_3d_side(raw_scales):
|
||||
|
||||
Args:
|
||||
raw_scales: list of (K_sf, N) float8_e4m3fn tensors, one per expert
|
||||
NOTE: These will be transposed to (N, K_sf) before swizzling,
|
||||
since the kernel expects N as the non-K dimension.
|
||||
|
||||
Returns:
|
||||
Assembled and swizzled scale tensor
|
||||
"""
|
||||
# The 3D side expects (N, K_sf) so transpose if needed
|
||||
transposed = []
|
||||
for s in raw_scales:
|
||||
if s.shape[0] < s.shape[1]:
|
||||
# Likely (K_sf, N) — transpose to (N, K_sf)
|
||||
transposed.append(s.permute(1, 0).contiguous())
|
||||
else:
|
||||
transposed.append(s)
|
||||
# Kernel expects (N, K_sf) — transpose before swizzling
|
||||
transposed = [sf.T.contiguous() for sf in raw_scales]
|
||||
return assemble_raw_scales_2d3d_3d_side(transposed)
|
||||
|
||||
|
||||
# ── Tensor Layout Conversion ──────────────────────────────────────────
|
||||
|
||||
def make_b_k_major(b_tensor):
|
||||
"""Convert B tensor from N-major to K-major layout.
|
||||
|
||||
The kernel expects B with stride (E*K*N, 1, K) — K is contiguous.
|
||||
torch.stack produces stride (E*K*N, N, 1) — N is contiguous.
|
||||
|
||||
Args:
|
||||
b_tensor: (experts, K_packed, N_packed) float4_e2m1fn_x2, N-major
|
||||
|
||||
Returns:
|
||||
Same shape, K-major strides
|
||||
"""
|
||||
return b_tensor.permute(0, 2, 1).contiguous().permute(0, 2, 1)
|
||||
|
||||
|
||||
def compute_expert_offsets(tokens_per_expert, num_experts, device="cuda"):
|
||||
"""Compute cumulative token offsets for the grouped GEMM.
|
||||
|
||||
Args:
|
||||
tokens_per_expert: list of int, one per expert
|
||||
|
||||
Returns:
|
||||
offs: (num_experts,) int32 — cumulative sum
|
||||
"""
|
||||
offs = torch.tensor(
|
||||
[sum(tokens_per_expert[:e+1]) for e in range(num_experts)],
|
||||
dtype=torch.int32, device=device,
|
||||
)
|
||||
return offs
|
||||
|
||||
|
||||
# ── Kernel Launch ─────────────────────────────────────────────────────
|
||||
|
||||
# Cache compiled kernels by (num_experts, device, K, N)
|
||||
_compiled_kernel_cache = {}
|
||||
|
||||
|
||||
def run_nvfp4_grouped_gemm(
|
||||
mat_a, # (tokens_sum, K_packed) float4_e2m1fn_x2
|
||||
mat_b, # (experts, K_packed, N_packed) float4_e2m1fn_x2, K-major
|
||||
scale_a, # assembled 2D side (padded + swizzled)
|
||||
scale_b, # assembled 3D side (padded + swizzled)
|
||||
expert_offsets, # (experts+1,) int32 cumulative token offsets
|
||||
expert_offsets, # (experts,) int32 cumulative token offsets
|
||||
global_scale_a=None, # (experts,) float32
|
||||
global_scale_b=None, # (experts,) float32
|
||||
mma_tiler_mn=(128, 128),
|
||||
@@ -317,19 +217,12 @@ def run_nvfp4_grouped_gemm(
|
||||
"""Run the CuTeDSL NVFP4 scaled grouped GEMM.
|
||||
|
||||
2Dx3D: A(tokens, K) x B(experts, K, N) -> C(tokens, N)
|
||||
|
||||
Compiles with real tensors each call. The CuTeDSL kernel's TMA
|
||||
descriptors are bound to the compilation-time tensor addresses,
|
||||
so caching across different tensor allocations produces wrong results.
|
||||
|
||||
The forward call (after compilation) is cudagraph-safe.
|
||||
"""
|
||||
num_experts = mat_b.shape[0]
|
||||
n_dim = mat_b.shape[2]
|
||||
n_dim = mat_b.shape[2] # packed N (in float4 elements)
|
||||
tokens_sum = mat_a.shape[0]
|
||||
device = mat_a.device
|
||||
|
||||
out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=device)
|
||||
out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device)
|
||||
|
||||
kernel = ScaledGroupedGemmKernel(
|
||||
scenario="2Dx3D",
|
||||
@@ -354,7 +247,7 @@ def run_nvfp4_grouped_gemm(
|
||||
offs_c = to_cute(expert_offsets)
|
||||
|
||||
workspace_size = kernel.get_workspace_size(num_experts)
|
||||
workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=device)
|
||||
workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=mat_a.device)
|
||||
ws_c = to_cute(workspace)
|
||||
|
||||
gsa_c = to_cute(global_scale_a) if global_scale_a is not None else None
|
||||
@@ -376,5 +269,6 @@ def run_nvfp4_grouped_gemm(
|
||||
stream,
|
||||
global_scale_a=gsa_c, global_scale_b=gsb_c,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user