[Refactor][Kernel] Add global helper to deduplicate vectorized memory ops (#35105)

Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <roberto.lopez.castro@udc.es>
Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
This commit is contained in:
Roberto L. Castro
2026-02-28 01:28:17 +01:00
committed by GitHub
parent e3691988d0
commit a201ad72d8
6 changed files with 474 additions and 372 deletions

View File

@@ -39,12 +39,12 @@ namespace vllm {
template <class Type, bool UE8M0_SF = false>
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols,
int32_t num_padded_cols,
int32_t num_packed_cols,
Type const* __restrict__ in,
float const* __restrict__ SFScale,
uint32_t* __restrict__ out,
uint32_t* __restrict__ SFout) {
using PackedVec = vllm::PackedVec<Type>;
using PackedVec = vllm::PackedVec<Type, CVT_FP4_PACK16>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
@@ -63,7 +63,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
// Input tensor row/col loops.
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
if (colIdx < num_padded_cols) {
if (colIdx < num_packed_cols) {
PackedVec in_vec;
PackedVec in_vec2;
int64_t inOffset =
@@ -73,19 +73,19 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
bool valid = (rowIdx < numRows) && (elem_idx < numCols);
if constexpr (CVT_FP4_PACK16) {
ld256_or_zero_cg_u32<Type>(
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 8],
valid);
ld256_or_zero_cg_u32<Type>(
in_vec2, &reinterpret_cast<const uint32_t*>(in)[inOffset2 * 8],
valid);
ld256_cg_or_zero(reinterpret_cast<u32x8_t&>(in_vec),
&reinterpret_cast<const uint32_t*>(in)[inOffset * 8],
valid);
ld256_cg_or_zero(reinterpret_cast<u32x8_t&>(in_vec2),
&reinterpret_cast<const uint32_t*>(in)[inOffset2 * 8],
valid);
} else {
ld128_or_zero_cg_u32<Type>(
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 4],
valid);
ld128_or_zero_cg_u32<Type>(
in_vec2, &reinterpret_cast<const uint32_t*>(in)[inOffset2 * 4],
valid);
ld128_cg_or_zero(reinterpret_cast<uint4&>(in_vec),
&reinterpret_cast<const uint32_t*>(in)[inOffset * 4],
valid);
ld128_cg_or_zero(reinterpret_cast<uint4&>(in_vec2),
&reinterpret_cast<const uint32_t*>(in)[inOffset2 * 4],
valid);
}
// Compute silu and mul
@@ -142,9 +142,9 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
int const numBlocksPerSM =
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
int sf_n_unpadded = int(n / CVT_FP4_ELTS_PER_THREAD);
int num_packed_cols = int(n / CVT_FP4_ELTS_PER_THREAD);
int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast<int>(block.x));
int grid_y = vllm::div_round_up(num_packed_cols, static_cast<int>(block.x));
int grid_x = std::min(
int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y);
@@ -154,7 +154,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
vllm::silu_mul_cvt_fp16_to_fp4<cuda_type><<<grid, block, 0, stream>>>(
m, n, sf_n_unpadded, input_ptr, input_sf_ptr,
m, n, num_packed_cols, input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});