Deepseek-v3 Batch Invariant on 8xH100 (#26609)

Signed-off-by: Bram Wasti <bwasti@meta.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Bram Wasti
2025-10-15 22:06:02 -07:00
committed by GitHub
parent 785d8b6410
commit 7d8975de84
21 changed files with 1567 additions and 102 deletions

View File

@@ -395,7 +395,6 @@ def mean_dim(
Tensor with mean values along specified dimension
"""
# Validate inputs
assert input.is_cuda, "Input must be a CUDA tensor"
assert -input.ndim <= dim < input.ndim, (
f"Invalid dimension {dim} for tensor with {input.ndim} dimensions"
)
@@ -470,6 +469,45 @@ def mm_batch_invariant(a, b):
return matmul_persistent(a, b)
def matmul_batch_invariant(a, b, *, out=None):
# torch.matmul can handle various dimensions
# For 2D x 2D, it's the same as mm
if a.ndim == 2 and b.ndim == 2:
result = matmul_persistent(a, b)
if out is not None:
out.copy_(result)
return out
return result
elif a.ndim == 3 and b.ndim == 3:
# Handle batched case like bmm
return bmm_batch_invariant(a, b, out=out)
else:
raise ValueError(
f"matmul_batch_invariant currently only supports 2D x 2D and 3D x 3D, "
f"got shapes {a.shape} and {b.shape}"
)
def bmm_batch_invariant(a, b, *, out=None):
# Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)
# Process each batch separately with our persistent kernel
if a.ndim == 3 and b.ndim == 3:
results = []
for i in range(a.shape[0]):
results.append(matmul_persistent(a[i], b[i]))
result = torch.stack(results, dim=0)
if out is not None:
out.copy_(result)
return out
return result
else:
raise ValueError(
f"bmm_batch_invariant expects 3D tensors, "
f"got shapes {a.shape} and {b.shape}"
)
def addmm_batch_invariant(bias, a, b):
return matmul_persistent(a, b, bias=bias)
@@ -479,11 +517,24 @@ def _log_softmax_batch_invariant(input, dim, _half_to_float):
return log_softmax(input, dim=dim)
def softmax_batch_invariant(input, dim, dtype=None):
# Compute softmax in a deterministic way
# First subtract max for numerical stability (standard practice)
input_max = torch.amax(input, dim=dim, keepdim=True)
input = input - input_max
exp_x = torch.exp(input)
sum_exp_x = torch.sum(exp_x, dim=dim, keepdim=True)
return exp_x / sum_exp_x
def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = None):
assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}"
result = input.to(torch.float32)
if len(dim) == 0:
dim = [i for i in range(len(input.shape))]
# Sort dimensions to reduce from largest to smallest to handle shifting dims
# during iterative reduction.
sorted_dims = sorted([d % input.ndim for d in dim], reverse=True)
@@ -500,8 +551,134 @@ def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None =
return result
@triton.jit
def _rms_norm_kernel(
input_ptr,
weight_ptr,
output_ptr,
input_row_stride,
output_row_stride,
n_cols,
eps,
BLOCK_SIZE: tl.constexpr,
):
"""
Compute RMS normalization along the last dimension of a 2D tensor.
RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight
Each block handles one row of the input tensor.
"""
row_idx = tl.program_id(0).to(tl.int64)
row_start_ptr = input_ptr + row_idx * input_row_stride
output_row_start_ptr = output_ptr + row_idx * output_row_stride
# Step 1: Compute sum of squares in float32 to avoid overflow
sum_sq = tl.zeros([1], dtype=tl.float32)
for col_offset in range(0, n_cols, BLOCK_SIZE):
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
mask = col_idx < n_cols
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
# Convert to float32 for accumulation to prevent overflow
vals_f32 = vals.to(tl.float32)
sq_vals = vals_f32 * vals_f32
sum_sq += tl.sum(tl.where(mask, sq_vals, 0.0))
# Step 2: Compute RMS (root mean square) in float32
mean_sq = sum_sq / n_cols
rms = tl.sqrt(mean_sq + eps)
inv_rms = 1.0 / rms
# Step 3: Normalize and apply weight
for col_offset in range(0, n_cols, BLOCK_SIZE):
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
mask = col_idx < n_cols
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
weight = tl.load(weight_ptr + col_idx, mask=mask, other=1.0)
# Compute in float32 then convert back to input dtype
vals_f32 = vals.to(tl.float32)
weight_f32 = weight.to(tl.float32)
output_f32 = vals_f32 * inv_rms * weight_f32
output = output_f32.to(vals.dtype)
tl.store(output_row_start_ptr + col_idx, output, mask=mask)
def rms_norm(
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> torch.Tensor:
"""
Compute RMS normalization using Triton kernel.
RMS Norm normalizes the input by the root mean square and scales by weight:
output = input / sqrt(mean(input^2) + eps) * weight
Args:
input: Input tensor of shape (..., hidden_size)
weight: Weight tensor of shape (hidden_size,)
eps: Small constant for numerical stability
Returns:
Tensor with RMS normalization applied along the last dimension
"""
assert weight.dim() == 1, "Weight must be 1-dimensional"
assert input.shape[-1] == weight.shape[0], (
f"Input last dimension ({input.shape[-1]}) must match "
f"weight dimension ({weight.shape[0]})"
)
# Flatten all dimensions except the last one
original_shape = input.shape
input_2d = input.reshape(-1, input.shape[-1])
input_2d = input_2d.contiguous()
weight = weight.contiguous()
n_rows, n_cols = input_2d.shape
output = torch.empty_like(input_2d)
BLOCK_SIZE = 1024
grid = (n_rows,)
_rms_norm_kernel[grid](
input_2d,
weight,
output,
input_2d.stride(0),
output.stride(0),
n_cols,
eps,
BLOCK_SIZE=BLOCK_SIZE,
)
return output.reshape(original_shape)
def rms_norm_batch_invariant(
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> torch.Tensor:
"""
Batch-invariant wrapper for RMS normalization.
This function provides a deterministic, batch-invariant implementation
of RMS normalization for use with the batch_invariant mode.
Args:
input: Input tensor of shape (..., hidden_size)
weight: Weight tensor of shape (hidden_size,)
eps: Small constant for numerical stability
Returns:
RMS normalized tensor
"""
return rms_norm(input, weight, eps=eps)
def linear_batch_invariant(input, weight, bias=None):
output = mm_batch_invariant(input, weight.t())
if bias is not None:
output = output + bias
return output
_batch_invariant_MODE = False
_batch_invariant_LIB = None
_original_torch_bmm = None
def is_batch_invariant_mode_enabled():
@@ -509,7 +686,7 @@ def is_batch_invariant_mode_enabled():
def enable_batch_invariant_mode():
global _batch_invariant_MODE, _batch_invariant_LIB
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
if _batch_invariant_MODE:
return
@@ -517,16 +694,28 @@ def enable_batch_invariant_mode():
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
_batch_invariant_LIB.impl(
"aten::_log_softmax", _log_softmax_batch_invariant, "CUDA"
)
_batch_invariant_LIB.impl("aten::softmax", softmax_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::_softmax", softmax_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
# Also monkeypatch torch.bmm directly as a fallback
_original_torch_bmm = torch.bmm
torch.bmm = bmm_batch_invariant
def disable_batch_invariant_mode():
global _batch_invariant_MODE, _batch_invariant_LIB
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
if _batch_invariant_LIB is not None:
_batch_invariant_LIB._destroy()
if _original_torch_bmm is not None:
torch.bmm = _original_torch_bmm
_original_torch_bmm = None
_batch_invariant_MODE = False
_batch_invariant_LIB = None
@@ -563,17 +752,55 @@ def vllm_kernel_override_batch_invariant():
return is_overridden
def override_envs_for_invariance():
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
supported_backends = [
"FLASH_ATTN", # best supported backend
"FLEX_ATTENTION",
"FLASHINFER",
"FLASH_ATTN_MLA",
"TRITON_MLA",
# Not yet supported MLA backends
# "FLASHMLA",
# "FLASHINFER_MLA",
]
if curr_attn_backend not in supported_backends:
warning = (
"Forcibly updating attention backend to"
f" {supported_backends[0]} for batch_invariant. "
f" Supported backends: {supported_backends}."
)
logger.warning_once(warning)
os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
if os.environ["VLLM_ATTENTION_BACKEND"] != supported_backends[0]:
warning = (
"You are using a decode-invariant form of batch invariance. "
"This will not be invariant between prefill and decode."
)
logger.warning_once(warning)
os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
# NCCL determinism settings
os.environ["NCCL_LAUNCH_MODE"] = "GROUP"
os.environ["NCCL_COLLNET_ENABLE"] = "0"
os.environ["NCCL_NVLS_ENABLE"] = "0"
os.environ["NCCL_P2P_NET_DISABLE"] = "1"
os.environ["NCCL_MIN_NCHANNELS"] = "1"
os.environ["NCCL_MAX_NCHANNELS"] = "1"
os.environ["NCCL_PROTO"] = "Simple"
os.environ["NCCL_ALGO"] = "allreduce:tree"
os.environ["NCCL_NTHREADS"] = "1"
os.environ["NCCL_SOCKET_NTHREADS"] = "1"
def init_batch_invariance():
# this will hit all the csrc overrides as well
if vllm_kernel_override_batch_invariant():
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
supported_backends = ["FLEX_ATTENTION", "FLASHINFER"]
if curr_attn_backend not in supported_backends:
warning = (
"Forcibly updating attention backend to"
f" {supported_backends[0]} for batch_invariant. "
f" Supported backends: {supported_backends}."
)
logger.warning_once(warning)
os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
override_envs_for_invariance()
enable_batch_invariant_mode()
# Disable TF32 for batch invariance - it causes non-deterministic rounding
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False