[Bugfix]fix output Nan/Inf in marlin if dtype=float16 (#33972)
Signed-off-by: IriKa Qiu <qiujie.jq@gmail.com>
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
const int4 *__restrict__ b_bias_ptr, \
|
||||
const float *__restrict__ a_scales_ptr, \
|
||||
const int4 *__restrict__ scales_ptr, \
|
||||
const uint16_t *__restrict__ global_scale_ptr, \
|
||||
const float *__restrict__ global_scale_ptr, \
|
||||
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
||||
const int32_t *__restrict__ sorted_token_ids_ptr, \
|
||||
const int32_t *__restrict__ expert_ids_ptr, \
|
||||
|
||||
@@ -260,7 +260,7 @@ __global__ void Marlin(
|
||||
// fp16 quantization scales. shape (k/groupsize, n)
|
||||
const int4* __restrict__ scales_ptr,
|
||||
// fp16 global scale (for nvfp4// only)
|
||||
const uint16_t* __restrict__ global_scale_ptr,
|
||||
const float* __restrict__ global_scale_ptr,
|
||||
// 4bit packed zero-points of shape
|
||||
// (k/groupsize, n/pack_factor)
|
||||
const int4* __restrict__ zp_ptr,
|
||||
@@ -308,7 +308,14 @@ __global__ void Marlin(
|
||||
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id();
|
||||
static constexpr auto num_bits =
|
||||
vllm::ScalarType::from_id(b_type_id).size_bits();
|
||||
// Disable use_fp16_accum for NVFP4 and cases when group_size == -1 &&
|
||||
// num_bits == 4
|
||||
constexpr bool use_fp16_accum =
|
||||
a_type_id == vllm::kFloat16.id() &&
|
||||
(!(b_type_id == vllm::kFE2M1f.id() && s_type_id == vllm::kFE4M3fn.id()) &&
|
||||
!(group_blocks == -1 && num_bits == 4));
|
||||
#else
|
||||
constexpr bool use_fp16_accum = false;
|
||||
#endif
|
||||
@@ -357,7 +364,7 @@ __global__ void Marlin(
|
||||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
|
||||
has_zp && !is_zp_float && !(b_type == vllm::kU8);
|
||||
|
||||
c_scalar_t2 global_scale;
|
||||
float global_scale_f32 = 1.0f;
|
||||
|
||||
constexpr bool has_act_order = group_blocks == 0;
|
||||
|
||||
@@ -507,11 +514,12 @@ __global__ void Marlin(
|
||||
|
||||
if (mul_topk_weights) {
|
||||
idx = idx < prob_m_top_k ? idx : 0;
|
||||
c_scalar_t2 topk_weight_val =
|
||||
Cdtype::num2num2(Cdtype::float2num(topk_weights_ptr[idx]));
|
||||
float topk_weight_tmp = topk_weights_ptr[idx];
|
||||
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||
topk_weight_val = __hmul2(topk_weight_val, global_scale);
|
||||
topk_weight_tmp *= global_scale_f32;
|
||||
}
|
||||
c_scalar_t2 topk_weight_val =
|
||||
Cdtype::num2num2(Cdtype::float2num(topk_weight_tmp));
|
||||
sh_block_topk_weights[threadIdx.x] = topk_weight_val;
|
||||
}
|
||||
}
|
||||
@@ -532,8 +540,7 @@ __global__ void Marlin(
|
||||
expert_id = expert_ids_ptr[block_id];
|
||||
|
||||
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||
uint16_t val = global_scale_ptr[expert_id];
|
||||
global_scale = Cdtype::num2num2(*reinterpret_cast<c_scalar_t*>(&val));
|
||||
global_scale_f32 = global_scale_ptr[expert_id];
|
||||
}
|
||||
|
||||
B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4);
|
||||
@@ -1784,6 +1791,13 @@ __global__ void Marlin(
|
||||
// We first reorder in shared memory to guarantee the most efficient final
|
||||
// global write patterns
|
||||
auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) {
|
||||
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||
if (!mul_topk_weights) {
|
||||
c0 *= global_scale_f32;
|
||||
c1 *= global_scale_f32;
|
||||
}
|
||||
}
|
||||
|
||||
c_scalar_t2 res =
|
||||
Cdtype::nums2num2(Cdtype::float2num(c0), Cdtype::float2num(c1));
|
||||
|
||||
@@ -1800,11 +1814,6 @@ __global__ void Marlin(
|
||||
res = __hmul2(res, tmp_scale);
|
||||
}
|
||||
|
||||
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||
if (!mul_topk_weights) {
|
||||
res = __hmul2(res, global_scale);
|
||||
}
|
||||
}
|
||||
if (has_bias && last) {
|
||||
c_scalar_t2 tmp_bias = b_bias[0];
|
||||
if constexpr (m_block_size_8) {
|
||||
|
||||
@@ -382,7 +382,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
const int4* bias_ptr = (const int4*)b_bias;
|
||||
const float* a_s_ptr = (const float*)a_s;
|
||||
const int4* b_s_ptr = (const int4*)b_s;
|
||||
const uint16_t* g_s_ptr = (const uint16_t*)g_s;
|
||||
const float* g_s_ptr = (const float*)g_s;
|
||||
const int4* zp_ptr = (const int4*)zp;
|
||||
const int* g_idx_ptr = (const int*)g_idx;
|
||||
const int* perm_ptr = (const int*)perm;
|
||||
@@ -759,7 +759,7 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn,
|
||||
"global_scale can only be used for nvfp4 format.");
|
||||
} else {
|
||||
global_scale = torch::empty({0}, options);
|
||||
global_scale = torch::empty({0}, options_fp32);
|
||||
TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn),
|
||||
"the global_scale parameter must be passed for nvfp4 format.");
|
||||
}
|
||||
@@ -842,8 +842,8 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
|
||||
TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float,
|
||||
"scalar type of a_scales must be float");
|
||||
TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(),
|
||||
"scalar type of global_scale must be the same with c");
|
||||
TORCH_CHECK(global_scale.scalar_type() == at::ScalarType::Float,
|
||||
"scalar type of global_scale must be float");
|
||||
if (a_type.size_bits() == 16) {
|
||||
TORCH_CHECK(
|
||||
a.scalar_type() == c.scalar_type(),
|
||||
|
||||
@@ -189,10 +189,7 @@ __device__ __forceinline__ void cp_async_wait<0>() {
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float clip(float v, float mmin, float mmax) {
|
||||
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
||||
return fminf(mmax, fmaxf(v, mmin));
|
||||
#else
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ __forceinline__ __nv_bfloat16 clip(__nv_bfloat16 v,
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
const int4 *__restrict__ b_bias_ptr, \
|
||||
const float *__restrict__ a_scales_ptr, \
|
||||
const int4 *__restrict__ scales_ptr, \
|
||||
const uint16_t *__restrict__ global_scale_ptr, \
|
||||
const float *__restrict__ global_scale_ptr, \
|
||||
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
||||
int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
|
||||
bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \
|
||||
|
||||
@@ -57,7 +57,7 @@ torch::Tensor marlin_gemm(
|
||||
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
|
||||
"marlin_gemm(..) requires CUDA_ARCH >= 7.5");
|
||||
return torch::empty({1, 1});
|
||||
}
|
||||
|
||||
@@ -356,7 +356,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
const int4* bias_ptr = (const int4*)b_bias;
|
||||
const float* a_s_ptr = (const float*)a_s;
|
||||
const int4* b_s_ptr = (const int4*)b_s;
|
||||
const uint16_t* g_s_ptr = (const uint16_t*)g_s;
|
||||
const float* g_s_ptr = (const float*)g_s;
|
||||
|
||||
const int4* zp_ptr = (const int4*)zp;
|
||||
const int* g_idx_ptr = (const int*)g_idx;
|
||||
@@ -751,7 +751,7 @@ torch::Tensor marlin_gemm(
|
||||
TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn,
|
||||
"global_scale can only be used for nvfp4 format.");
|
||||
} else {
|
||||
global_scale = torch::empty({0}, options);
|
||||
global_scale = torch::empty({0}, options_fp32);
|
||||
TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn),
|
||||
"the global_scale parameter must be passed for nvfp4 format.");
|
||||
}
|
||||
@@ -832,8 +832,8 @@ torch::Tensor marlin_gemm(
|
||||
|
||||
TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float,
|
||||
"scalar type of a_scales must be float");
|
||||
TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(),
|
||||
"scalar type of global_scale must be the same with c");
|
||||
TORCH_CHECK(global_scale.scalar_type() == at::ScalarType::Float,
|
||||
"scalar type of global_scale must be float");
|
||||
if (a_type.size_bits() == 16) {
|
||||
TORCH_CHECK(
|
||||
a.scalar_type() == c.scalar_type(),
|
||||
|
||||
@@ -251,8 +251,8 @@ __global__ void Marlin(
|
||||
const float* __restrict__ a_scales_ptr,
|
||||
// fp16 quantization scales. shape (k/groupsize, n)
|
||||
const int4* __restrict__ scales_ptr,
|
||||
// fp16 global scale (for nvfp4// only)
|
||||
const uint16_t* __restrict__ global_scale_ptr,
|
||||
// float global scale (for nvfp4// only)
|
||||
const float* __restrict__ global_scale_ptr,
|
||||
// 4bit packed zero-points of shape
|
||||
// (k/groupsize, n/pack_factor)
|
||||
const int4* __restrict__ zp_ptr,
|
||||
@@ -292,7 +292,13 @@ __global__ void Marlin(
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id();
|
||||
constexpr auto num_bits = vllm::ScalarType::from_id(b_type_id).size_bits();
|
||||
// Disable use_fp16_accum for NVFP4 and cases when group_size == -1 &&
|
||||
// num_bits == 4
|
||||
constexpr bool use_fp16_accum =
|
||||
a_type_id == vllm::kFloat16.id() &&
|
||||
(!(b_type_id == vllm::kFE2M1f.id() && s_type_id == vllm::kFE4M3fn.id()) &&
|
||||
!(group_blocks == -1 && num_bits == 4));
|
||||
#else
|
||||
constexpr bool use_fp16_accum = false;
|
||||
#endif
|
||||
@@ -342,11 +348,10 @@ __global__ void Marlin(
|
||||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
|
||||
has_zp && !is_zp_float && !(b_type == vllm::kU8);
|
||||
|
||||
c_scalar_t2 global_scale;
|
||||
float global_scale_f32 = 1.0f;
|
||||
|
||||
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||
uint16_t val = global_scale_ptr[0];
|
||||
global_scale = Cdtype::num2num2(*reinterpret_cast<c_scalar_t*>(&val));
|
||||
global_scale_f32 = global_scale_ptr[0];
|
||||
}
|
||||
|
||||
constexpr bool has_act_order = group_blocks == 0;
|
||||
@@ -1644,6 +1649,10 @@ __global__ void Marlin(
|
||||
// We first reorder in shared memory to guarantee the most efficient final
|
||||
// global write patterns
|
||||
auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) {
|
||||
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||
c0 *= global_scale_f32;
|
||||
c1 *= global_scale_f32;
|
||||
}
|
||||
c_scalar_t2 res =
|
||||
Cdtype::nums2num2(Cdtype::float2num(c0), Cdtype::float2num(c1));
|
||||
|
||||
@@ -1659,10 +1668,6 @@ __global__ void Marlin(
|
||||
}
|
||||
res = __hmul2(res, tmp_scale);
|
||||
}
|
||||
|
||||
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||
res = __hmul2(res, global_scale);
|
||||
}
|
||||
if (has_bias && last) {
|
||||
c_scalar_t2 tmp_bias = b_bias[0];
|
||||
if constexpr (m_block_size_8) {
|
||||
|
||||
@@ -27,10 +27,19 @@ def is_fp4_marlin_supported():
|
||||
return current_platform.has_device_capability(75)
|
||||
|
||||
|
||||
def _nvfp4_compute_scale_factor(marlin_scales: torch.Tensor) -> float:
|
||||
def _nvfp4_compute_scale_factor(
|
||||
marlin_scales: torch.Tensor,
|
||||
a_dtype: torch.dtype | None = None,
|
||||
) -> float:
|
||||
"""Compute the power-of-2 scale_factor needed so that all non-zero
|
||||
values in marlin_scales * 2^7 are >= 2 after rescaling.
|
||||
Returns a Python float (power of 2, >= 1.0)."""
|
||||
|
||||
# Since half has a smaller dynamic range compared to bfloat16,
|
||||
# no rescaling is applied here if active dtype is half.
|
||||
if a_dtype is not None and a_dtype == torch.half:
|
||||
return 1.0
|
||||
|
||||
ws_float = marlin_scales.float() * (2**7)
|
||||
nonzero_mask = ws_float > 0
|
||||
if nonzero_mask.any():
|
||||
@@ -44,6 +53,7 @@ def _nvfp4_compute_scale_factor(marlin_scales: torch.Tensor) -> float:
|
||||
def nvfp4_marlin_process_scales(
|
||||
marlin_scales: torch.Tensor,
|
||||
scale_factor: float | None = None,
|
||||
a_dtype: torch.dtype | None = None,
|
||||
) -> tuple[torch.Tensor, float]:
|
||||
"""Process NVFP4 weight scales into the special S0E5M3 format for Marlin.
|
||||
|
||||
@@ -91,7 +101,7 @@ def nvfp4_marlin_process_scales(
|
||||
# to fully utilize the E4M3 dynamic range (e.g., global_scale=1).
|
||||
# The caller must compensate by dividing global_scale by scale_factor.
|
||||
if scale_factor is None:
|
||||
scale_factor = _nvfp4_compute_scale_factor(marlin_scales)
|
||||
scale_factor = _nvfp4_compute_scale_factor(marlin_scales, a_dtype)
|
||||
if scale_factor > 1.0:
|
||||
marlin_scales = (marlin_scales.float() * scale_factor).to(torch.half)
|
||||
|
||||
@@ -119,12 +129,14 @@ def mxfp4_marlin_process_scales(marlin_scales, input_dtype=None):
|
||||
return marlin_scales
|
||||
|
||||
|
||||
def nvfp4_marlin_process_global_scale(global_scale):
|
||||
assert global_scale.dtype in [torch.half, torch.bfloat16]
|
||||
def nvfp4_marlin_process_global_scale(global_scale, a_dtype: torch.dtype | None = None):
|
||||
if a_dtype is None:
|
||||
a_dtype = global_scale.dtype
|
||||
assert a_dtype in [torch.half, torch.bfloat16]
|
||||
fp4_exponent = 2
|
||||
if global_scale.dtype == torch.half:
|
||||
if a_dtype == torch.half:
|
||||
target_exponent = 5
|
||||
elif global_scale.dtype == torch.bfloat16:
|
||||
elif a_dtype == torch.bfloat16:
|
||||
target_exponent = 8
|
||||
# exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14
|
||||
# exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126
|
||||
@@ -244,11 +256,15 @@ def prepare_fp4_layer_for_marlin(
|
||||
)
|
||||
|
||||
if is_nvfp4:
|
||||
weight_scale, scale_factor = nvfp4_marlin_process_scales(weight_scale)
|
||||
weight_scale, scale_factor = nvfp4_marlin_process_scales(
|
||||
weight_scale, a_dtype=param_dtype
|
||||
)
|
||||
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
weight_global_scale = layer.weight_global_scale.to(param_dtype)
|
||||
weight_global_scale = nvfp4_marlin_process_global_scale(weight_global_scale)
|
||||
weight_global_scale = layer.weight_global_scale.to(torch.float32)
|
||||
weight_global_scale = nvfp4_marlin_process_global_scale(
|
||||
weight_global_scale, param_dtype
|
||||
)
|
||||
weight_global_scale = weight_global_scale / scale_factor
|
||||
layer.weight_global_scale = torch.nn.Parameter(
|
||||
weight_global_scale, requires_grad=False
|
||||
@@ -339,7 +355,6 @@ def prepare_nvfp4_moe_layer_for_marlin(
|
||||
scales: torch.Tensor, g_scales: torch.Tensor, name: str
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
scales = scales.to(param_dtype)
|
||||
g_scales = g_scales.to(param_dtype)
|
||||
|
||||
tensor_list = []
|
||||
num_shards = 2 if is_act_and_mul else 1
|
||||
@@ -350,7 +365,7 @@ def prepare_nvfp4_moe_layer_for_marlin(
|
||||
|
||||
# All experts share one global_scale, so compute the max
|
||||
# scale_factor across all experts first, then apply uniformly.
|
||||
combined_scale_factor = _nvfp4_compute_scale_factor(scales)
|
||||
combined_scale_factor = _nvfp4_compute_scale_factor(scales, param_dtype)
|
||||
|
||||
for i in range(E):
|
||||
scale = scales[i].T
|
||||
@@ -362,12 +377,12 @@ def prepare_nvfp4_moe_layer_for_marlin(
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
marlin_scales, _ = nvfp4_marlin_process_scales(
|
||||
marlin_scales, scale_factor=combined_scale_factor
|
||||
marlin_scales, scale_factor=combined_scale_factor, a_dtype=param_dtype
|
||||
)
|
||||
tensor_list.append(marlin_scales)
|
||||
|
||||
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
g_scales = nvfp4_marlin_process_global_scale(g_scales)
|
||||
g_scales = nvfp4_marlin_process_global_scale(g_scales, param_dtype)
|
||||
g_scales = g_scales / combined_scale_factor
|
||||
return scales, g_scales
|
||||
|
||||
@@ -438,7 +453,7 @@ def prepare_moe_fp4_layer_for_marlin(
|
||||
scales = scales.view(torch.float8_e8m0fnu)
|
||||
scales = scales.to(param_dtype)
|
||||
if is_nvfp4:
|
||||
global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype)
|
||||
global_scale = getattr(layer, name + "_weight_scale_2")
|
||||
|
||||
tensor_list = []
|
||||
if "w13" in name:
|
||||
@@ -449,7 +464,7 @@ def prepare_moe_fp4_layer_for_marlin(
|
||||
# For NVFP4: compute unified scale_factor across all experts
|
||||
combined_scale_factor = None
|
||||
if is_nvfp4:
|
||||
combined_scale_factor = _nvfp4_compute_scale_factor(scales)
|
||||
combined_scale_factor = _nvfp4_compute_scale_factor(scales, param_dtype)
|
||||
|
||||
for i in range(e):
|
||||
scale = scales[i].T
|
||||
@@ -463,7 +478,9 @@ def prepare_moe_fp4_layer_for_marlin(
|
||||
)
|
||||
if is_nvfp4:
|
||||
marlin_scales, _ = nvfp4_marlin_process_scales(
|
||||
marlin_scales, scale_factor=combined_scale_factor
|
||||
marlin_scales,
|
||||
scale_factor=combined_scale_factor,
|
||||
a_dtype=param_dtype,
|
||||
)
|
||||
else:
|
||||
marlin_scales = mxfp4_marlin_process_scales(
|
||||
@@ -477,7 +494,7 @@ def prepare_moe_fp4_layer_for_marlin(
|
||||
|
||||
if is_nvfp4:
|
||||
assert combined_scale_factor is not None
|
||||
global_scale = nvfp4_marlin_process_global_scale(global_scale)
|
||||
global_scale = nvfp4_marlin_process_global_scale(global_scale, param_dtype)
|
||||
global_scale = global_scale / combined_scale_factor
|
||||
global_scale = torch.nn.Parameter(global_scale, requires_grad=False)
|
||||
setattr(layer, name + "_weight_scale_2", global_scale)
|
||||
@@ -665,7 +682,7 @@ def rand_marlin_weight_nvfp4_like(weight, group_size, input_dtype=None):
|
||||
)
|
||||
marlin_scales, scale_factor = nvfp4_marlin_process_scales(marlin_scales)
|
||||
|
||||
global_scale = nvfp4_marlin_process_global_scale(global_scale)
|
||||
global_scale = nvfp4_marlin_process_global_scale(global_scale).to(torch.float32)
|
||||
global_scale = global_scale / scale_factor
|
||||
|
||||
return weight_ref.T, marlin_qweight, marlin_scales, global_scale
|
||||
|
||||
Reference in New Issue
Block a user