[Custom Ops] Add functional + out variant for scaled_fp4_quant (#34389)
Signed-off-by: tianrengao <terrygao87@gmail.com>
This commit is contained in:
@@ -18,6 +18,7 @@
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <utility>
|
||||
|
||||
#include "../../cuda_vec_utils.cuh"
|
||||
|
||||
@@ -54,6 +55,18 @@ inline int computeEffectiveRows(int m) {
|
||||
return round_up(m, ROW_TILE);
|
||||
}
|
||||
|
||||
// Compute the shape of the swizzled SF output tensor.
|
||||
// Returns (rounded_m, rounded_n / 4) where:
|
||||
// rounded_m = round_up(m, 128)
|
||||
// rounded_n = round_up(n / CVT_FP4_SF_VEC_SIZE, 4)
|
||||
inline std::pair<int64_t, int64_t> computeSwizzledSFShape(int64_t m,
|
||||
int64_t n) {
|
||||
int64_t rounded_m = round_up(m, static_cast<int64_t>(128));
|
||||
int64_t scale_n = n / CVT_FP4_SF_VEC_SIZE;
|
||||
int64_t rounded_n = round_up(scale_n, static_cast<int64_t>(4));
|
||||
return {rounded_m, rounded_n / 4};
|
||||
}
|
||||
|
||||
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
|
||||
inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) {
|
||||
uint32_t val;
|
||||
|
||||
Reference in New Issue
Block a user