Fix CUDA kernel index data type in vllm/csrc/quantization/fused_kernels/layernorm_utils.cuh +10 (#15159)
Signed-off-by: Lu Fang <lufang@fb.com> Co-authored-by: Richard Barnes <rbarnes@meta.com>
This commit is contained in:
@@ -55,11 +55,11 @@ struct GmemTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
||||
this_block_B_base_ptr = params.B_ptr + blockIdx.y * Ntile * params.K +
|
||||
blockIdx.z * params.SplitK * 4;
|
||||
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
const auto lane_id = threadIdx.x % WARP_SIZE;
|
||||
|
||||
// For matrix A, a block load/store Mtile(row) x 32(col) elements in
|
||||
// multiple iters, 8x4 warp load/store 8(row) x 32(col) elements per iter
|
||||
const int Aldg_row_base_idx = threadIdx.x / 4;
|
||||
const auto Aldg_row_base_idx = threadIdx.x / 4;
|
||||
Aldg_col_idx = (threadIdx.x % 4) * LDG_ELEMENT_CNT_A;
|
||||
const int Aldg_base_offset = Aldg_row_base_idx * params.K + Aldg_col_idx;
|
||||
|
||||
@@ -67,7 +67,7 @@ struct GmemTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
||||
// elements of N32K16 packing in multiple iters, 4x8 warp load/store 4(row)
|
||||
// * 128(col) per iter
|
||||
Bldg_col_idx = (threadIdx.x % 8) * LDG_ELEMENT_CNT_B;
|
||||
const int Bldg_row_base_idx = threadIdx.x / 8;
|
||||
const auto Bldg_row_base_idx = threadIdx.x / 8;
|
||||
const int Bldg_base_offset =
|
||||
Bldg_row_base_idx * params.K * 4 + Bldg_col_idx;
|
||||
|
||||
@@ -89,7 +89,7 @@ struct GmemTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
||||
B_ldg_guard = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (Mtile + M_SIZE_ONE_LOAD - 1) / M_SIZE_ONE_LOAD; ++i) {
|
||||
int m_idx = blockIdx.x * Mtile + Aldg_row_base_idx + i * M_SIZE_ONE_LOAD;
|
||||
auto m_idx = blockIdx.x * Mtile + Aldg_row_base_idx + i * M_SIZE_ONE_LOAD;
|
||||
if (m_idx < params.M) {
|
||||
A_ldg_guard |= (1u << i);
|
||||
}
|
||||
@@ -98,8 +98,8 @@ struct GmemTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
||||
const int N_padded = (params.N + 31) / 32 * 32;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (Ntile + N_SIZE_ONE_LOAD - 1) / N_SIZE_ONE_LOAD; ++i) {
|
||||
int n_idx = blockIdx.y * Ntile + (Bldg_row_base_idx / 8) * 32 +
|
||||
i * N_SIZE_ONE_LOAD;
|
||||
auto n_idx = blockIdx.y * Ntile + (Bldg_row_base_idx / 8) * 32 +
|
||||
i * N_SIZE_ONE_LOAD;
|
||||
if (n_idx < N_padded) {
|
||||
B_ldg_guard |= (1u << i);
|
||||
}
|
||||
@@ -355,7 +355,7 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
||||
__device__ void fused_splitk_reduce() {
|
||||
// need splitk-reduce if enable splitk
|
||||
if (gridDim.z > 1) {
|
||||
int blk_red_idx = blockIdx.x * gridDim.y + blockIdx.y;
|
||||
auto blk_red_idx = blockIdx.x * gridDim.y + blockIdx.y;
|
||||
// Wait for all previous blocks in the splitk direction to accumulate the
|
||||
// results into C_tmp
|
||||
if (threadIdx.x == 0) {
|
||||
@@ -371,7 +371,7 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int C_tmp_base_offset = blk_red_idx * Mtile * Ntile + threadIdx.x * 4;
|
||||
auto C_tmp_base_offset = blk_red_idx * Mtile * Ntile + threadIdx.x * 4;
|
||||
if (blockIdx.z != 0) {
|
||||
// expecting that temporary register here reuses the previous A&B frag
|
||||
// register
|
||||
@@ -456,7 +456,7 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
||||
|
||||
FType* C_base_ptr = this_block_C_base_ptr + store_c_base_offset;
|
||||
// C_tile lds and stg
|
||||
int m_base_idx = store_c_row_base_idx + blockIdx.x * Mtile;
|
||||
auto m_base_idx = store_c_row_base_idx + blockIdx.x * Mtile;
|
||||
bool n_guard = (store_c_col_idx + blockIdx.y * Ntile) < params.N;
|
||||
if (WARP_NTILE == 32) {
|
||||
int lds_c_base_offset = warp_id * Mtile * WARP_NTILE +
|
||||
@@ -580,9 +580,9 @@ __global__ void __launch_bounds__(BLOCK)
|
||||
int sts_stage_idx = 0;
|
||||
int lds_stage_idx = 0;
|
||||
|
||||
int tb_k_slice = blockIdx.z * params.SplitK + params.SplitK <= params.K
|
||||
? params.SplitK
|
||||
: params.K - blockIdx.z * params.SplitK;
|
||||
auto tb_k_slice = blockIdx.z * params.SplitK + params.SplitK <= params.K
|
||||
? params.SplitK
|
||||
: params.K - blockIdx.z * params.SplitK;
|
||||
int k_tiles = (tb_k_slice + 31) / 32;
|
||||
int first_k_tile = tb_k_slice - (k_tiles - 1) * 32;
|
||||
|
||||
@@ -777,13 +777,13 @@ __global__ void restore_N32_K16_dequantize_rhs_w8a16_perc_kernel(
|
||||
const QT* qdata, const FT* scales, const FT* zeros, FT* fdata,
|
||||
const int N_32align, const int N, const int K) {
|
||||
__shared__ FT smem[64 * 32];
|
||||
int warp_id = threadIdx.x / 32;
|
||||
int lane_id = threadIdx.x % 32;
|
||||
const int src_row_idx = blockIdx.x * 8 + lane_id / 4;
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
auto lane_id = threadIdx.x % 32;
|
||||
const auto src_row_idx = blockIdx.x * 8 + lane_id / 4;
|
||||
const int src_col_idx =
|
||||
blockIdx.y * 64 * 4 + warp_id * 16 * 4 + (lane_id % 4) * 16;
|
||||
const int src_offset = src_row_idx * K * 4 + src_col_idx;
|
||||
int params_nidx = blockIdx.x * 32 + (lane_id / 4) * 4;
|
||||
auto params_nidx = blockIdx.x * 32 + (lane_id / 4) * 4;
|
||||
|
||||
QT qval_reg[16];
|
||||
const QT* pdata = qdata + src_offset;
|
||||
@@ -829,8 +829,8 @@ __global__ void restore_N32_K16_dequantize_rhs_w8a16_perc_kernel(
|
||||
*reinterpret_cast<uint4*>(smem + lds_base_offset + i * 32 * 32);
|
||||
}
|
||||
|
||||
const int dst_row_base_kidx = blockIdx.y * 64 + threadIdx.x / 4;
|
||||
const int dst_col_nidx = blockIdx.x * 32 + (threadIdx.x % 4) * 8;
|
||||
const auto dst_row_base_kidx = blockIdx.y * 64 + threadIdx.x / 4;
|
||||
const auto dst_col_nidx = blockIdx.x * 32 + (threadIdx.x % 4) * 8;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
int dst_row_kidx = dst_row_base_kidx + i * 32;
|
||||
@@ -1008,4 +1008,4 @@ torch::Tensor allspark_w8a16_gemm(
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("allspark_w8a16_gemm", &allspark_w8a16_gemm);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user