[Bugfix]fix output Nan/Inf in marlin if dtype=float16 (#33972)
Signed-off-by: IriKa Qiu <qiujie.jq@gmail.com>
This commit is contained in:
@@ -13,7 +13,7 @@
|
|||||||
const int4 *__restrict__ b_bias_ptr, \
|
const int4 *__restrict__ b_bias_ptr, \
|
||||||
const float *__restrict__ a_scales_ptr, \
|
const float *__restrict__ a_scales_ptr, \
|
||||||
const int4 *__restrict__ scales_ptr, \
|
const int4 *__restrict__ scales_ptr, \
|
||||||
const uint16_t *__restrict__ global_scale_ptr, \
|
const float *__restrict__ global_scale_ptr, \
|
||||||
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
||||||
const int32_t *__restrict__ sorted_token_ids_ptr, \
|
const int32_t *__restrict__ sorted_token_ids_ptr, \
|
||||||
const int32_t *__restrict__ expert_ids_ptr, \
|
const int32_t *__restrict__ expert_ids_ptr, \
|
||||||
|
|||||||
@@ -260,7 +260,7 @@ __global__ void Marlin(
|
|||||||
// fp16 quantization scales. shape (k/groupsize, n)
|
// fp16 quantization scales. shape (k/groupsize, n)
|
||||||
const int4* __restrict__ scales_ptr,
|
const int4* __restrict__ scales_ptr,
|
||||||
// fp16 global scale (for nvfp4// only)
|
// fp16 global scale (for nvfp4// only)
|
||||||
const uint16_t* __restrict__ global_scale_ptr,
|
const float* __restrict__ global_scale_ptr,
|
||||||
// 4bit packed zero-points of shape
|
// 4bit packed zero-points of shape
|
||||||
// (k/groupsize, n/pack_factor)
|
// (k/groupsize, n/pack_factor)
|
||||||
const int4* __restrict__ zp_ptr,
|
const int4* __restrict__ zp_ptr,
|
||||||
@@ -308,7 +308,14 @@ __global__ void Marlin(
|
|||||||
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
|
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||||
constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id();
|
static constexpr auto num_bits =
|
||||||
|
vllm::ScalarType::from_id(b_type_id).size_bits();
|
||||||
|
// Disable use_fp16_accum for NVFP4 and cases when group_size == -1 &&
|
||||||
|
// num_bits == 4
|
||||||
|
constexpr bool use_fp16_accum =
|
||||||
|
a_type_id == vllm::kFloat16.id() &&
|
||||||
|
(!(b_type_id == vllm::kFE2M1f.id() && s_type_id == vllm::kFE4M3fn.id()) &&
|
||||||
|
!(group_blocks == -1 && num_bits == 4));
|
||||||
#else
|
#else
|
||||||
constexpr bool use_fp16_accum = false;
|
constexpr bool use_fp16_accum = false;
|
||||||
#endif
|
#endif
|
||||||
@@ -357,7 +364,7 @@ __global__ void Marlin(
|
|||||||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
|
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
|
||||||
has_zp && !is_zp_float && !(b_type == vllm::kU8);
|
has_zp && !is_zp_float && !(b_type == vllm::kU8);
|
||||||
|
|
||||||
c_scalar_t2 global_scale;
|
float global_scale_f32 = 1.0f;
|
||||||
|
|
||||||
constexpr bool has_act_order = group_blocks == 0;
|
constexpr bool has_act_order = group_blocks == 0;
|
||||||
|
|
||||||
@@ -507,11 +514,12 @@ __global__ void Marlin(
|
|||||||
|
|
||||||
if (mul_topk_weights) {
|
if (mul_topk_weights) {
|
||||||
idx = idx < prob_m_top_k ? idx : 0;
|
idx = idx < prob_m_top_k ? idx : 0;
|
||||||
c_scalar_t2 topk_weight_val =
|
float topk_weight_tmp = topk_weights_ptr[idx];
|
||||||
Cdtype::num2num2(Cdtype::float2num(topk_weights_ptr[idx]));
|
|
||||||
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||||
topk_weight_val = __hmul2(topk_weight_val, global_scale);
|
topk_weight_tmp *= global_scale_f32;
|
||||||
}
|
}
|
||||||
|
c_scalar_t2 topk_weight_val =
|
||||||
|
Cdtype::num2num2(Cdtype::float2num(topk_weight_tmp));
|
||||||
sh_block_topk_weights[threadIdx.x] = topk_weight_val;
|
sh_block_topk_weights[threadIdx.x] = topk_weight_val;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -532,8 +540,7 @@ __global__ void Marlin(
|
|||||||
expert_id = expert_ids_ptr[block_id];
|
expert_id = expert_ids_ptr[block_id];
|
||||||
|
|
||||||
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||||
uint16_t val = global_scale_ptr[expert_id];
|
global_scale_f32 = global_scale_ptr[expert_id];
|
||||||
global_scale = Cdtype::num2num2(*reinterpret_cast<c_scalar_t*>(&val));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4);
|
B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4);
|
||||||
@@ -1784,6 +1791,13 @@ __global__ void Marlin(
|
|||||||
// We first reorder in shared memory to guarantee the most efficient final
|
// We first reorder in shared memory to guarantee the most efficient final
|
||||||
// global write patterns
|
// global write patterns
|
||||||
auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) {
|
auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) {
|
||||||
|
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||||
|
if (!mul_topk_weights) {
|
||||||
|
c0 *= global_scale_f32;
|
||||||
|
c1 *= global_scale_f32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
c_scalar_t2 res =
|
c_scalar_t2 res =
|
||||||
Cdtype::nums2num2(Cdtype::float2num(c0), Cdtype::float2num(c1));
|
Cdtype::nums2num2(Cdtype::float2num(c0), Cdtype::float2num(c1));
|
||||||
|
|
||||||
@@ -1800,11 +1814,6 @@ __global__ void Marlin(
|
|||||||
res = __hmul2(res, tmp_scale);
|
res = __hmul2(res, tmp_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
|
||||||
if (!mul_topk_weights) {
|
|
||||||
res = __hmul2(res, global_scale);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (has_bias && last) {
|
if (has_bias && last) {
|
||||||
c_scalar_t2 tmp_bias = b_bias[0];
|
c_scalar_t2 tmp_bias = b_bias[0];
|
||||||
if constexpr (m_block_size_8) {
|
if constexpr (m_block_size_8) {
|
||||||
|
|||||||
@@ -382,7 +382,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
|||||||
const int4* bias_ptr = (const int4*)b_bias;
|
const int4* bias_ptr = (const int4*)b_bias;
|
||||||
const float* a_s_ptr = (const float*)a_s;
|
const float* a_s_ptr = (const float*)a_s;
|
||||||
const int4* b_s_ptr = (const int4*)b_s;
|
const int4* b_s_ptr = (const int4*)b_s;
|
||||||
const uint16_t* g_s_ptr = (const uint16_t*)g_s;
|
const float* g_s_ptr = (const float*)g_s;
|
||||||
const int4* zp_ptr = (const int4*)zp;
|
const int4* zp_ptr = (const int4*)zp;
|
||||||
const int* g_idx_ptr = (const int*)g_idx;
|
const int* g_idx_ptr = (const int*)g_idx;
|
||||||
const int* perm_ptr = (const int*)perm;
|
const int* perm_ptr = (const int*)perm;
|
||||||
@@ -759,7 +759,7 @@ torch::Tensor moe_wna16_marlin_gemm(
|
|||||||
TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn,
|
TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn,
|
||||||
"global_scale can only be used for nvfp4 format.");
|
"global_scale can only be used for nvfp4 format.");
|
||||||
} else {
|
} else {
|
||||||
global_scale = torch::empty({0}, options);
|
global_scale = torch::empty({0}, options_fp32);
|
||||||
TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn),
|
TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn),
|
||||||
"the global_scale parameter must be passed for nvfp4 format.");
|
"the global_scale parameter must be passed for nvfp4 format.");
|
||||||
}
|
}
|
||||||
@@ -842,8 +842,8 @@ torch::Tensor moe_wna16_marlin_gemm(
|
|||||||
|
|
||||||
TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float,
|
TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float,
|
||||||
"scalar type of a_scales must be float");
|
"scalar type of a_scales must be float");
|
||||||
TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(),
|
TORCH_CHECK(global_scale.scalar_type() == at::ScalarType::Float,
|
||||||
"scalar type of global_scale must be the same with c");
|
"scalar type of global_scale must be float");
|
||||||
if (a_type.size_bits() == 16) {
|
if (a_type.size_bits() == 16) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
a.scalar_type() == c.scalar_type(),
|
a.scalar_type() == c.scalar_type(),
|
||||||
|
|||||||
@@ -189,10 +189,7 @@ __device__ __forceinline__ void cp_async_wait<0>() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ float clip(float v, float mmin, float mmax) {
|
__device__ __forceinline__ float clip(float v, float mmin, float mmax) {
|
||||||
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
|
||||||
return fminf(mmax, fmaxf(v, mmin));
|
return fminf(mmax, fmaxf(v, mmin));
|
||||||
#else
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ __nv_bfloat16 clip(__nv_bfloat16 v,
|
__device__ __forceinline__ __nv_bfloat16 clip(__nv_bfloat16 v,
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
const int4 *__restrict__ b_bias_ptr, \
|
const int4 *__restrict__ b_bias_ptr, \
|
||||||
const float *__restrict__ a_scales_ptr, \
|
const float *__restrict__ a_scales_ptr, \
|
||||||
const int4 *__restrict__ scales_ptr, \
|
const int4 *__restrict__ scales_ptr, \
|
||||||
const uint16_t *__restrict__ global_scale_ptr, \
|
const float *__restrict__ global_scale_ptr, \
|
||||||
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
||||||
int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
|
int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
|
||||||
bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \
|
bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ torch::Tensor marlin_gemm(
|
|||||||
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
|
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
|
||||||
bool is_zp_float) {
|
bool is_zp_float) {
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||||
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
|
"marlin_gemm(..) requires CUDA_ARCH >= 7.5");
|
||||||
return torch::empty({1, 1});
|
return torch::empty({1, 1});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -356,7 +356,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
|||||||
const int4* bias_ptr = (const int4*)b_bias;
|
const int4* bias_ptr = (const int4*)b_bias;
|
||||||
const float* a_s_ptr = (const float*)a_s;
|
const float* a_s_ptr = (const float*)a_s;
|
||||||
const int4* b_s_ptr = (const int4*)b_s;
|
const int4* b_s_ptr = (const int4*)b_s;
|
||||||
const uint16_t* g_s_ptr = (const uint16_t*)g_s;
|
const float* g_s_ptr = (const float*)g_s;
|
||||||
|
|
||||||
const int4* zp_ptr = (const int4*)zp;
|
const int4* zp_ptr = (const int4*)zp;
|
||||||
const int* g_idx_ptr = (const int*)g_idx;
|
const int* g_idx_ptr = (const int*)g_idx;
|
||||||
@@ -751,7 +751,7 @@ torch::Tensor marlin_gemm(
|
|||||||
TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn,
|
TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn,
|
||||||
"global_scale can only be used for nvfp4 format.");
|
"global_scale can only be used for nvfp4 format.");
|
||||||
} else {
|
} else {
|
||||||
global_scale = torch::empty({0}, options);
|
global_scale = torch::empty({0}, options_fp32);
|
||||||
TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn),
|
TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn),
|
||||||
"the global_scale parameter must be passed for nvfp4 format.");
|
"the global_scale parameter must be passed for nvfp4 format.");
|
||||||
}
|
}
|
||||||
@@ -832,8 +832,8 @@ torch::Tensor marlin_gemm(
|
|||||||
|
|
||||||
TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float,
|
TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float,
|
||||||
"scalar type of a_scales must be float");
|
"scalar type of a_scales must be float");
|
||||||
TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(),
|
TORCH_CHECK(global_scale.scalar_type() == at::ScalarType::Float,
|
||||||
"scalar type of global_scale must be the same with c");
|
"scalar type of global_scale must be float");
|
||||||
if (a_type.size_bits() == 16) {
|
if (a_type.size_bits() == 16) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
a.scalar_type() == c.scalar_type(),
|
a.scalar_type() == c.scalar_type(),
|
||||||
|
|||||||
@@ -251,8 +251,8 @@ __global__ void Marlin(
|
|||||||
const float* __restrict__ a_scales_ptr,
|
const float* __restrict__ a_scales_ptr,
|
||||||
// fp16 quantization scales. shape (k/groupsize, n)
|
// fp16 quantization scales. shape (k/groupsize, n)
|
||||||
const int4* __restrict__ scales_ptr,
|
const int4* __restrict__ scales_ptr,
|
||||||
// fp16 global scale (for nvfp4// only)
|
// float global scale (for nvfp4// only)
|
||||||
const uint16_t* __restrict__ global_scale_ptr,
|
const float* __restrict__ global_scale_ptr,
|
||||||
// 4bit packed zero-points of shape
|
// 4bit packed zero-points of shape
|
||||||
// (k/groupsize, n/pack_factor)
|
// (k/groupsize, n/pack_factor)
|
||||||
const int4* __restrict__ zp_ptr,
|
const int4* __restrict__ zp_ptr,
|
||||||
@@ -292,7 +292,13 @@ __global__ void Marlin(
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||||
constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id();
|
constexpr auto num_bits = vllm::ScalarType::from_id(b_type_id).size_bits();
|
||||||
|
// Disable use_fp16_accum for NVFP4 and cases when group_size == -1 &&
|
||||||
|
// num_bits == 4
|
||||||
|
constexpr bool use_fp16_accum =
|
||||||
|
a_type_id == vllm::kFloat16.id() &&
|
||||||
|
(!(b_type_id == vllm::kFE2M1f.id() && s_type_id == vllm::kFE4M3fn.id()) &&
|
||||||
|
!(group_blocks == -1 && num_bits == 4));
|
||||||
#else
|
#else
|
||||||
constexpr bool use_fp16_accum = false;
|
constexpr bool use_fp16_accum = false;
|
||||||
#endif
|
#endif
|
||||||
@@ -342,11 +348,10 @@ __global__ void Marlin(
|
|||||||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
|
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
|
||||||
has_zp && !is_zp_float && !(b_type == vllm::kU8);
|
has_zp && !is_zp_float && !(b_type == vllm::kU8);
|
||||||
|
|
||||||
c_scalar_t2 global_scale;
|
float global_scale_f32 = 1.0f;
|
||||||
|
|
||||||
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||||
uint16_t val = global_scale_ptr[0];
|
global_scale_f32 = global_scale_ptr[0];
|
||||||
global_scale = Cdtype::num2num2(*reinterpret_cast<c_scalar_t*>(&val));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr bool has_act_order = group_blocks == 0;
|
constexpr bool has_act_order = group_blocks == 0;
|
||||||
@@ -1644,6 +1649,10 @@ __global__ void Marlin(
|
|||||||
// We first reorder in shared memory to guarantee the most efficient final
|
// We first reorder in shared memory to guarantee the most efficient final
|
||||||
// global write patterns
|
// global write patterns
|
||||||
auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) {
|
auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) {
|
||||||
|
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||||
|
c0 *= global_scale_f32;
|
||||||
|
c1 *= global_scale_f32;
|
||||||
|
}
|
||||||
c_scalar_t2 res =
|
c_scalar_t2 res =
|
||||||
Cdtype::nums2num2(Cdtype::float2num(c0), Cdtype::float2num(c1));
|
Cdtype::nums2num2(Cdtype::float2num(c0), Cdtype::float2num(c1));
|
||||||
|
|
||||||
@@ -1659,10 +1668,6 @@ __global__ void Marlin(
|
|||||||
}
|
}
|
||||||
res = __hmul2(res, tmp_scale);
|
res = __hmul2(res, tmp_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
|
||||||
res = __hmul2(res, global_scale);
|
|
||||||
}
|
|
||||||
if (has_bias && last) {
|
if (has_bias && last) {
|
||||||
c_scalar_t2 tmp_bias = b_bias[0];
|
c_scalar_t2 tmp_bias = b_bias[0];
|
||||||
if constexpr (m_block_size_8) {
|
if constexpr (m_block_size_8) {
|
||||||
|
|||||||
@@ -27,10 +27,19 @@ def is_fp4_marlin_supported():
|
|||||||
return current_platform.has_device_capability(75)
|
return current_platform.has_device_capability(75)
|
||||||
|
|
||||||
|
|
||||||
def _nvfp4_compute_scale_factor(marlin_scales: torch.Tensor) -> float:
|
def _nvfp4_compute_scale_factor(
|
||||||
|
marlin_scales: torch.Tensor,
|
||||||
|
a_dtype: torch.dtype | None = None,
|
||||||
|
) -> float:
|
||||||
"""Compute the power-of-2 scale_factor needed so that all non-zero
|
"""Compute the power-of-2 scale_factor needed so that all non-zero
|
||||||
values in marlin_scales * 2^7 are >= 2 after rescaling.
|
values in marlin_scales * 2^7 are >= 2 after rescaling.
|
||||||
Returns a Python float (power of 2, >= 1.0)."""
|
Returns a Python float (power of 2, >= 1.0)."""
|
||||||
|
|
||||||
|
# Since half has a smaller dynamic range compared to bfloat16,
|
||||||
|
# no rescaling is applied here if active dtype is half.
|
||||||
|
if a_dtype is not None and a_dtype == torch.half:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
ws_float = marlin_scales.float() * (2**7)
|
ws_float = marlin_scales.float() * (2**7)
|
||||||
nonzero_mask = ws_float > 0
|
nonzero_mask = ws_float > 0
|
||||||
if nonzero_mask.any():
|
if nonzero_mask.any():
|
||||||
@@ -44,6 +53,7 @@ def _nvfp4_compute_scale_factor(marlin_scales: torch.Tensor) -> float:
|
|||||||
def nvfp4_marlin_process_scales(
|
def nvfp4_marlin_process_scales(
|
||||||
marlin_scales: torch.Tensor,
|
marlin_scales: torch.Tensor,
|
||||||
scale_factor: float | None = None,
|
scale_factor: float | None = None,
|
||||||
|
a_dtype: torch.dtype | None = None,
|
||||||
) -> tuple[torch.Tensor, float]:
|
) -> tuple[torch.Tensor, float]:
|
||||||
"""Process NVFP4 weight scales into the special S0E5M3 format for Marlin.
|
"""Process NVFP4 weight scales into the special S0E5M3 format for Marlin.
|
||||||
|
|
||||||
@@ -91,7 +101,7 @@ def nvfp4_marlin_process_scales(
|
|||||||
# to fully utilize the E4M3 dynamic range (e.g., global_scale=1).
|
# to fully utilize the E4M3 dynamic range (e.g., global_scale=1).
|
||||||
# The caller must compensate by dividing global_scale by scale_factor.
|
# The caller must compensate by dividing global_scale by scale_factor.
|
||||||
if scale_factor is None:
|
if scale_factor is None:
|
||||||
scale_factor = _nvfp4_compute_scale_factor(marlin_scales)
|
scale_factor = _nvfp4_compute_scale_factor(marlin_scales, a_dtype)
|
||||||
if scale_factor > 1.0:
|
if scale_factor > 1.0:
|
||||||
marlin_scales = (marlin_scales.float() * scale_factor).to(torch.half)
|
marlin_scales = (marlin_scales.float() * scale_factor).to(torch.half)
|
||||||
|
|
||||||
@@ -119,12 +129,14 @@ def mxfp4_marlin_process_scales(marlin_scales, input_dtype=None):
|
|||||||
return marlin_scales
|
return marlin_scales
|
||||||
|
|
||||||
|
|
||||||
def nvfp4_marlin_process_global_scale(global_scale):
|
def nvfp4_marlin_process_global_scale(global_scale, a_dtype: torch.dtype | None = None):
|
||||||
assert global_scale.dtype in [torch.half, torch.bfloat16]
|
if a_dtype is None:
|
||||||
|
a_dtype = global_scale.dtype
|
||||||
|
assert a_dtype in [torch.half, torch.bfloat16]
|
||||||
fp4_exponent = 2
|
fp4_exponent = 2
|
||||||
if global_scale.dtype == torch.half:
|
if a_dtype == torch.half:
|
||||||
target_exponent = 5
|
target_exponent = 5
|
||||||
elif global_scale.dtype == torch.bfloat16:
|
elif a_dtype == torch.bfloat16:
|
||||||
target_exponent = 8
|
target_exponent = 8
|
||||||
# exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14
|
# exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14
|
||||||
# exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126
|
# exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126
|
||||||
@@ -244,11 +256,15 @@ def prepare_fp4_layer_for_marlin(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_nvfp4:
|
if is_nvfp4:
|
||||||
weight_scale, scale_factor = nvfp4_marlin_process_scales(weight_scale)
|
weight_scale, scale_factor = nvfp4_marlin_process_scales(
|
||||||
|
weight_scale, a_dtype=param_dtype
|
||||||
|
)
|
||||||
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
||||||
|
|
||||||
weight_global_scale = layer.weight_global_scale.to(param_dtype)
|
weight_global_scale = layer.weight_global_scale.to(torch.float32)
|
||||||
weight_global_scale = nvfp4_marlin_process_global_scale(weight_global_scale)
|
weight_global_scale = nvfp4_marlin_process_global_scale(
|
||||||
|
weight_global_scale, param_dtype
|
||||||
|
)
|
||||||
weight_global_scale = weight_global_scale / scale_factor
|
weight_global_scale = weight_global_scale / scale_factor
|
||||||
layer.weight_global_scale = torch.nn.Parameter(
|
layer.weight_global_scale = torch.nn.Parameter(
|
||||||
weight_global_scale, requires_grad=False
|
weight_global_scale, requires_grad=False
|
||||||
@@ -339,7 +355,6 @@ def prepare_nvfp4_moe_layer_for_marlin(
|
|||||||
scales: torch.Tensor, g_scales: torch.Tensor, name: str
|
scales: torch.Tensor, g_scales: torch.Tensor, name: str
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
scales = scales.to(param_dtype)
|
scales = scales.to(param_dtype)
|
||||||
g_scales = g_scales.to(param_dtype)
|
|
||||||
|
|
||||||
tensor_list = []
|
tensor_list = []
|
||||||
num_shards = 2 if is_act_and_mul else 1
|
num_shards = 2 if is_act_and_mul else 1
|
||||||
@@ -350,7 +365,7 @@ def prepare_nvfp4_moe_layer_for_marlin(
|
|||||||
|
|
||||||
# All experts share one global_scale, so compute the max
|
# All experts share one global_scale, so compute the max
|
||||||
# scale_factor across all experts first, then apply uniformly.
|
# scale_factor across all experts first, then apply uniformly.
|
||||||
combined_scale_factor = _nvfp4_compute_scale_factor(scales)
|
combined_scale_factor = _nvfp4_compute_scale_factor(scales, param_dtype)
|
||||||
|
|
||||||
for i in range(E):
|
for i in range(E):
|
||||||
scale = scales[i].T
|
scale = scales[i].T
|
||||||
@@ -362,12 +377,12 @@ def prepare_nvfp4_moe_layer_for_marlin(
|
|||||||
is_a_8bit=is_a_8bit,
|
is_a_8bit=is_a_8bit,
|
||||||
)
|
)
|
||||||
marlin_scales, _ = nvfp4_marlin_process_scales(
|
marlin_scales, _ = nvfp4_marlin_process_scales(
|
||||||
marlin_scales, scale_factor=combined_scale_factor
|
marlin_scales, scale_factor=combined_scale_factor, a_dtype=param_dtype
|
||||||
)
|
)
|
||||||
tensor_list.append(marlin_scales)
|
tensor_list.append(marlin_scales)
|
||||||
|
|
||||||
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||||
g_scales = nvfp4_marlin_process_global_scale(g_scales)
|
g_scales = nvfp4_marlin_process_global_scale(g_scales, param_dtype)
|
||||||
g_scales = g_scales / combined_scale_factor
|
g_scales = g_scales / combined_scale_factor
|
||||||
return scales, g_scales
|
return scales, g_scales
|
||||||
|
|
||||||
@@ -438,7 +453,7 @@ def prepare_moe_fp4_layer_for_marlin(
|
|||||||
scales = scales.view(torch.float8_e8m0fnu)
|
scales = scales.view(torch.float8_e8m0fnu)
|
||||||
scales = scales.to(param_dtype)
|
scales = scales.to(param_dtype)
|
||||||
if is_nvfp4:
|
if is_nvfp4:
|
||||||
global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype)
|
global_scale = getattr(layer, name + "_weight_scale_2")
|
||||||
|
|
||||||
tensor_list = []
|
tensor_list = []
|
||||||
if "w13" in name:
|
if "w13" in name:
|
||||||
@@ -449,7 +464,7 @@ def prepare_moe_fp4_layer_for_marlin(
|
|||||||
# For NVFP4: compute unified scale_factor across all experts
|
# For NVFP4: compute unified scale_factor across all experts
|
||||||
combined_scale_factor = None
|
combined_scale_factor = None
|
||||||
if is_nvfp4:
|
if is_nvfp4:
|
||||||
combined_scale_factor = _nvfp4_compute_scale_factor(scales)
|
combined_scale_factor = _nvfp4_compute_scale_factor(scales, param_dtype)
|
||||||
|
|
||||||
for i in range(e):
|
for i in range(e):
|
||||||
scale = scales[i].T
|
scale = scales[i].T
|
||||||
@@ -463,7 +478,9 @@ def prepare_moe_fp4_layer_for_marlin(
|
|||||||
)
|
)
|
||||||
if is_nvfp4:
|
if is_nvfp4:
|
||||||
marlin_scales, _ = nvfp4_marlin_process_scales(
|
marlin_scales, _ = nvfp4_marlin_process_scales(
|
||||||
marlin_scales, scale_factor=combined_scale_factor
|
marlin_scales,
|
||||||
|
scale_factor=combined_scale_factor,
|
||||||
|
a_dtype=param_dtype,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
marlin_scales = mxfp4_marlin_process_scales(
|
marlin_scales = mxfp4_marlin_process_scales(
|
||||||
@@ -477,7 +494,7 @@ def prepare_moe_fp4_layer_for_marlin(
|
|||||||
|
|
||||||
if is_nvfp4:
|
if is_nvfp4:
|
||||||
assert combined_scale_factor is not None
|
assert combined_scale_factor is not None
|
||||||
global_scale = nvfp4_marlin_process_global_scale(global_scale)
|
global_scale = nvfp4_marlin_process_global_scale(global_scale, param_dtype)
|
||||||
global_scale = global_scale / combined_scale_factor
|
global_scale = global_scale / combined_scale_factor
|
||||||
global_scale = torch.nn.Parameter(global_scale, requires_grad=False)
|
global_scale = torch.nn.Parameter(global_scale, requires_grad=False)
|
||||||
setattr(layer, name + "_weight_scale_2", global_scale)
|
setattr(layer, name + "_weight_scale_2", global_scale)
|
||||||
@@ -665,7 +682,7 @@ def rand_marlin_weight_nvfp4_like(weight, group_size, input_dtype=None):
|
|||||||
)
|
)
|
||||||
marlin_scales, scale_factor = nvfp4_marlin_process_scales(marlin_scales)
|
marlin_scales, scale_factor = nvfp4_marlin_process_scales(marlin_scales)
|
||||||
|
|
||||||
global_scale = nvfp4_marlin_process_global_scale(global_scale)
|
global_scale = nvfp4_marlin_process_global_scale(global_scale).to(torch.float32)
|
||||||
global_scale = global_scale / scale_factor
|
global_scale = global_scale / scale_factor
|
||||||
|
|
||||||
return weight_ref.T, marlin_qweight, marlin_scales, global_scale
|
return weight_ref.T, marlin_qweight, marlin_scales, global_scale
|
||||||
|
|||||||
Reference in New Issue
Block a user