2024-12-26 18:33:30 -05:00
|
|
|
# Adapted from https://github.com/sgl-project/sglang/pull/2575
|
|
|
|
|
from typing import List, Optional, Tuple
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import triton
|
|
|
|
|
import triton.language as tl
|
|
|
|
|
|
2025-01-17 12:12:26 -05:00
|
|
|
from vllm.platforms import current_platform
|
|
|
|
|
|
2024-12-26 18:33:30 -05:00
|
|
|
|
|
|
|
|
def apply_w8a8_block_fp8_linear(
|
|
|
|
|
input: torch.Tensor,
|
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
block_size: List[int],
|
|
|
|
|
weight_scale: torch.Tensor,
|
|
|
|
|
input_scale: Optional[torch.Tensor] = None,
|
|
|
|
|
bias: Optional[torch.Tensor] = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
assert input_scale is None
|
|
|
|
|
# View input as 2D matrix for fp8 methods
|
|
|
|
|
input_2d = input.view(-1, input.shape[-1])
|
|
|
|
|
output_shape = [*input.shape[:-1], weight.shape[0]]
|
|
|
|
|
|
|
|
|
|
q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1])
|
|
|
|
|
output = w8a8_block_fp8_matmul(q_input,
|
|
|
|
|
weight,
|
|
|
|
|
x_scale,
|
|
|
|
|
weight_scale,
|
|
|
|
|
block_size,
|
|
|
|
|
output_dtype=input.dtype)
|
|
|
|
|
|
|
|
|
|
if bias is not None:
|
|
|
|
|
output = output + bias
|
|
|
|
|
return output.to(dtype=input.dtype).view(*output_shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def input_to_float8(
|
2025-01-17 12:12:26 -05:00
|
|
|
x: torch.Tensor,
|
|
|
|
|
dtype: Optional[torch.dtype] = None
|
2024-12-26 18:33:30 -05:00
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
"""This function quantizes input values to float8 values "
|
|
|
|
|
"with tensor-wise quantization."""
|
2025-01-17 12:12:26 -05:00
|
|
|
if dtype is None:
|
|
|
|
|
dtype = (torch.float8_e4m3fnuz
|
|
|
|
|
if current_platform.is_rocm() else torch.float8_e4m3fn)
|
2024-12-26 18:33:30 -05:00
|
|
|
finfo = torch.finfo(dtype)
|
|
|
|
|
min_val, max_val = x.aminmax()
|
|
|
|
|
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
|
|
|
|
scale = finfo.max / amax
|
|
|
|
|
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
|
|
|
|
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def block_quant_to_tensor_quant(
|
|
|
|
|
x_q_block: torch.Tensor,
|
|
|
|
|
x_s: torch.Tensor,
|
|
|
|
|
block_size: List[int],
|
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
"""This function converts block-wise quantization to tensor-wise
|
|
|
|
|
quantization. The inputs are block-wise quantization tensor `x_q_block`,
|
|
|
|
|
block-wise quantization scale and the block size.
|
|
|
|
|
The outputs are tensor-wise quantization tensor and tensor-wise
|
|
|
|
|
quantization scale. Note only float8 is supported for now.
|
|
|
|
|
"""
|
|
|
|
|
block_n, block_k = block_size[0], block_size[1]
|
|
|
|
|
n, k = x_q_block.shape
|
|
|
|
|
n_tiles = (n + block_n - 1) // block_n
|
|
|
|
|
k_tiles = (k + block_k - 1) // block_k
|
|
|
|
|
assert n_tiles == x_s.shape[0]
|
|
|
|
|
assert k_tiles == x_s.shape[1]
|
|
|
|
|
|
|
|
|
|
x_dq_block = x_q_block.to(torch.float32)
|
|
|
|
|
|
|
|
|
|
x_dq_block_tiles = [[
|
|
|
|
|
x_dq_block[j * block_n:min((j + 1) * block_n, n),
|
|
|
|
|
i * block_k:min((i + 1) * block_k, k), ]
|
|
|
|
|
for i in range(k_tiles)
|
|
|
|
|
] for j in range(n_tiles)]
|
|
|
|
|
|
|
|
|
|
for i in range(k_tiles):
|
|
|
|
|
for j in range(n_tiles):
|
|
|
|
|
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
|
|
|
|
|
|
|
|
|
|
x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
|
|
|
|
|
return x_q_tensor, scale
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@triton.jit
|
|
|
|
|
def _per_token_group_quant_fp8(
|
|
|
|
|
# Pointers to inputs and output
|
|
|
|
|
y_ptr,
|
|
|
|
|
y_q_ptr,
|
|
|
|
|
y_s_ptr,
|
|
|
|
|
# Stride of input
|
|
|
|
|
y_stride,
|
|
|
|
|
# Columns of input
|
|
|
|
|
N,
|
|
|
|
|
# Avoid to divide zero
|
|
|
|
|
eps,
|
|
|
|
|
# Information for float8
|
|
|
|
|
fp8_min,
|
|
|
|
|
fp8_max,
|
|
|
|
|
# Meta-parameters
|
|
|
|
|
BLOCK: tl.constexpr,
|
|
|
|
|
):
|
|
|
|
|
"""A Triton-accelerated function to perform per-token-group
|
|
|
|
|
quantization on a tensor.
|
|
|
|
|
This function converts the tensor values into float8 values.
|
|
|
|
|
"""
|
|
|
|
|
# Map the program id to the row of X and Y it should compute.
|
|
|
|
|
g_id = tl.program_id(0)
|
|
|
|
|
y_ptr += g_id * y_stride
|
|
|
|
|
y_q_ptr += g_id * y_stride
|
|
|
|
|
y_s_ptr += g_id
|
|
|
|
|
|
|
|
|
|
cols = tl.arange(0, BLOCK) # N <= BLOCK
|
|
|
|
|
mask = cols < N
|
|
|
|
|
|
|
|
|
|
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
|
|
|
|
# Quant
|
|
|
|
|
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
|
|
|
|
y_s = _absmax / fp8_max
|
|
|
|
|
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
|
|
|
|
|
|
|
|
|
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
|
|
|
|
tl.store(y_s_ptr, y_s)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def per_token_group_quant_fp8(
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
group_size: int,
|
|
|
|
|
eps: float = 1e-10,
|
2025-01-17 12:12:26 -05:00
|
|
|
dtype: Optional[torch.dtype] = None,
|
2024-12-26 18:33:30 -05:00
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
"""Function to perform per-token-group quantization on an input tensor `x`.
|
|
|
|
|
It converts the tensor values into signed float8 values and returns the
|
|
|
|
|
quantized tensor along with the scaling factor used for quantization.
|
|
|
|
|
Args:
|
|
|
|
|
x: The input tenosr with ndim >= 2.
|
|
|
|
|
group_size: The group size used for quantization.
|
|
|
|
|
eps: The minimum to avoid dividing zero.
|
|
|
|
|
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
|
|
|
|
|
is supported for now.
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
|
|
|
|
scaling factor for quantization.
|
|
|
|
|
"""
|
2025-01-17 12:12:26 -05:00
|
|
|
if dtype is None:
|
|
|
|
|
dtype = (torch.float8_e4m3fnuz
|
|
|
|
|
if current_platform.is_rocm() else torch.float8_e4m3fn)
|
2024-12-26 18:33:30 -05:00
|
|
|
assert (x.shape[-1] % group_size == 0), (
|
|
|
|
|
f"the last dimension of `x` {x.shape[-1]} must be divisible "
|
|
|
|
|
f"by `group_size` {group_size}")
|
|
|
|
|
assert x.is_contiguous(), "`x` must be contiguous"
|
|
|
|
|
|
|
|
|
|
finfo = torch.finfo(dtype)
|
|
|
|
|
fp8_min = finfo.min
|
|
|
|
|
fp8_max = finfo.max
|
|
|
|
|
|
|
|
|
|
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
|
|
|
|
M = x.numel() // group_size
|
|
|
|
|
N = group_size
|
|
|
|
|
x_s = torch.empty(
|
|
|
|
|
x.shape[:-1] + (x.shape[-1] // group_size, ),
|
|
|
|
|
device=x.device,
|
|
|
|
|
dtype=torch.float32,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
BLOCK = triton.next_power_of_2(N)
|
|
|
|
|
# heuristics for number of warps
|
|
|
|
|
num_warps = min(max(BLOCK // 256, 1), 8)
|
|
|
|
|
num_stages = 1
|
|
|
|
|
_per_token_group_quant_fp8[(M, )](
|
|
|
|
|
x,
|
|
|
|
|
x_q,
|
|
|
|
|
x_s,
|
|
|
|
|
group_size,
|
|
|
|
|
N,
|
|
|
|
|
eps,
|
|
|
|
|
fp8_min=fp8_min,
|
|
|
|
|
fp8_max=fp8_max,
|
|
|
|
|
BLOCK=BLOCK,
|
|
|
|
|
num_warps=num_warps,
|
|
|
|
|
num_stages=num_stages,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return x_q, x_s
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@triton.jit
|
|
|
|
|
def _w8a8_block_fp8_matmul(
|
|
|
|
|
# Pointers to inputs and output
|
|
|
|
|
A,
|
|
|
|
|
B,
|
|
|
|
|
C,
|
|
|
|
|
As,
|
|
|
|
|
Bs,
|
|
|
|
|
# Shape for matmul
|
|
|
|
|
M,
|
|
|
|
|
N,
|
|
|
|
|
K,
|
|
|
|
|
# Block size for block-wise quantization
|
|
|
|
|
group_n,
|
|
|
|
|
group_k,
|
|
|
|
|
# Stride for inputs and output
|
|
|
|
|
stride_am,
|
|
|
|
|
stride_ak,
|
|
|
|
|
stride_bk,
|
|
|
|
|
stride_bn,
|
|
|
|
|
stride_cm,
|
|
|
|
|
stride_cn,
|
|
|
|
|
stride_As_m,
|
|
|
|
|
stride_As_k,
|
|
|
|
|
stride_Bs_k,
|
|
|
|
|
stride_Bs_n,
|
|
|
|
|
# Meta-parameters
|
|
|
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
|
|
|
BLOCK_SIZE_N: tl.constexpr,
|
|
|
|
|
BLOCK_SIZE_K: tl.constexpr,
|
|
|
|
|
GROUP_SIZE_M: tl.constexpr,
|
|
|
|
|
):
|
|
|
|
|
"""Triton-accelerated function used to perform linear operations (dot
|
|
|
|
|
product) on input tensors `A` and `B` with block-wise quantization, and
|
|
|
|
|
store the result in output tensor `C`.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
pid = tl.program_id(axis=0)
|
|
|
|
|
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
|
|
|
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
|
|
|
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
|
|
|
|
group_id = pid // num_pid_in_group
|
|
|
|
|
first_pid_m = group_id * GROUP_SIZE_M
|
|
|
|
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
|
|
|
|
pid_m = first_pid_m + (pid % group_size_m)
|
|
|
|
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
|
|
|
|
|
|
|
|
|
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
|
|
|
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
|
|
|
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
|
|
|
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
|
|
|
|
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
|
|
|
|
|
|
|
|
|
As_ptrs = As + offs_am * stride_As_m
|
|
|
|
|
offs_bsn = offs_bn // group_n
|
|
|
|
|
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
|
|
|
|
|
|
|
|
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
|
|
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
|
|
|
|
a = tl.load(a_ptrs,
|
|
|
|
|
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
|
|
|
|
|
other=0.0)
|
|
|
|
|
b = tl.load(b_ptrs,
|
|
|
|
|
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
|
|
|
|
|
other=0.0)
|
|
|
|
|
|
|
|
|
|
k_start = k * BLOCK_SIZE_K
|
|
|
|
|
offs_ks = k_start // group_k
|
|
|
|
|
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
|
|
|
|
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
|
|
|
|
|
|
|
|
|
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
|
|
|
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
|
|
|
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
|
|
|
|
|
|
|
|
|
if C.dtype.element_ty == tl.bfloat16:
|
|
|
|
|
c = accumulator.to(tl.bfloat16)
|
|
|
|
|
elif C.dtype.element_ty == tl.float16:
|
|
|
|
|
c = accumulator.to(tl.float16)
|
|
|
|
|
else:
|
|
|
|
|
c = accumulator.to(tl.float32)
|
|
|
|
|
|
|
|
|
|
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
|
|
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
|
|
|
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
|
|
|
|
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
|
|
|
|
tl.store(c_ptrs, c, mask=c_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def w8a8_block_fp8_matmul(
|
|
|
|
|
A: torch.Tensor,
|
|
|
|
|
B: torch.Tensor,
|
|
|
|
|
As: torch.Tensor,
|
|
|
|
|
Bs: torch.Tensor,
|
|
|
|
|
block_size: List[int],
|
|
|
|
|
output_dtype: torch.dtype = torch.float16,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
"""This function performs matrix multiplication with block-wise
|
|
|
|
|
quantization.
|
|
|
|
|
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
|
|
|
|
The output is returned in the specified `output_dtype`.
|
|
|
|
|
Args:
|
|
|
|
|
A: The input tensor, e.g., activation.
|
|
|
|
|
B: The input tensor, e.g., weight.
|
|
|
|
|
As: The per-token-group quantization scale for `A`.
|
|
|
|
|
Bs: The per-block quantization scale for `B`.
|
|
|
|
|
block_size: The block size for per-block quantization. It should
|
|
|
|
|
be 2-dim, e.g., [128, 128].
|
|
|
|
|
output_dytpe: The dtype of the returned tensor.
|
|
|
|
|
Returns:
|
|
|
|
|
torch.Tensor: The result of matmul.
|
|
|
|
|
"""
|
|
|
|
|
assert len(block_size) == 2
|
|
|
|
|
block_n, block_k = block_size[0], block_size[1]
|
|
|
|
|
|
|
|
|
|
assert A.shape[-1] == B.shape[-1]
|
|
|
|
|
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
|
|
|
|
|
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
|
|
|
|
M = A.numel() // A.shape[-1]
|
|
|
|
|
|
|
|
|
|
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
|
|
|
|
N, K = B.shape
|
|
|
|
|
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
|
|
|
|
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
|
|
|
|
|
|
|
|
|
C_shape = A.shape[:-1] + (N, )
|
|
|
|
|
C = A.new_empty(C_shape, dtype=output_dtype)
|
|
|
|
|
|
|
|
|
|
# TODO:
|
|
|
|
|
# BLOCK_SIZE_M, BLOCK_SIZE_K, BLOCK_SIZE_N can be optimized.
|
|
|
|
|
# BLOCK_SIZE_K must be divisible by block_k
|
|
|
|
|
# BLOCK_SIZE_N and BLOCK_SIZE_M has no requirements
|
|
|
|
|
BLOCK_SIZE_M = 128
|
|
|
|
|
if M < BLOCK_SIZE_M:
|
|
|
|
|
BLOCK_SIZE_M = triton.next_power_of_2(M)
|
|
|
|
|
BLOCK_SIZE_M = max(BLOCK_SIZE_M, 16)
|
|
|
|
|
BLOCK_SIZE_K = block_k
|
|
|
|
|
assert block_k % BLOCK_SIZE_K == 0
|
|
|
|
|
BLOCK_SIZE_N = block_n
|
|
|
|
|
|
|
|
|
|
def grid(META):
|
|
|
|
|
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
|
|
|
|
|
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
|
|
|
|
|
|
|
|
|
|
_w8a8_block_fp8_matmul[grid](
|
|
|
|
|
A,
|
|
|
|
|
B,
|
|
|
|
|
C,
|
|
|
|
|
As,
|
|
|
|
|
Bs,
|
|
|
|
|
M,
|
|
|
|
|
N,
|
|
|
|
|
K,
|
|
|
|
|
block_n,
|
|
|
|
|
block_k,
|
|
|
|
|
A.stride(-2),
|
|
|
|
|
A.stride(-1),
|
|
|
|
|
B.stride(1),
|
|
|
|
|
B.stride(0),
|
|
|
|
|
C.stride(-2),
|
|
|
|
|
C.stride(-1),
|
|
|
|
|
As.stride(-2),
|
|
|
|
|
As.stride(-1),
|
|
|
|
|
Bs.stride(1),
|
|
|
|
|
Bs.stride(0),
|
|
|
|
|
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
|
|
|
|
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
|
|
|
|
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
|
|
|
|
GROUP_SIZE_M=8,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return C
|