[BugFix] Fix fp4 quant kernel on CUDA 12.8 (#35210)

Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
This commit is contained in:
Roberto L. Castro
2026-02-26 03:32:50 +01:00
committed by GitHub
parent 160424a937
commit 86c3b5a808
2 changed files with 11 additions and 7 deletions

View File

@@ -109,7 +109,8 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
template <class Type, bool UE8M0_SF = false>
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
cvt_fp16_to_fp4_sf_major(int32_t numRows, int32_t numCols,
int32_t sf_n_unpadded, Type const* __restrict__ in,
int32_t sf_n_unpadded, int32_t num_packed_cols,
Type const* __restrict__ in,
float const* __restrict__ SFScale,
uint32_t* __restrict__ out,
uint32_t* __restrict__ SFout) {
@@ -131,7 +132,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
// Iterate over all rows and cols including padded ones -
// ensures we visit every single scale factor address to initialize it.
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
if (colIdx < sf_n_unpadded) {
if (colIdx < num_packed_cols) {
PackedVec in_vec;
int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
@@ -222,7 +223,8 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
reinterpret_cast<uint32_t*>(sf_out));
});
} else {
int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast<int>(block.x));
int num_packed_cols = n / CVT_FP4_ELTS_PER_THREAD;
int grid_y = vllm::div_round_up(num_packed_cols, static_cast<int>(block.x));
int grid_x = std::min(
m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y);
@@ -232,8 +234,8 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
// NOTE: We don't support e8m0 scales at this moment.
vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false>
<<<grid, block, 0, stream>>>(m, n, sf_n_unpadded, input_ptr,
input_sf_ptr,
<<<grid, block, 0, stream>>>(m, n, sf_n_unpadded, num_packed_cols,
input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});