[PERF] [Qwen3-next] Speed up gated RMSNorm (#26207)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Vadim Gimpelson
2025-10-12 12:27:50 +04:00
committed by GitHub
parent 4ca204055e
commit 82e64c7a20
2 changed files with 475 additions and 33 deletions

View File

@@ -13,6 +13,7 @@
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
from functools import lru_cache
from typing import Optional
import torch
@@ -21,6 +22,7 @@ import torch.nn.functional as F
from einops import rearrange
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, next_power_of_2
from .utils import input_guard
@@ -76,55 +78,103 @@ def layer_norm_fwd_kernel(
stride_y_row,
stride_z_row,
M, # number of rows in X
N, # number of columns in X
N: tl.constexpr, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_N: tl.constexpr,
ROWS_PER_BLOCK: tl.constexpr,
HAS_BIAS: tl.constexpr,
HAS_Z: tl.constexpr,
NORM_BEFORE_GATE: tl.constexpr,
IS_RMS_NORM: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
# Map the program id to the starting row of X and Y it should compute.
row_start = tl.program_id(0) * ROWS_PER_BLOCK
group = tl.program_id(1)
X += row * stride_x_row + group * N
Y += row * stride_y_row + group * N
if HAS_Z:
Z += row * stride_z_row + group * N
if not IS_RMS_NORM:
Mean += group * M
Rstd += group * M
W += group * N
if HAS_BIAS:
B += group * N
# Compute mean and variance
# Create 2D tile: [ROWS_PER_BLOCK, BLOCK_N]
rows = row_start + tl.arange(0, ROWS_PER_BLOCK)
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
# Compute offsets for 2D tile
row_offsets = rows[:, None] * stride_x_row
col_offsets = cols[None, :] + group * N
# Base pointers
X_base = X + row_offsets + col_offsets
Y_base = Y + rows[:, None] * stride_y_row + col_offsets
# Create mask for valid rows and columns
row_mask = rows[:, None] < M
col_mask = cols[None, :] < N
mask = row_mask & col_mask
# Load input data with 2D tile
x = tl.load(X_base, mask=mask, other=0.0).to(tl.float32)
if HAS_Z and not NORM_BEFORE_GATE:
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
Z_base = Z + rows[:, None] * stride_z_row + col_offsets
z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)
x *= z * tl.sigmoid(z)
# Compute mean and variance per row (reduce along axis 1)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
tl.store(Mean + row, mean)
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
mean = tl.sum(x, axis=1) / N # Shape: [ROWS_PER_BLOCK]
# Store mean for each row
mean_offsets = group * M + rows
mean_mask = rows < M
tl.store(Mean + mean_offsets, mean, mask=mean_mask)
# Broadcast mean back to 2D for subtraction
xbar = tl.where(mask, x - mean[:, None], 0.0)
var = tl.sum(xbar * xbar, axis=1) / N # Shape: [ROWS_PER_BLOCK]
else:
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
mask = cols < N
w = tl.load(W + cols, mask=mask).to(tl.float32)
xbar = tl.where(mask, x, 0.0)
var = tl.sum(xbar * xbar, axis=1) / N # Shape: [ROWS_PER_BLOCK]
mean = 0.0 # Placeholder for RMS norm
rstd = tl.rsqrt(var + eps) # Shape: [ROWS_PER_BLOCK]
# Store rstd for each row
rstd_offsets = group * M + rows
rstd_mask = rows < M
tl.store(Rstd + rstd_offsets, rstd, mask=rstd_mask)
# Load weights and biases (broadcast across rows)
w_offsets = cols + group * N
w_mask = cols < N
w = tl.load(W + w_offsets, mask=w_mask, other=0.0).to(tl.float32)
if HAS_BIAS:
b = tl.load(B + cols, mask=mask).to(tl.float32)
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
y = x_hat * w + b if HAS_BIAS else x_hat * w
b = tl.load(B + w_offsets, mask=w_mask, other=0.0).to(tl.float32)
# Normalize and apply linear transformation
if not IS_RMS_NORM:
x_hat = (x - mean[:, None]) * rstd[:, None]
else:
x_hat = x * rstd[:, None]
y = x_hat * w[None, :] + b[None, :] if HAS_BIAS else x_hat * w[None, :]
if HAS_Z and NORM_BEFORE_GATE:
z = tl.load(Z + cols, mask=mask).to(tl.float32)
Z_base = Z + rows[:, None] * stride_z_row + col_offsets
z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)
y *= z * tl.sigmoid(z)
# Write output
tl.store(Y + cols, y, mask=mask)
tl.store(Y_base, y, mask=mask)
@lru_cache
def _get_sm_count(device: torch.device) -> int:
"""Get and cache the SM count for a given device."""
props = torch.cuda.get_device_properties(device)
return props.multi_processor_count
def calc_rows_per_block(M: int, device: torch.device) -> int:
sm_count = _get_sm_count(device)
rows_per_block = next_power_of_2(cdiv(M, 2 * sm_count))
rows_per_block = min(rows_per_block, 4)
return rows_per_block
def layer_norm_fwd(
@@ -171,7 +221,10 @@ def layer_norm_fwd(
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_N // 256, 1), 8)
grid = (M, ngroups)
# Calculate rows per block based on SM count
rows_per_block = calc_rows_per_block(M, x.device)
# Update grid to use rows_per_block
grid = (cdiv(M, rows_per_block), ngroups)
layer_norm_fwd_kernel[grid](
x,
out,
@@ -187,6 +240,7 @@ def layer_norm_fwd(
group_size,
eps,
BLOCK_N=BLOCK_N,
ROWS_PER_BLOCK=rows_per_block,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
num_warps=num_warps,