2025-04-11 11:54:08 -06:00
# SPDX-License-Identifier: Apache-2.0
2025-06-03 11:20:17 -07:00
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
2025-04-11 11:54:08 -06:00
# Adapted from https://github.com/sgl-project/sglang/blob/4cb53ecd0cffceb6dee5c011a58f65997a86f151/python/sglang/srt/layers/quantization/int8_kernel.py
import functools
import json
import logging
import os
2025-05-13 12:17:23 +01:00
from typing import Any , Optional
2025-04-11 11:54:08 -06:00
import torch
from vllm . platforms import current_platform
2025-05-06 17:53:09 +08:00
from vllm . triton_utils import tl , triton
2025-04-11 11:54:08 -06:00
logger = logging . getLogger ( __name__ )
def apply_w8a8_block_int8_linear (
input : torch . Tensor ,
weight : torch . Tensor ,
2025-05-13 12:17:23 +01:00
block_size : list [ int ] ,
2025-04-11 11:54:08 -06:00
weight_scale : torch . Tensor ,
input_scale : Optional [ torch . Tensor ] = None ,
bias : Optional [ torch . Tensor ] = None ,
) - > torch . Tensor :
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input . view ( - 1 , input . shape [ - 1 ] )
output_shape = [ * input . shape [ : - 1 ] , weight . shape [ 0 ] ]
q_input , x_scale = per_token_group_quant_int8 ( input_2d , block_size [ 1 ] )
output = w8a8_block_int8_matmul ( q_input ,
weight ,
x_scale ,
weight_scale ,
block_size ,
output_dtype = input . dtype )
if bias is not None :
output = output + bias
return output . to ( dtype = input . dtype ) . view ( * output_shape )
def input_to_int8 (
x : torch . Tensor ,
2025-05-13 12:17:23 +01:00
dtype : torch . dtype = torch . int8 ) - > tuple [ torch . Tensor , torch . Tensor ] :
2025-04-11 11:54:08 -06:00
""" This function quantizes input values to int8 values with
tensor - wise quantization . """
iinfo = torch . iinfo ( dtype )
min_val , max_val = x . aminmax ( )
amax = torch . maximum ( min_val . abs ( ) , max_val . abs ( ) ) . clamp ( min = 1e-12 )
int8_min , int8_max = iinfo . min , iinfo . max
scale = int8_max / amax
x_scl_sat = ( x * scale ) . clamp ( min = int8_min , max = int8_max )
return x_scl_sat . to ( dtype ) . contiguous ( ) , scale . float ( ) . reciprocal ( )
def block_dequant (
x_q_block : torch . Tensor ,
x_s : torch . Tensor ,
2025-05-13 12:17:23 +01:00
block_size : list [ int ] ,
2025-04-11 11:54:08 -06:00
) - > torch . Tensor :
""" This function conducts block-wise dequantization.
The inputs are block - wise quantization tensor ` x_q_block ` ,
block - wise quantization scale and the block size .
The outputs are dequantized tensor .
"""
block_n , block_k = block_size [ 0 ] , block_size [ 1 ]
n , k = x_q_block . shape
n_tiles = ( n + block_n - 1 ) / / block_n
k_tiles = ( k + block_k - 1 ) / / block_k
assert n_tiles == x_s . shape [ 0 ]
assert k_tiles == x_s . shape [ 1 ]
x_dq_block = x_q_block . to ( torch . float32 )
for i in range ( k_tiles ) :
for j in range ( n_tiles ) :
x_dq_block [
j * block_n : min ( ( j + 1 ) * block_n , n ) ,
i * block_k : min ( ( i + 1 ) * block_k , k ) ,
] * = x_s [ j ] [ i ]
return x_dq_block
2025-05-02 23:41:10 -05:00
if current_platform . is_rocm ( ) :
from triton . language import core
# NOTE: This can be removed when hip.libdevice.round() is available.
@core.extern
def round_f32 ( arg0 , _builder = None ) :
return core . extern_elementwise ( " " ,
" " , [ arg0 ] , {
( core . dtype ( " fp32 " ) , ) :
( " llvm.round " , core . dtype ( " fp32 " ) ) ,
( core . dtype ( " fp64 " ) , ) :
( " llvm.round " , core . dtype ( " fp64 " ) ) ,
} ,
is_pure = True ,
_builder = _builder )
@triton.jit
def round_int8 ( x ) :
return round_f32 ( x ) . to ( tl . int8 )
else :
@triton.jit
def round_int8 ( x ) :
return tl . extra . cuda . libdevice . round ( x ) . to ( tl . int8 )
2025-04-11 11:54:08 -06:00
@triton.jit
def _per_token_quant_int8 (
x_ptr ,
xq_ptr ,
scale_ptr ,
stride_x ,
stride_xq ,
N ,
BLOCK : tl . constexpr ,
) :
# Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
row_id = tl . program_id ( 0 )
cols = tl . arange ( 0 , BLOCK )
mask = cols < N
x = tl . load ( x_ptr + row_id * stride_x + cols , mask = mask ,
other = 0.0 ) . to ( tl . float32 )
absmax = tl . maximum ( tl . max ( tl . abs ( x ) ) , 1e-10 )
scale_x = absmax / 127
x_q = x * ( 127 / absmax )
2025-05-02 23:41:10 -05:00
x_q = round_int8 ( x_q )
2025-04-11 11:54:08 -06:00
tl . store ( xq_ptr + row_id * stride_xq + cols , x_q , mask = mask )
tl . store ( scale_ptr + row_id , scale_x )
def per_token_quant_int8 ( x ) :
M = x . numel ( ) / / x . shape [ - 1 ]
N = x . shape [ - 1 ]
x_q = torch . empty_like ( x , device = x . device , dtype = torch . int8 )
scales = torch . empty ( x . shape [ : - 1 ] + ( 1 , ) ,
device = x . device ,
dtype = torch . float32 )
BLOCK = triton . next_power_of_2 ( N )
# heuristics for number of warps
num_warps = min ( max ( BLOCK / / 256 , 1 ) , 8 )
assert x . is_contiguous ( )
_per_token_quant_int8 [ ( M , ) ] (
x ,
x_q ,
scales ,
stride_x = x . stride ( - 2 ) ,
stride_xq = x_q . stride ( - 2 ) ,
N = N ,
BLOCK = BLOCK ,
num_warps = num_warps ,
num_stages = 1 ,
)
return x_q , scales
@triton.jit
def _per_token_group_quant_int8 (
# Pointers to inputs and output
y_ptr ,
y_q_ptr ,
y_s_ptr ,
# Stride of input
y_stride ,
# Columns of input
N ,
# Avoid to divide zero
eps ,
# Information for int8
int8_min ,
int8_max ,
# Meta-parameters
BLOCK : tl . constexpr ,
) :
""" A Triton-accelerated function to perform per-token-group
quantization on a tensor .
This function converts the tensor values into int8 values .
"""
# Map the program id to the row of X and Y it should compute.
g_id = tl . program_id ( 0 )
y_ptr + = g_id * y_stride
y_q_ptr + = g_id * y_stride
y_s_ptr + = g_id
cols = tl . arange ( 0 , BLOCK ) # N <= BLOCK
mask = cols < N
y = tl . load ( y_ptr + cols , mask = mask , other = 0.0 ) . to ( tl . float32 )
# Quant
_absmax = tl . maximum ( tl . max ( tl . abs ( y ) ) , eps )
y_s = _absmax / int8_max
y_q = tl . clamp ( y / y_s , int8_min , int8_max ) . to ( y_q_ptr . dtype . element_ty )
tl . store ( y_q_ptr + cols , y_q , mask = mask )
tl . store ( y_s_ptr , y_s )
def per_token_group_quant_int8 (
x : torch . Tensor ,
group_size : int ,
eps : float = 1e-10 ,
dtype : torch . dtype = torch . int8 ,
2025-05-13 12:17:23 +01:00
) - > tuple [ torch . Tensor , torch . Tensor ] :
2025-04-11 11:54:08 -06:00
""" Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed int8 values and returns the
quantized tensor along with the scaling factor used for quantization .
Args :
2025-06-12 10:57:10 +08:00
x : The input tensor with ndim > = 2.
2025-04-11 11:54:08 -06:00
group_size : The group size used for quantization .
eps : The minimum to avoid dividing zero .
dtype : The dype of output tensor . Note that only ` torch . int8 `
is supported for now .
Returns :
2025-05-13 12:17:23 +01:00
tuple [ torch . Tensor , torch . Tensor ] : The quantized tensor and the
2025-04-11 11:54:08 -06:00
scaling factor for quantization .
"""
assert ( x . shape [ - 1 ] % group_size == 0
) , " the last dimension of `x` cannot be divisible by `group_size` "
assert x . is_contiguous ( ) , " `x` is not contiguous "
iinfo = torch . iinfo ( dtype )
int8_max = iinfo . max
int8_min = iinfo . min
x_q = torch . empty_like ( x , device = x . device , dtype = dtype )
x_s = torch . empty (
x . shape [ : - 1 ] + ( x . shape [ - 1 ] / / group_size , ) ,
device = x . device ,
dtype = torch . float32 ,
)
2025-07-25 20:07:07 -04:00
# prefer CUDA kernel if available
if current_platform . is_cuda ( ) :
torch . ops . _C . per_token_group_quant_int8 ( x , x_q , x_s , group_size , eps ,
float ( int8_min ) ,
float ( int8_max ) )
return x_q , x_s
M = x . numel ( ) / / group_size
N = group_size
2025-04-11 11:54:08 -06:00
BLOCK = triton . next_power_of_2 ( N )
# heuristics for number of warps
num_warps = min ( max ( BLOCK / / 256 , 1 ) , 8 )
num_stages = 1
_per_token_group_quant_int8 [ ( M , ) ] (
x ,
x_q ,
x_s ,
group_size ,
N ,
eps ,
int8_min = int8_min ,
int8_max = int8_max ,
BLOCK = BLOCK ,
num_warps = num_warps ,
num_stages = num_stages ,
)
return x_q , x_s
@triton.jit
def _w8a8_block_int8_matmul (
# Pointers to inputs and output
A ,
B ,
C ,
As ,
Bs ,
# Shape for matmul
M ,
N ,
K ,
# Block size for block-wise quantization
group_n ,
group_k ,
# Stride for inputs and output
stride_am ,
stride_ak ,
stride_bk ,
stride_bn ,
stride_cm ,
stride_cn ,
stride_As_m ,
stride_As_k ,
stride_Bs_k ,
stride_Bs_n ,
# Meta-parameters
BLOCK_SIZE_M : tl . constexpr ,
BLOCK_SIZE_N : tl . constexpr ,
BLOCK_SIZE_K : tl . constexpr ,
GROUP_SIZE_M : tl . constexpr ,
) :
""" Triton-accelerated function used to perform linear operations (dot
product ) on input tensors ` A ` and ` B ` with block - wise quantization , and
store the result in output tensor ` C ` .
"""
pid = tl . program_id ( axis = 0 )
num_pid_m = tl . cdiv ( M , BLOCK_SIZE_M )
num_pid_n = tl . cdiv ( N , BLOCK_SIZE_N )
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid / / num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min ( num_pid_m - first_pid_m , GROUP_SIZE_M )
pid_m = first_pid_m + ( pid % group_size_m )
pid_n = ( pid % num_pid_in_group ) / / group_size_m
offs_am = ( pid_m * BLOCK_SIZE_M + tl . arange ( 0 , BLOCK_SIZE_M ) ) % M
offs_bn = ( pid_n * BLOCK_SIZE_N + tl . arange ( 0 , BLOCK_SIZE_N ) ) % N
offs_k = tl . arange ( 0 , BLOCK_SIZE_K )
a_ptrs = A + ( offs_am [ : , None ] * stride_am + offs_k [ None , : ] * stride_ak )
b_ptrs = B + ( offs_k [ : , None ] * stride_bk + offs_bn [ None , : ] * stride_bn )
As_ptrs = As + offs_am * stride_As_m
offs_bsn = offs_bn / / group_n
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
accumulator = tl . zeros ( ( BLOCK_SIZE_M , BLOCK_SIZE_N ) , dtype = tl . float32 )
for k in range ( 0 , tl . cdiv ( K , BLOCK_SIZE_K ) ) :
a = tl . load ( a_ptrs ,
mask = offs_k [ None , : ] < K - k * BLOCK_SIZE_K ,
other = 0.0 )
b = tl . load ( b_ptrs ,
mask = offs_k [ : , None ] < K - k * BLOCK_SIZE_K ,
other = 0.0 )
k_start = k * BLOCK_SIZE_K
offs_ks = k_start / / group_k
a_s = tl . load ( As_ptrs + offs_ks * stride_As_k )
b_s = tl . load ( Bs_ptrs + offs_ks * stride_Bs_k )
accumulator + = tl . dot ( a , b ) . to ( tl . float32 ) * a_s [ : ,
None ] * b_s [ None , : ]
a_ptrs + = BLOCK_SIZE_K * stride_ak
b_ptrs + = BLOCK_SIZE_K * stride_bk
if C . dtype . element_ty == tl . bfloat16 :
c = accumulator . to ( tl . bfloat16 )
elif C . dtype . element_ty == tl . float16 :
c = accumulator . to ( tl . float16 )
else :
c = accumulator . to ( tl . float32 )
offs_cm = pid_m * BLOCK_SIZE_M + tl . arange ( 0 , BLOCK_SIZE_M )
offs_cn = pid_n * BLOCK_SIZE_N + tl . arange ( 0 , BLOCK_SIZE_N )
c_ptrs = C + stride_cm * offs_cm [ : , None ] + stride_cn * offs_cn [ None , : ]
c_mask = ( offs_cm [ : , None ] < M ) & ( offs_cn [ None , : ] < N )
tl . store ( c_ptrs , c , mask = c_mask )
@functools.lru_cache
def get_w8a8_block_int8_configs ( N : int , K : int , block_n : int ,
2025-05-13 12:17:23 +01:00
block_k : int ) - > Optional [ dict [ int , Any ] ] :
2025-04-11 11:54:08 -06:00
"""
Return optimized configurations for the w8a8 block fp8 kernel .
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the w8a8 block fp8 kernel . To evaluate the
kernel on a given batch size bs , the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel .
"""
# First look up if an optimized configuration is available in the configs
# directory
device_name = current_platform . get_device_name ( ) . replace ( " " , " _ " )
json_file_name = f " N= { N } ,K= { K } ,device_name= { device_name } ,dtype=int8_w8a8,block_shape=[ { block_n } , { block_k } ].json " # noqa: E501
config_file_path = os . path . join (
os . path . dirname ( os . path . realpath ( __file__ ) ) , " configs " , json_file_name )
if os . path . exists ( config_file_path ) :
with open ( config_file_path ) as f :
logger . info (
" Using configuration from %s for W8A8 Block INT8 kernel. " ,
config_file_path ,
)
# If a configuration has been found, return it
return { int ( key ) : val for key , val in json . load ( f ) . items ( ) }
# If no optimized configuration is available, we will use the default
# configuration
logger . warning (
( " Using default W8A8 Block INT8 kernel config. Performance might "
" be sub-optimal! Config file not found at %s " ) ,
config_file_path ,
)
return None
def w8a8_block_int8_matmul (
A : torch . Tensor ,
B : torch . Tensor ,
As : torch . Tensor ,
Bs : torch . Tensor ,
2025-05-13 12:17:23 +01:00
block_size : list [ int ] ,
2025-04-11 11:54:08 -06:00
output_dtype : torch . dtype = torch . float16 ,
) - > torch . Tensor :
""" This function performs matrix multiplication with block-wise
quantization .
It takes two input tensors ` A ` and ` B ` with scales ` As ` and ` Bs ` .
The output is returned in the specified ` output_dtype ` .
Args :
A : The input tensor , e . g . , activation .
B : The input tensor , e . g . , weight .
As : The per - token - group quantization scale for ` A ` .
Bs : The per - block quantization scale for ` B ` .
block_size : The block size for per - block quantization . It should be
2 - dim , e . g . , [ 128 , 128 ] .
2025-09-12 22:43:15 +09:00
output_dtype : The dtype of the returned tensor .
2025-04-11 11:54:08 -06:00
Returns :
torch . Tensor : The result of matmul .
"""
assert len ( block_size ) == 2
block_n , block_k = block_size [ 0 ] , block_size [ 1 ]
assert A . shape [ - 1 ] == B . shape [ - 1 ]
assert A . shape [ : - 1 ] == As . shape [ : - 1 ] and A . is_contiguous ( )
assert triton . cdiv ( A . shape [ - 1 ] , block_k ) == As . shape [ - 1 ]
M = A . numel ( ) / / A . shape [ - 1 ]
assert B . ndim == 2 and B . is_contiguous ( ) and Bs . ndim == 2
N , K = B . shape
assert triton . cdiv ( N , block_n ) == Bs . shape [ 0 ]
assert triton . cdiv ( K , block_k ) == Bs . shape [ 1 ]
C_shape = A . shape [ : - 1 ] + ( N , )
C = A . new_empty ( C_shape , dtype = output_dtype )
configs = get_w8a8_block_int8_configs ( N , K , block_size [ 0 ] , block_size [ 1 ] )
if configs :
# If an optimal configuration map has been found, look up the
# optimal config
config = configs [ min ( configs . keys ( ) , key = lambda x : abs ( x - M ) ) ]
else :
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
config = {
" BLOCK_SIZE_M " : 64 ,
" BLOCK_SIZE_N " : block_size [ 0 ] ,
" BLOCK_SIZE_K " : block_size [ 1 ] ,
" GROUP_SIZE_M " : 32 ,
" num_warps " : 4 ,
" num_stages " : 3 ,
}
def grid ( META ) :
return ( triton . cdiv ( M , META [ " BLOCK_SIZE_M " ] ) *
triton . cdiv ( N , META [ " BLOCK_SIZE_N " ] ) , )
_w8a8_block_int8_matmul [ grid ] (
A ,
B ,
C ,
As ,
Bs ,
M ,
N ,
K ,
block_n ,
block_k ,
A . stride ( - 2 ) ,
A . stride ( - 1 ) ,
B . stride ( 1 ) ,
B . stride ( 0 ) ,
C . stride ( - 2 ) ,
C . stride ( - 1 ) ,
As . stride ( - 2 ) ,
As . stride ( - 1 ) ,
Bs . stride ( 1 ) ,
Bs . stride ( 0 ) ,
* * config ,
)
return C