[Custom Ops] Add functional + out variant for scaled_fp4_quant (#34389)

Signed-off-by: tianrengao <terrygao87@gmail.com>
This commit is contained in:
Terry Gao
2026-03-16 15:51:46 -07:00
committed by GitHub
parent 7961486a9b
commit 3e6a1e1686
12 changed files with 213 additions and 44 deletions

View File

@@ -18,6 +18,7 @@
#include <cuda_runtime.h>
#include <cuda_fp8.h>
#include <utility>
#include "../../cuda_vec_utils.cuh"
@@ -54,6 +55,18 @@ inline int computeEffectiveRows(int m) {
return round_up(m, ROW_TILE);
}
// Compute the shape of the swizzled SF output tensor.
// Returns (rounded_m, rounded_n / 4) where:
// rounded_m = round_up(m, 128)
// rounded_n = round_up(n / CVT_FP4_SF_VEC_SIZE, 4)
inline std::pair<int64_t, int64_t> computeSwizzledSFShape(int64_t m,
int64_t n) {
int64_t rounded_m = round_up(m, static_cast<int64_t>(128));
int64_t scale_n = n / CVT_FP4_SF_VEC_SIZE;
int64_t rounded_n = round_up(scale_n, static_cast<int64_t>(4));
return {rounded_m, rounded_n / 4};
}
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) {
uint32_t val;