// SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright contributors to the vLLM project // Adapted from SGLang: // https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cuh #pragma once #include #include #include #include #include #include #include #include "cute/tensor.hpp" namespace expert_specialization { using namespace cute; constexpr uint32_t THREAD_BLOCK_SIZE = 128; constexpr uint32_t WARP_SIZE = 32; constexpr int BLOCK_M = 128; constexpr int BLOCK_K = 128; using ThrLayout = Layout, Stride<_8, _1>>; using ValLayout = Layout>; using SfR2SThrLayout = Layout, Stride<_4, _1>>; using SfR2SValLayout = Layout>; using ScaleFactorTileLayout = Layout, _4>, Stride, _1>>; // Fast reciprocal. inline __device__ float reciprocal_approximate_ftz(float a) { float b; asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); return b; } // Some code references TRT-LLM: // https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/quantization.cuh template __inline__ __device__ uint8_t cvt_warp_fp16_to_mxfp8(FragmentS& fragment_s, FragmentD& fragment_d) { using FragmentSLayout = typename FragmentS::layout_type; using FragmentDLayout = typename FragmentD::layout_type; FragmentSLayout fragment_s_layout; FragmentDLayout fragment_d_layout; static_assert(is_static::value && size(fragment_s_layout) == 16); static_assert(is_static::value && size(fragment_d_layout) == 16); constexpr int eles_per_thr = 16; using ValType = typename FragmentS::element_type; using VecType = std::conditional_t, __nv_bfloat162, __half2>; VecType vec[8]; // Assign vals vec[0].x = fragment_s(Int<0>{}); vec[0].y = fragment_s(Int<1>{}); vec[1].x = fragment_s(Int<2>{}); vec[1].y = fragment_s(Int<3>{}); vec[2].x = fragment_s(Int<4>{}); vec[2].y = fragment_s(Int<5>{}); vec[3].x = fragment_s(Int<6>{}); vec[3].y = fragment_s(Int<7>{}); vec[4].x = fragment_s(Int<8>{}); vec[4].y = fragment_s(Int<9>{}); vec[5].x = fragment_s(Int<10>{}); vec[5].y = fragment_s(Int<11>{}); vec[6].x = fragment_s(Int<12>{}); vec[6].y = fragment_s(Int<13>{}); vec[7].x = fragment_s(Int<14>{}); vec[7].y = fragment_s(Int<15>{}); auto local_max = __habs2(vec[0]); for (int i = 1; i < eles_per_thr / 2; i++) { local_max = __hmax2(__habs2(vec[i]), local_max); } local_max = __hmax2(__shfl_xor_sync(uint32_t(-1), local_max, 1), local_max); // Get the final absolute maximum values. float block_max(0.0f); if constexpr (std::is_same_v) { block_max = __bfloat162float(__hmax(local_max.x, local_max.y)); } else { block_max = __half2float(__hmax(local_max.x, local_max.y)); } // Get the SF (max value of the vector / max value of mxfp8). float sf_val = block_max * reciprocal_approximate_ftz(448.0f); // 8 bits representation of the SF. uint8_t fp8_sf_val; __nv_fp8_e8m0 tmp_sf_val; tmp_sf_val.__x = __nv_cvt_float_to_e8m0(sf_val, __NV_SATFINITE, cudaRoundPosInf); sf_val = static_cast(tmp_sf_val); fp8_sf_val = tmp_sf_val.__x; // Get the output scale (reciprocal of the SFValue). float output_scale = block_max != 0.f ? reciprocal_approximate_ftz(sf_val) : 0.0f; // Convert the input to float. float2 fp2_vals[eles_per_thr / 2]; #pragma unroll for (int i = 0; i < eles_per_thr / 2; i++) { if constexpr (std::is_same_v) { fp2_vals[i] = __half22float2(vec[i]); } else { fp2_vals[i] = __bfloat1622float2(vec[i]); } fp2_vals[i].x *= output_scale; fp2_vals[i].y *= output_scale; } union { uint8_t bytes[16]; __nv_fp8x2_e4m3 elts[8]; } u; u.elts[0] = __nv_fp8x2_e4m3(fp2_vals[0]); u.elts[1] = __nv_fp8x2_e4m3(fp2_vals[1]); u.elts[2] = __nv_fp8x2_e4m3(fp2_vals[2]); u.elts[3] = __nv_fp8x2_e4m3(fp2_vals[3]); u.elts[4] = __nv_fp8x2_e4m3(fp2_vals[4]); u.elts[5] = __nv_fp8x2_e4m3(fp2_vals[5]); u.elts[6] = __nv_fp8x2_e4m3(fp2_vals[6]); u.elts[7] = __nv_fp8x2_e4m3(fp2_vals[7]); fragment_d(Int<0>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[0]); fragment_d(Int<1>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[1]); fragment_d(Int<2>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[2]); fragment_d(Int<3>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[3]); fragment_d(Int<4>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[4]); fragment_d(Int<5>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[5]); fragment_d(Int<6>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[6]); fragment_d(Int<7>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[7]); fragment_d(Int<8>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[8]); fragment_d(Int<9>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[9]); fragment_d(Int<10>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[10]); fragment_d(Int<11>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[11]); fragment_d(Int<12>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[12]); fragment_d(Int<13>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[13]); fragment_d(Int<14>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[14]); fragment_d(Int<15>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[15]); return fp8_sf_val; } template __inline__ __device__ void mxfp8_experts_quant_tile( TensorS& tensor_s, TensorP& tensor_p, TensorD& tensor_d, TensorSharedSF& tensor_shared_sf, TensorSF& tensor_sf, int m, TiledCopyG2R& tiled_copy_g2r, TiledCopyR2G& tiled_copy_r2g, TiledCopyR2S& tiled_copy_r2s) { static_assert(size(get<0>(typename TensorS::layout_type{})) == 128 && size(get<1>(typename TensorS::layout_type{})) == 128 && stride(get<1>(typename TensorS::layout_type{})) == 1); static_assert(size(get<0>(typename TensorD::layout_type{})) == 128 && size(get<1>(typename TensorD::layout_type{})) == 128 && stride(get<1>(typename TensorD::layout_type{})) == 1); static_assert(size(get<0>(typename TensorP::layout_type{})) == 128 && size(get<1>(typename TensorP::layout_type{})) == 128); static_assert(size(get<0>(typename TensorSharedSF::layout_type{})) == 128 && size(get<1>(typename TensorSharedSF::layout_type{})) == 4); static_assert(size(get<0>(typename TensorSF::layout_type{})) == 128 && size(get<1>(typename TensorSF::layout_type{})) == 4); using Tiler_MN = typename TiledCopyG2R::Tiler_MN; auto tiler_mn = Tiler_MN{}; static_assert(size<0>(tiler_mn) == 16 && size<1>(tiler_mn) == 128); auto tiled_tensor_s = tiled_divide(tensor_s, tiler_mn); auto tiled_tensor_p = tiled_divide(tensor_p, tiler_mn); auto tiled_tensor_d = tiled_divide(tensor_d, tiler_mn); static_assert(size<2>(tiled_tensor_s) == 1); static_assert(size<2>(tiled_tensor_p) == 1); static_assert(size<2>(tiled_tensor_d) == 1); auto squeeze_tiled_tensor_s = take<0, 2>(tiled_tensor_s); auto squeeze_tiled_tensor_p = take<0, 2>(tiled_tensor_p); auto squeeze_tiled_tensor_d = take<0, 2>(tiled_tensor_d); using SF_Tiler_MN = typename TiledCopyR2S::Tiler_MN; auto sf_tiler_mn = SF_Tiler_MN{}; static_assert(size<0>(sf_tiler_mn) == 16 && size<1>(sf_tiler_mn) == 4); auto tiled_tensor_sf = tiled_divide(tensor_sf, sf_tiler_mn); auto tiled_tensor_shared_sf = tiled_divide(tensor_shared_sf, sf_tiler_mn); auto squeeze_tiled_tensor_sf = take<0, 2>(tiled_tensor_sf); auto squeeze_tiled_tensor_shared_sf = take<0, 2>(tiled_tensor_shared_sf); constexpr int tile_loop_count = size<1>(tiled_tensor_s); constexpr int rows_in_tile = 16; // We don't need to clear shared memory // clear(squeeze_tiled_tensor_shared_sf); #pragma unroll 4 for (int t = 0; t < tile_loop_count; t++) { if (t * rows_in_tile >= m) { break; } auto current_copy_tile_s = tensor<0>(squeeze_tiled_tensor_s(_, t)); auto current_copy_tile_p = tensor<0>(squeeze_tiled_tensor_p(_, t)); auto current_copy_tile_d = tensor<0>(squeeze_tiled_tensor_d(_, t)); auto current_copy_tile_sf = tensor<0>(squeeze_tiled_tensor_sf(_, t)); auto current_copy_tile_shared_sf = tensor<0>(squeeze_tiled_tensor_shared_sf(_, t)); // Global to Register copy auto thr_copy_g2r = tiled_copy_g2r.get_thread_slice(threadIdx.x); auto thr_tile_g2r_s = thr_copy_g2r.partition_S(current_copy_tile_s); auto thr_tile_g2r_p = thr_copy_g2r.partition_S(current_copy_tile_p); auto input_fragment = make_fragment_like(thr_tile_g2r_s); // Register to Global copy auto thr_copy_r2g = tiled_copy_r2g.get_thread_slice(threadIdx.x); auto thr_tile_r2g_d = thr_copy_r2g.partition_D(current_copy_tile_d); auto thr_tile_r2g_p = thr_copy_r2g.partition_D(current_copy_tile_p); auto output_fragment = make_fragment_like(thr_tile_r2g_d); // Register to Shared copy auto thr_copy_r2s = tiled_copy_r2s.get_thread_slice(threadIdx.x / 2); auto thr_tile_r2s_shared_sf = thr_copy_r2s.partition_D(current_copy_tile_shared_sf); auto shared_sf_fragment = make_fragment_like(thr_tile_r2s_shared_sf); // CopyG2R & convert & CopyR2G copy_if(tiled_copy_g2r, thr_tile_g2r_p, thr_tile_g2r_s, input_fragment); uint8_t fp8_sf_val = cvt_warp_fp16_to_mxfp8(input_fragment, output_fragment); copy_if(tiled_copy_r2g, thr_tile_r2g_p, output_fragment, thr_tile_r2g_d); shared_sf_fragment[0] = fp8_sf_val; // Before first copy r2s, clear shared memory and wait previous group if (t == 0 && threadIdx.x == 0) { // Wait for the group to have completed reading from shared memory. cuda::ptx::cp_async_bulk_wait_group_read(cuda::ptx::n32_t<0>()); } __syncthreads(); if (threadIdx.x % 2 == 0) { copy(tiled_copy_r2s, shared_sf_fragment, thr_tile_r2s_shared_sf); } __syncthreads(); } // Wait for shared memory writes to be visible to TMA engine. cuda::ptx::fence_proxy_async(cuda::ptx::space_shared); // b) __syncthreads(); if (threadIdx.x == 0) { cuda::ptx::cp_async_bulk(cuda::ptx::space_global, cuda::ptx::space_shared, squeeze_tiled_tensor_sf.data().get(), squeeze_tiled_tensor_shared_sf.data().get(), 512); // Wait for TMA transfer to have finished reading shared memory. // Create a "bulk async-group" out of the previous bulk copy operation. cuda::ptx::cp_async_bulk_commit_group(); } __syncthreads(); } template __global__ void mxfp8_experts_quant_kernel( const T_IN* input, const int* problem_sizes, const int* expert_offsets, const int* blockscale_offsets, cutlass::float_e4m3_t* quant_output, uint8_t* scale_factor, int groups, TiledCopyG2R tiled_copy_g2r, TiledCopyR2G tiled_copy_r2g, TiledCopyR2S tiled_copy_r2s) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 __shared__ __align__(512) uint8_t shared_memory[512]; ScaleFactorTileLayout scale_factor_tile_layout{}; auto scale_factor_shared = make_tensor(make_smem_ptr(shared_memory), scale_factor_tile_layout); // ((_32,_4), _4):((_16,_4), _1) // TODO: Transform Groupwise Schedule into a more efficient Schedule for (int g = 0; g < groups; g++) { int m = problem_sizes[g * 3 + 0]; int k = problem_sizes[g * 3 + 2]; int64_t expert_offset = static_cast(expert_offsets[g]); int64_t blockscale_offset = static_cast(blockscale_offsets[g]); auto input_tensor = make_tensor( make_gmem_ptr(input + expert_offset * k), make_layout(make_shape(m, k), LayoutRight{})); // (M, K):(K, 1) half_t/bfloat16_t auto quant_output_tensor = make_tensor( make_gmem_ptr(quant_output + expert_offset * k), make_layout(make_shape(m, k), LayoutRight{})); // (M, K):(K, 1) cutlass::float_e4m3_t auto scale_factor_shape = make_shape(ceil_div(m, 128) * 128, k / 32); auto scale_factor_layout = tile_to_shape(scale_factor_tile_layout, scale_factor_shape, LayoutRight{}); // layout<0>(layout<0>(scale_factor_layout)) (_32,_4):(_16,_4) -- static // layout<1>(layout<0>(scale_factor_layout)) M_align_128 / 128 -- dynamic // shape dynamic stride layout<0>(layout<1>(scale_factor_layout)) _4:_1 -- // static layout<1>(layout<1>(scale_factor_layout)) (K / 32) / 4 : _512 -- // dynamic shape static stride // Reshape to zipped layout for 1D indexing auto zipped_scale_factor_layout = make_layout( make_layout(layout<0>(layout<0>(scale_factor_layout)), layout<0>(layout<1>(scale_factor_layout))), make_layout( layout<1>(layout<0>(scale_factor_layout)), layout<1>(layout<1>( scale_factor_layout)))); // (((_32,_4),_4),(M_align_128 / // 128,(K / 32) / // 4)):(((_16,_4),_1),(?,_512)) auto scale_factor_tensor = make_tensor(make_gmem_ptr(scale_factor + blockscale_offset * (k / 32)), zipped_scale_factor_layout); // Used for cases where M is not divisible by 128 (most scenarios). auto input_shape = shape(input_tensor); // (M, K):(K, 1) auto identity_tensor = make_identity_tensor(input_shape); auto predict_tensor = cute::lazy::transform( identity_tensor, [&](auto c) { return elem_less(c, input_shape); }); // (_128, _128) auto tiler = make_shape(Int{}, Int{}); auto tiled_input_tensor = zipped_divide( input_tensor, tiler); // ((128, 128), (cdiv(M, 128), cdiv(K, 128))) auto tiled_quant_output_tensor = zipped_divide(quant_output_tensor, tiler); // ((128, 128), (cdiv(M, 128), cdiv(K, 128))) auto tiled_predict_tensor = zipped_divide( predict_tensor, tiler); // ((128, 128), (cdiv(M, 128), cdiv(K, 128))) auto total_tiles = size<1>(tiled_input_tensor); // cdiv(M, 128) * cdiv(K, 128) decltype(total_tiles) blk_offset = blockIdx.x; while (blk_offset < total_tiles) { auto current_input_tile = tensor<0>(tiled_input_tensor(_, blk_offset)); auto current_quant_output_tile = tensor<0>(tiled_quant_output_tensor(_, blk_offset)); auto current_predict_tile = tensor<0>(tiled_predict_tensor(_, blk_offset)); auto current_scale_factor_tile = tensor<0>(scale_factor_tensor(_, blk_offset)); mxfp8_experts_quant_tile< decltype(current_input_tile), decltype(current_predict_tile), decltype(current_quant_output_tile), decltype(scale_factor_shared), decltype(current_scale_factor_tile), TiledCopyG2R, TiledCopyR2G, TiledCopyR2S>(current_input_tile, current_predict_tile, current_quant_output_tile, scale_factor_shared, current_scale_factor_tile, m, tiled_copy_g2r, tiled_copy_r2g, tiled_copy_r2s); blk_offset += gridDim.x; } } #endif } template void launch_mxfp8_experts_quant(const torch::Tensor& input, const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets, const torch::Tensor& blockscale_offsets, torch::Tensor& quant_output, torch::Tensor& scale_factor) { ThrLayout thr_layout{}; ValLayout val_layout{}; SfR2SThrLayout r2s_thr_layout{}; SfR2SValLayout r2s_val_layout{}; using CopyOpG2R = UniversalCopy>; using CopyAtomG2R = cute::Copy_Atom; auto tiled_copy_g2r = cute::make_tiled_copy( CopyAtomG2R{}, thr_layout, val_layout); // Tiler_MN: (16, 128) using CopyOpR2G = UniversalCopy< cutlass::AlignedArray>; using CopyAtomR2G = cute::Copy_Atom; auto tiled_copy_r2g = cute::make_tiled_copy( CopyAtomR2G{}, thr_layout, val_layout); // Tiler_MN: (16, 128) using CopyOpR2S = UniversalCopy>; using CopyAtomR2S = cute::Copy_Atom; auto tiled_copy_r2s = cute::make_tiled_copy( CopyAtomR2S{}, r2s_thr_layout, r2s_val_layout); // Tiler_MN: (16, 4) int max_active_blocks_per_sm = -1; AT_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks_per_sm, mxfp8_experts_quant_kernel, THREAD_BLOCK_SIZE, 0)); dim3 grid(at::cuda::getCurrentDeviceProperties()->multiProcessorCount * max_active_blocks_per_sm, 1, 1); dim3 block(THREAD_BLOCK_SIZE, 1, 1); int num_experts = (int)problem_sizes.size(0); auto stream = at::cuda::getCurrentCUDAStream(); mxfp8_experts_quant_kernel <<>>( reinterpret_cast(input.data_ptr()), reinterpret_cast(problem_sizes.data_ptr()), reinterpret_cast(expert_offsets.data_ptr()), reinterpret_cast(blockscale_offsets.data_ptr()), reinterpret_cast(quant_output.data_ptr()), reinterpret_cast(scale_factor.data_ptr()), num_experts, tiled_copy_g2r, tiled_copy_r2g, tiled_copy_r2s); } } // namespace expert_specialization