Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -11,14 +11,17 @@ import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
def _matmul_launch_metadata(grid: Callable[..., Any], kernel: Any,
|
||||
args: dict[str, Any]) -> dict[str, Any]:
|
||||
def _matmul_launch_metadata(
|
||||
grid: Callable[..., Any], kernel: Any, args: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
ret = {}
|
||||
m, n, k = args["M"], args["N"], args["K"]
|
||||
ret["name"] = f"{kernel.name} [M={m}, N={n}, K={k}]"
|
||||
if "tiles_per_update" in args:
|
||||
ret["name"] = (f"{kernel.name} [M={m}, N={n}, K={k}, "
|
||||
f"tiles_per_update={args['tiles_per_update']:02}]")
|
||||
ret["name"] = (
|
||||
f"{kernel.name} [M={m}, N={n}, K={k}, "
|
||||
f"tiles_per_update={args['tiles_per_update']:02}]"
|
||||
)
|
||||
if "c_ptr" in args:
|
||||
bytes_per_elem = args["c_ptr"].element_size()
|
||||
else:
|
||||
@@ -75,8 +78,9 @@ def matmul_kernel_persistent(
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
|
||||
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True):
|
||||
pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m,
|
||||
GROUP_SIZE_M, NUM_SMS)
|
||||
pid_m, pid_n = _compute_pid(
|
||||
tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
|
||||
)
|
||||
start_m = pid_m * BLOCK_SIZE_M
|
||||
start_n = pid_n * BLOCK_SIZE_N
|
||||
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
|
||||
@@ -87,46 +91,44 @@ def matmul_kernel_persistent(
|
||||
offs_bn = offs_bn.to(tl.int64)
|
||||
offs_am = tl.where(offs_am < M, offs_am, 0)
|
||||
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
|
||||
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M),
|
||||
BLOCK_SIZE_M)
|
||||
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N),
|
||||
BLOCK_SIZE_N)
|
||||
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
|
||||
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for ki in range(k_tiles):
|
||||
if A_LARGE or B_LARGE:
|
||||
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(
|
||||
tl.int64)
|
||||
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
|
||||
else:
|
||||
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am +
|
||||
offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk +
|
||||
offs_bn[None, :] * stride_bn)
|
||||
a_ptrs = a_ptr + (
|
||||
offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
|
||||
)
|
||||
b_ptrs = b_ptr + (
|
||||
offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
|
||||
)
|
||||
|
||||
a = tl.load(a_ptrs,
|
||||
mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K,
|
||||
other=0.0)
|
||||
b = tl.load(b_ptrs,
|
||||
mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K,
|
||||
other=0.0)
|
||||
a = tl.load(
|
||||
a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0
|
||||
)
|
||||
b = tl.load(
|
||||
b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0
|
||||
)
|
||||
accumulator = tl.dot(a, b, accumulator)
|
||||
|
||||
tile_id_c += NUM_SMS
|
||||
pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m,
|
||||
GROUP_SIZE_M, NUM_SMS)
|
||||
pid_m, pid_n = _compute_pid(
|
||||
tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
|
||||
)
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
if C_LARGE:
|
||||
offs_cm = offs_cm.to(tl.int64)
|
||||
offs_cn = offs_cn.to(tl.int64)
|
||||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[
|
||||
None, :]
|
||||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
if HAS_BIAS:
|
||||
bias_ptrs = bias_ptr + offs_cn
|
||||
bias = tl.load(bias_ptrs, mask=offs_cn < N,
|
||||
other=0.0).to(tl.float32)
|
||||
bias = tl.load(bias_ptrs, mask=offs_cn < N, other=0.0).to(tl.float32)
|
||||
accumulator += bias
|
||||
if c_ptr.dtype.element_ty == tl.float8e4nv:
|
||||
c = accumulator.to(tl.float8e4nv)
|
||||
@@ -135,14 +137,15 @@ def matmul_kernel_persistent(
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
|
||||
def matmul_persistent(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
bias: Union[torch.Tensor, None] = None):
|
||||
def matmul_persistent(
|
||||
a: torch.Tensor, b: torch.Tensor, bias: Union[torch.Tensor, None] = None
|
||||
):
|
||||
# Check constraints.
|
||||
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
|
||||
assert a.dtype == b.dtype, "Incompatible dtypes"
|
||||
assert bias is None or bias.dim() == 1, (
|
||||
"Currently assuming bias is 1D, let Horace know if you run into this")
|
||||
"Currently assuming bias is 1D, let Horace know if you run into this"
|
||||
)
|
||||
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
||||
M, K = a.shape
|
||||
K, N = b.shape
|
||||
@@ -152,10 +155,13 @@ def matmul_persistent(a: torch.Tensor,
|
||||
|
||||
# 1D launch kernel where each block gets its own program.
|
||||
def grid(META):
|
||||
return (min(
|
||||
NUM_SMS,
|
||||
triton.cdiv(M, META["BLOCK_SIZE_M"]) *
|
||||
triton.cdiv(N, META["BLOCK_SIZE_N"])), )
|
||||
return (
|
||||
min(
|
||||
NUM_SMS,
|
||||
triton.cdiv(M, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
),
|
||||
)
|
||||
|
||||
configs = {
|
||||
torch.bfloat16: {
|
||||
@@ -284,8 +290,9 @@ def log_softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
Tensor with log_softmax applied along the specified dimension
|
||||
"""
|
||||
if dim != -1 and dim != input.ndim - 1:
|
||||
raise ValueError("This implementation only supports log_softmax along "
|
||||
"the last dimension")
|
||||
raise ValueError(
|
||||
"This implementation only supports log_softmax along the last dimension"
|
||||
)
|
||||
|
||||
# Flatten all dimensions except the last one
|
||||
original_shape = input.shape
|
||||
@@ -301,7 +308,7 @@ def log_softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
BLOCK_SIZE = 1024
|
||||
|
||||
# Launch kernel with one block per row
|
||||
grid = (n_rows, )
|
||||
grid = (n_rows,)
|
||||
_log_softmax_kernel[grid](
|
||||
input_2d,
|
||||
output,
|
||||
@@ -350,8 +357,9 @@ def mean_kernel(
|
||||
mask = n_offsets < N
|
||||
|
||||
# Calculate input indices
|
||||
input_idx = m_idx * input_stride0 + n_offsets * input_stride1 \
|
||||
+ k_idx * input_stride2
|
||||
input_idx = (
|
||||
m_idx * input_stride0 + n_offsets * input_stride1 + k_idx * input_stride2
|
||||
)
|
||||
|
||||
# Load and accumulate
|
||||
vals = tl.load(input_ptr + input_idx, mask=mask, other=0.0)
|
||||
@@ -363,10 +371,12 @@ def mean_kernel(
|
||||
tl.store(output_ptr + output_idx, mean_val)
|
||||
|
||||
|
||||
def mean_dim(input: torch.Tensor,
|
||||
dim: int,
|
||||
keepdim: bool = False,
|
||||
dtype: Union[torch.dtype, None] = None) -> torch.Tensor:
|
||||
def mean_dim(
|
||||
input: torch.Tensor,
|
||||
dim: int,
|
||||
keepdim: bool = False,
|
||||
dtype: Union[torch.dtype, None] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Triton implementation of torch.mean with single dimension reduction.
|
||||
|
||||
@@ -383,7 +393,8 @@ def mean_dim(input: torch.Tensor,
|
||||
# Validate inputs
|
||||
assert input.is_cuda, "Input must be a CUDA tensor"
|
||||
assert -input.ndim <= dim < input.ndim, (
|
||||
f"Invalid dimension {dim} for tensor with {input.ndim} dimensions")
|
||||
f"Invalid dimension {dim} for tensor with {input.ndim} dimensions"
|
||||
)
|
||||
|
||||
# Handle negative dim
|
||||
if dim < 0:
|
||||
@@ -422,7 +433,7 @@ def mean_dim(input: torch.Tensor,
|
||||
output_shape = shape.copy()
|
||||
output_shape[dim] = 1
|
||||
else:
|
||||
output_shape = shape[:dim] + shape[dim + 1:]
|
||||
output_shape = shape[:dim] + shape[dim + 1 :]
|
||||
|
||||
# Create output tensor
|
||||
output = torch.empty(output_shape, dtype=dtype, device=input.device)
|
||||
@@ -434,7 +445,7 @@ def mean_dim(input: torch.Tensor,
|
||||
output_2d = output.reshape(M, K)
|
||||
|
||||
# Launch kernel
|
||||
grid = (M * K, )
|
||||
grid = (M * K,)
|
||||
BLOCK_SIZE = 1024
|
||||
|
||||
mean_kernel[grid](
|
||||
@@ -467,12 +478,10 @@ def _log_softmax_batch_invariant(input, dim, _half_to_float):
|
||||
return log_softmax(input, dim=dim)
|
||||
|
||||
|
||||
def mean_batch_invariant(input,
|
||||
dim,
|
||||
keepdim=False,
|
||||
dtype: Union[torch.dtype, None] = None):
|
||||
assert dtype is None or dtype == torch.float32, \
|
||||
f"unsupported dtype: {dtype}"
|
||||
def mean_batch_invariant(
|
||||
input, dim, keepdim=False, dtype: Union[torch.dtype, None] = None
|
||||
):
|
||||
assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}"
|
||||
|
||||
result = input.to(torch.float32)
|
||||
|
||||
@@ -509,8 +518,9 @@ def enable_batch_invariant_mode():
|
||||
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
|
||||
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::_log_softmax",
|
||||
_log_softmax_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl(
|
||||
"aten::_log_softmax", _log_softmax_batch_invariant, "CUDA"
|
||||
)
|
||||
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user