[NVFP4][Perf] Tune NVFP4 input quant kernel for small batch size (#30897)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -128,51 +128,42 @@ inline __device__ float reciprocal_approximate_ftz(float a) {
|
||||
return b;
|
||||
}
|
||||
|
||||
// Compute SF output offset for swizzled tensor core layout.
|
||||
// SF layout: [numMTiles, numKTiles, 32, 4, 4]
|
||||
// Caller must precompute: numKTiles = (numCols + 63) / 64
|
||||
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
|
||||
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx,
|
||||
int numCols,
|
||||
SFType* SFout) {
|
||||
__device__ __forceinline__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(
|
||||
int rowIdx, int colIdx, int32_t numKTiles, SFType* SFout) {
|
||||
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 ||
|
||||
CVT_FP4_NUM_THREADS_PER_SF == 2);
|
||||
|
||||
// One pair of threads write one SF to global memory.
|
||||
// TODO: stage through smem for packed STG.32
|
||||
// is it better than STG.8 from 4 threads ?
|
||||
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
|
||||
// SF vector index (16 elements share one SF in the K dimension).
|
||||
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
|
||||
int32_t mIdx = rowIdx;
|
||||
|
||||
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
|
||||
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
|
||||
|
||||
int32_t mTileIdx = mIdx / (32 * 4);
|
||||
// SF vector size 16.
|
||||
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
||||
int32_t numKTiles = (numCols + factor - 1) / factor;
|
||||
int64_t mTileStride = numKTiles * 32 * 4 * 4;
|
||||
|
||||
int32_t kTileIdx = (kIdx / 4);
|
||||
int64_t kTileStride = 32 * 4 * 4;
|
||||
|
||||
// M tile layout [32, 4] is column-major.
|
||||
int32_t outerMIdx = (mIdx % 32);
|
||||
int64_t outerMStride = 4 * 4;
|
||||
|
||||
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
|
||||
int64_t innerMStride = 4;
|
||||
|
||||
int32_t innerKIdx = (kIdx % 4);
|
||||
int64_t innerKStride = 1;
|
||||
|
||||
// Compute the global offset.
|
||||
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride +
|
||||
outerMIdx * outerMStride + innerMIdx * innerMStride +
|
||||
innerKIdx * innerKStride;
|
||||
|
||||
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
|
||||
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF != 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return nullptr;
|
||||
|
||||
// SF vector index (16 elements share one SF in the K dimension).
|
||||
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
|
||||
int32_t mIdx = rowIdx;
|
||||
|
||||
// Decompose indices using bitwise ops (all divisors are powers of 2).
|
||||
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
|
||||
int32_t mTileIdx = mIdx >> 7; // mIdx / 128
|
||||
int32_t outerMIdx = mIdx & 31; // mIdx % 32
|
||||
int32_t innerMIdx = (mIdx >> 5) & 3; // (mIdx / 32) % 4
|
||||
int32_t kTileIdx = kIdx >> 2; // kIdx / 4
|
||||
int32_t innerKIdx = kIdx & 3; // kIdx % 4
|
||||
|
||||
// Compute global SF offset: mTileIdx * (numKTiles * 512) + kTileIdx * 512 +
|
||||
// outerMIdx * 16 + innerMIdx * 4 + innerKIdx
|
||||
// Use bitwise OR for non-overlapping lower bits.
|
||||
int64_t SFOffset = (static_cast<int64_t>(mTileIdx) * numKTiles + kTileIdx)
|
||||
<< 9 |
|
||||
(outerMIdx << 4) | (innerMIdx << 2) | innerKIdx;
|
||||
|
||||
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
|
||||
}
|
||||
|
||||
// Quantizes the provided PackedVec into the uint32_t output
|
||||
|
||||
Reference in New Issue
Block a user