Kernel-override Determinism [1/n] (#25603)

Signed-off-by: Bram Wasti <bwasti@meta.com>
This commit is contained in:
Bram Wasti
2025-09-26 19:59:09 -04:00
committed by GitHub
parent 4778b42660
commit dc48ba0c75
8 changed files with 890 additions and 4 deletions

View File

@@ -0,0 +1,561 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
from collections import namedtuple
from collections.abc import Callable
from typing import Any, Union
import torch
import triton
import triton.language as tl
def _matmul_launch_metadata(grid: Callable[..., Any], kernel: Any,
args: dict[str, Any]) -> dict[str, Any]:
ret = {}
m, n, k = args["M"], args["N"], args["K"]
ret["name"] = f"{kernel.name} [M={m}, N={n}, K={k}]"
if "tiles_per_update" in args:
ret["name"] = (f"{kernel.name} [M={m}, N={n}, K={k}, "
f"tiles_per_update={args['tiles_per_update']:02}]")
if "c_ptr" in args:
bytes_per_elem = args["c_ptr"].element_size()
else:
bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2
ret[f"flops{bytes_per_elem * 8}"] = 2.0 * m * n * k
ret["bytes"] = bytes_per_elem * (m * k + n * k + m * n)
return ret
@triton.jit
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
return pid_m, pid_n
@triton.jit(launch_metadata=_matmul_launch_metadata)
def matmul_kernel_persistent(
a_ptr,
b_ptr,
c_ptr, #
bias_ptr,
M,
N,
K, #
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_SIZE_M: tl.constexpr, #
BLOCK_SIZE_N: tl.constexpr, #
BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
NUM_SMS: tl.constexpr, #
A_LARGE: tl.constexpr,
B_LARGE: tl.constexpr,
C_LARGE: tl.constexpr,
HAS_BIAS: tl.constexpr,
):
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
num_tiles = num_pid_m * num_pid_n
tile_id_c = start_pid - NUM_SMS
offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True):
pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m,
GROUP_SIZE_M, NUM_SMS)
start_m = pid_m * BLOCK_SIZE_M
start_n = pid_n * BLOCK_SIZE_N
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
if A_LARGE:
offs_am = offs_am.to(tl.int64)
if B_LARGE:
offs_bn = offs_bn.to(tl.int64)
offs_am = tl.where(offs_am < M, offs_am, 0)
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M),
BLOCK_SIZE_M)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N),
BLOCK_SIZE_N)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for ki in range(k_tiles):
if A_LARGE or B_LARGE:
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(
tl.int64)
else:
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am +
offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk +
offs_bn[None, :] * stride_bn)
a = tl.load(a_ptrs,
mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K,
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K,
other=0.0)
accumulator = tl.dot(a, b, accumulator)
tile_id_c += NUM_SMS
pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m,
GROUP_SIZE_M, NUM_SMS)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
if C_LARGE:
offs_cm = offs_cm.to(tl.int64)
offs_cn = offs_cn.to(tl.int64)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[
None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
if HAS_BIAS:
bias_ptrs = bias_ptr + offs_cn
bias = tl.load(bias_ptrs, mask=offs_cn < N,
other=0.0).to(tl.float32)
accumulator += bias
if c_ptr.dtype.element_ty == tl.float8e4nv:
c = accumulator.to(tl.float8e4nv)
else:
c = accumulator.to(tl.float16)
tl.store(c_ptrs, c, mask=c_mask)
def matmul_persistent(a: torch.Tensor,
b: torch.Tensor,
bias: Union[torch.Tensor, None] = None):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.dtype == b.dtype, "Incompatible dtypes"
assert bias is None or bias.dim() == 1, (
"Currently assuming bias is 1D, let Horace know if you run into this")
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
M, K = a.shape
K, N = b.shape
dtype = a.dtype
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=dtype)
# 1D launch kernel where each block gets its own program.
def grid(META):
return (min(
NUM_SMS,
triton.cdiv(M, META["BLOCK_SIZE_M"]) *
triton.cdiv(N, META["BLOCK_SIZE_N"])), )
configs = {
torch.bfloat16: {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_stages": 3,
"num_warps": 8,
},
torch.float16: {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_stages": 3,
"num_warps": 8,
},
torch.float32: {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"num_stages": 3,
"num_warps": 8,
},
}
# print(a.device, b.device, c.device)
matmul_kernel_persistent[grid](
a,
b,
c, #
bias,
M,
N,
K, #
a.stride(0),
a.stride(1), #
b.stride(0),
b.stride(1), #
c.stride(0),
c.stride(1), #
NUM_SMS=NUM_SMS, #
A_LARGE=a.numel() > 2**31,
B_LARGE=b.numel() > 2**31,
C_LARGE=c.numel() > 2**31,
HAS_BIAS=bias is not None,
**configs[dtype],
)
return c
@triton.jit
def _log_softmax_kernel(
input_ptr,
output_ptr,
input_row_stride,
output_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
"""
Compute log_softmax along the last dimension of a 2D tensor.
Each block handles one row of the input tensor.
"""
# Get the row index for this block
row_idx = tl.program_id(0).to(tl.int64)
# Compute base pointers for input and output rows
row_start_ptr = input_ptr + row_idx * input_row_stride
output_row_start_ptr = output_ptr + row_idx * output_row_stride
# Step 1: Find maximum value in the row for numerical stability
max_val = -float("inf")
for col_offset in range(0, n_cols, BLOCK_SIZE):
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
mask = col_idx < n_cols
# Load values
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=-float("inf"))
# Update maximum
max_val = tl.max(tl.maximum(vals, max_val))
# Step 2: Compute sum of exp(x - max_val)
sum_exp = 0.0
for col_offset in range(0, n_cols, BLOCK_SIZE):
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
mask = col_idx < n_cols
# Load values
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
# Compute exp(x - max_val) and accumulate
exp_vals = tl.exp(vals - max_val)
sum_exp += tl.sum(tl.where(mask, exp_vals, 0.0))
# Compute log(sum_exp)
log_sum_exp = tl.log(sum_exp)
# Step 3: Compute final log_softmax values: x - max_val - log_sum_exp
for col_offset in range(0, n_cols, BLOCK_SIZE):
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
mask = col_idx < n_cols
# Load values
vals = tl.load(row_start_ptr + col_idx, mask=mask)
# Compute log_softmax
output = vals - max_val - log_sum_exp
# Store results
tl.store(output_row_start_ptr + col_idx, output, mask=mask)
def log_softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor:
"""
Compute log_softmax using Triton kernel.
Args:
input: Input tensor
dim: Dimension along which to compute log_softmax
(only -1 or last dim supported)
>> Stashed changes
Returns:
Tensor with log_softmax applied along the specified dimension
"""
if dim != -1 and dim != input.ndim - 1:
raise ValueError("This implementation only supports log_softmax along "
"the last dimension")
# Flatten all dimensions except the last one
original_shape = input.shape
input_2d = input.reshape(-1, input.shape[-1])
input_2d = input_2d.contiguous()
n_rows, n_cols = input_2d.shape
# Allocate output tensor
output = torch.empty_like(input_2d)
# Choose block size based on the number of columns
BLOCK_SIZE = 1024
# Launch kernel with one block per row
grid = (n_rows, )
_log_softmax_kernel[grid](
input_2d,
output,
input_2d.stride(0),
output.stride(0),
n_cols,
BLOCK_SIZE=BLOCK_SIZE,
)
# Reshape output back to original shape
return output.reshape(original_shape)
@triton.jit
def mean_kernel(
input_ptr,
output_ptr,
input_stride0,
input_stride1,
input_stride2,
output_stride0,
output_stride1,
M, # size before reduction dim
N, # size of reduction dim
K, # size after reduction dim
BLOCK_SIZE: tl.constexpr,
):
"""
Kernel for computing mean along a single dimension.
Input is viewed as (M, N, K) where N is the dimension being reduced.
"""
# Program ID gives us which output element we're computing
pid = tl.program_id(0)
# Compute output indices
m_idx = pid // K
k_idx = pid % K
# Bounds check
if m_idx >= M or k_idx >= K:
return
# Accumulate sum across reduction dimension
acc = 0.0
for n_start in range(0, N, BLOCK_SIZE):
n_offsets = n_start + tl.arange(0, BLOCK_SIZE)
mask = n_offsets < N
# Calculate input indices
input_idx = m_idx * input_stride0 + n_offsets * input_stride1 \
+ k_idx * input_stride2
# Load and accumulate
vals = tl.load(input_ptr + input_idx, mask=mask, other=0.0)
acc += tl.sum(vals)
# Compute mean and store
mean_val = acc / N
output_idx = m_idx * output_stride0 + k_idx * output_stride1
tl.store(output_ptr + output_idx, mean_val)
def mean_dim(input: torch.Tensor,
dim: int,
keepdim: bool = False,
dtype: Union[torch.dtype, None] = None) -> torch.Tensor:
"""
Triton implementation of torch.mean with single dimension reduction.
Args:
input: Input tensor
dim: Single dimension along which to compute mean
keepdim: Whether to keep the reduced dimension
dtype: Output dtype. If None, uses input dtype
(or float32 for integer inputs)
Returns:
Tensor with mean values along specified dimension
"""
# Validate inputs
assert input.is_cuda, "Input must be a CUDA tensor"
assert -input.ndim <= dim < input.ndim, (
f"Invalid dimension {dim} for tensor with {input.ndim} dimensions")
# Handle negative dim
if dim < 0:
dim = dim + input.ndim
# Handle dtype
if dtype is None:
if input.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
dtype = torch.float32
else:
dtype = input.dtype
# Convert input to appropriate dtype if needed
if input.dtype != dtype:
input = input.to(dtype)
# Get input shape and strides
shape = list(input.shape)
# Calculate dimensions for kernel
M = 1
for i in range(dim):
M *= shape[i]
N = shape[dim]
K = 1
for i in range(dim + 1, len(shape)):
K *= shape[i]
# Reshape input to 3D view (M, N, K)
input_3d = input.reshape(M, N, K)
# Create output shape
if keepdim:
output_shape = shape.copy()
output_shape[dim] = 1
else:
output_shape = shape[:dim] + shape[dim + 1:]
# Create output tensor
output = torch.empty(output_shape, dtype=dtype, device=input.device)
# Reshape output for kernel
if keepdim:
output_2d = output.reshape(M, 1, K).squeeze(1)
else:
output_2d = output.reshape(M, K)
# Launch kernel
grid = (M * K, )
BLOCK_SIZE = 1024
mean_kernel[grid](
input_3d,
output_2d,
input_3d.stride(0),
input_3d.stride(1),
input_3d.stride(2),
output_2d.stride(0),
output_2d.stride(1) if output_2d.ndim > 1 else 0,
M,
N,
K,
BLOCK_SIZE,
)
return output
def mm_batch_invariant(a, b):
return matmul_persistent(a, b)
def addmm_batch_invariant(bias, a, b):
return matmul_persistent(a, b, bias=bias)
def _log_softmax_batch_invariant(input, dim, _half_to_float):
assert not _half_to_float, "not implemented"
return log_softmax(input, dim=dim)
def mean_batch_invariant(input,
dim,
keepdim=False,
dtype: Union[torch.dtype, None] = None):
assert dtype is None or dtype == torch.float32, \
f"unsupported dtype: {dtype}"
result = input.to(torch.float32)
# Sort dimensions to reduce from largest to smallest to handle shifting dims
# during iterative reduction.
sorted_dims = sorted([d % input.ndim for d in dim], reverse=True)
# Iteratively apply a deterministic mean.
for d in sorted_dims:
result = mean_dim(result, dim=d, keepdim=True)
if not keepdim:
# Squeeze the reduced dimensions.
for d in sorted_dims:
result = result.squeeze(d)
return result
_batch_invariant_MODE = False
_batch_invariant_LIB = None
def is_batch_invariant_mode_enabled():
return _batch_invariant_MODE
def enable_batch_invariant_mode():
global _batch_invariant_MODE, _batch_invariant_LIB
if _batch_invariant_MODE:
return
_batch_invariant_MODE = True
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::_log_softmax",
_log_softmax_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
def disable_batch_invariant_mode():
global _batch_invariant_MODE, _batch_invariant_LIB
if _batch_invariant_LIB is not None:
_batch_invariant_LIB._destroy()
_batch_invariant_MODE = False
_batch_invariant_LIB = None
@contextlib.contextmanager
def set_batch_invariant_mode(enabled: bool = True):
global _batch_invariant_MODE, _batch_invariant_LIB
old_data = (_batch_invariant_MODE, _batch_invariant_LIB)
if enabled:
enable_batch_invariant_mode()
else:
disable_batch_invariant_mode()
yield
if _batch_invariant_LIB is not None:
_batch_invariant_LIB._destroy()
_batch_invariant_MODE, _batch_invariant_LIB = old_data
AttentionBlockSize = namedtuple("AttentionBlockSize", ["block_m", "block_n"])
def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
return AttentionBlockSize(block_m=16, block_n=16)
def vllm_kernel_override_batch_invariant():
env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"
is_overridden = False
val = os.getenv(env_key, "0")
try:
is_overridden = int(val) != 0
except ValueError:
is_overridden = False
return is_overridden
def init_batch_invariance():
# this will hit all the csrc overrides as well
if vllm_kernel_override_batch_invariant():
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION"
enable_batch_invariant_mode()