[Bugfix]fix output Nan/Inf in marlin if dtype=float16 (#33972)

Signed-off-by: IriKa Qiu <qiujie.jq@gmail.com>
This commit is contained in:
IriKa
2026-03-28 07:36:08 +08:00
committed by GitHub
parent b69bf2f0b1
commit 148a5c1226
8 changed files with 83 additions and 55 deletions

View File

@@ -13,7 +13,7 @@
const int4 *__restrict__ b_bias_ptr, \
const float *__restrict__ a_scales_ptr, \
const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ global_scale_ptr, \
const float *__restrict__ global_scale_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \

View File

@@ -260,7 +260,7 @@ __global__ void Marlin(
// fp16 quantization scales. shape (k/groupsize, n)
const int4* __restrict__ scales_ptr,
// fp16 global scale (for nvfp4// only)
const uint16_t* __restrict__ global_scale_ptr,
const float* __restrict__ global_scale_ptr,
// 4bit packed zero-points of shape
// (k/groupsize, n/pack_factor)
const int4* __restrict__ zp_ptr,
@@ -308,7 +308,14 @@ __global__ void Marlin(
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id();
static constexpr auto num_bits =
vllm::ScalarType::from_id(b_type_id).size_bits();
// Disable use_fp16_accum for NVFP4 and cases when group_size == -1 &&
// num_bits == 4
constexpr bool use_fp16_accum =
a_type_id == vllm::kFloat16.id() &&
(!(b_type_id == vllm::kFE2M1f.id() && s_type_id == vllm::kFE4M3fn.id()) &&
!(group_blocks == -1 && num_bits == 4));
#else
constexpr bool use_fp16_accum = false;
#endif
@@ -357,7 +364,7 @@ __global__ void Marlin(
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(b_type == vllm::kU8);
c_scalar_t2 global_scale;
float global_scale_f32 = 1.0f;
constexpr bool has_act_order = group_blocks == 0;
@@ -507,11 +514,12 @@ __global__ void Marlin(
if (mul_topk_weights) {
idx = idx < prob_m_top_k ? idx : 0;
c_scalar_t2 topk_weight_val =
Cdtype::num2num2(Cdtype::float2num(topk_weights_ptr[idx]));
float topk_weight_tmp = topk_weights_ptr[idx];
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
topk_weight_val = __hmul2(topk_weight_val, global_scale);
topk_weight_tmp *= global_scale_f32;
}
c_scalar_t2 topk_weight_val =
Cdtype::num2num2(Cdtype::float2num(topk_weight_tmp));
sh_block_topk_weights[threadIdx.x] = topk_weight_val;
}
}
@@ -532,8 +540,7 @@ __global__ void Marlin(
expert_id = expert_ids_ptr[block_id];
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
uint16_t val = global_scale_ptr[expert_id];
global_scale = Cdtype::num2num2(*reinterpret_cast<c_scalar_t*>(&val));
global_scale_f32 = global_scale_ptr[expert_id];
}
B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4);
@@ -1784,6 +1791,13 @@ __global__ void Marlin(
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) {
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
if (!mul_topk_weights) {
c0 *= global_scale_f32;
c1 *= global_scale_f32;
}
}
c_scalar_t2 res =
Cdtype::nums2num2(Cdtype::float2num(c0), Cdtype::float2num(c1));
@@ -1800,11 +1814,6 @@ __global__ void Marlin(
res = __hmul2(res, tmp_scale);
}
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
if (!mul_topk_weights) {
res = __hmul2(res, global_scale);
}
}
if (has_bias && last) {
c_scalar_t2 tmp_bias = b_bias[0];
if constexpr (m_block_size_8) {

View File

@@ -382,7 +382,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
const int4* bias_ptr = (const int4*)b_bias;
const float* a_s_ptr = (const float*)a_s;
const int4* b_s_ptr = (const int4*)b_s;
const uint16_t* g_s_ptr = (const uint16_t*)g_s;
const float* g_s_ptr = (const float*)g_s;
const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx;
const int* perm_ptr = (const int*)perm;
@@ -759,7 +759,7 @@ torch::Tensor moe_wna16_marlin_gemm(
TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn,
"global_scale can only be used for nvfp4 format.");
} else {
global_scale = torch::empty({0}, options);
global_scale = torch::empty({0}, options_fp32);
TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn),
"the global_scale parameter must be passed for nvfp4 format.");
}
@@ -842,8 +842,8 @@ torch::Tensor moe_wna16_marlin_gemm(
TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float,
"scalar type of a_scales must be float");
TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(),
"scalar type of global_scale must be the same with c");
TORCH_CHECK(global_scale.scalar_type() == at::ScalarType::Float,
"scalar type of global_scale must be float");
if (a_type.size_bits() == 16) {
TORCH_CHECK(
a.scalar_type() == c.scalar_type(),

View File

@@ -189,10 +189,7 @@ __device__ __forceinline__ void cp_async_wait<0>() {
}
__device__ __forceinline__ float clip(float v, float mmin, float mmax) {
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
return fminf(mmax, fmaxf(v, mmin));
#else
#endif
}
__device__ __forceinline__ __nv_bfloat16 clip(__nv_bfloat16 v,

View File

@@ -13,7 +13,7 @@
const int4 *__restrict__ b_bias_ptr, \
const float *__restrict__ a_scales_ptr, \
const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ global_scale_ptr, \
const float *__restrict__ global_scale_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \

View File

@@ -57,7 +57,7 @@ torch::Tensor marlin_gemm(
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) {
TORCH_CHECK_NOT_IMPLEMENTED(false,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
"marlin_gemm(..) requires CUDA_ARCH >= 7.5");
return torch::empty({1, 1});
}
@@ -356,7 +356,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
const int4* bias_ptr = (const int4*)b_bias;
const float* a_s_ptr = (const float*)a_s;
const int4* b_s_ptr = (const int4*)b_s;
const uint16_t* g_s_ptr = (const uint16_t*)g_s;
const float* g_s_ptr = (const float*)g_s;
const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx;
@@ -751,7 +751,7 @@ torch::Tensor marlin_gemm(
TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn,
"global_scale can only be used for nvfp4 format.");
} else {
global_scale = torch::empty({0}, options);
global_scale = torch::empty({0}, options_fp32);
TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn),
"the global_scale parameter must be passed for nvfp4 format.");
}
@@ -832,8 +832,8 @@ torch::Tensor marlin_gemm(
TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float,
"scalar type of a_scales must be float");
TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(),
"scalar type of global_scale must be the same with c");
TORCH_CHECK(global_scale.scalar_type() == at::ScalarType::Float,
"scalar type of global_scale must be float");
if (a_type.size_bits() == 16) {
TORCH_CHECK(
a.scalar_type() == c.scalar_type(),

View File

@@ -251,8 +251,8 @@ __global__ void Marlin(
const float* __restrict__ a_scales_ptr,
// fp16 quantization scales. shape (k/groupsize, n)
const int4* __restrict__ scales_ptr,
// fp16 global scale (for nvfp4// only)
const uint16_t* __restrict__ global_scale_ptr,
// float global scale (for nvfp4// only)
const float* __restrict__ global_scale_ptr,
// 4bit packed zero-points of shape
// (k/groupsize, n/pack_factor)
const int4* __restrict__ zp_ptr,
@@ -292,7 +292,13 @@ __global__ void Marlin(
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id();
constexpr auto num_bits = vllm::ScalarType::from_id(b_type_id).size_bits();
// Disable use_fp16_accum for NVFP4 and cases when group_size == -1 &&
// num_bits == 4
constexpr bool use_fp16_accum =
a_type_id == vllm::kFloat16.id() &&
(!(b_type_id == vllm::kFE2M1f.id() && s_type_id == vllm::kFE4M3fn.id()) &&
!(group_blocks == -1 && num_bits == 4));
#else
constexpr bool use_fp16_accum = false;
#endif
@@ -342,11 +348,10 @@ __global__ void Marlin(
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(b_type == vllm::kU8);
c_scalar_t2 global_scale;
float global_scale_f32 = 1.0f;
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
uint16_t val = global_scale_ptr[0];
global_scale = Cdtype::num2num2(*reinterpret_cast<c_scalar_t*>(&val));
global_scale_f32 = global_scale_ptr[0];
}
constexpr bool has_act_order = group_blocks == 0;
@@ -1644,6 +1649,10 @@ __global__ void Marlin(
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) {
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
c0 *= global_scale_f32;
c1 *= global_scale_f32;
}
c_scalar_t2 res =
Cdtype::nums2num2(Cdtype::float2num(c0), Cdtype::float2num(c1));
@@ -1659,10 +1668,6 @@ __global__ void Marlin(
}
res = __hmul2(res, tmp_scale);
}
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
res = __hmul2(res, global_scale);
}
if (has_bias && last) {
c_scalar_t2 tmp_bias = b_bias[0];
if constexpr (m_block_size_8) {

View File

@@ -27,10 +27,19 @@ def is_fp4_marlin_supported():
return current_platform.has_device_capability(75)
def _nvfp4_compute_scale_factor(marlin_scales: torch.Tensor) -> float:
def _nvfp4_compute_scale_factor(
marlin_scales: torch.Tensor,
a_dtype: torch.dtype | None = None,
) -> float:
"""Compute the power-of-2 scale_factor needed so that all non-zero
values in marlin_scales * 2^7 are >= 2 after rescaling.
Returns a Python float (power of 2, >= 1.0)."""
# Since half has a smaller dynamic range compared to bfloat16,
# no rescaling is applied here if active dtype is half.
if a_dtype is not None and a_dtype == torch.half:
return 1.0
ws_float = marlin_scales.float() * (2**7)
nonzero_mask = ws_float > 0
if nonzero_mask.any():
@@ -44,6 +53,7 @@ def _nvfp4_compute_scale_factor(marlin_scales: torch.Tensor) -> float:
def nvfp4_marlin_process_scales(
marlin_scales: torch.Tensor,
scale_factor: float | None = None,
a_dtype: torch.dtype | None = None,
) -> tuple[torch.Tensor, float]:
"""Process NVFP4 weight scales into the special S0E5M3 format for Marlin.
@@ -91,7 +101,7 @@ def nvfp4_marlin_process_scales(
# to fully utilize the E4M3 dynamic range (e.g., global_scale=1).
# The caller must compensate by dividing global_scale by scale_factor.
if scale_factor is None:
scale_factor = _nvfp4_compute_scale_factor(marlin_scales)
scale_factor = _nvfp4_compute_scale_factor(marlin_scales, a_dtype)
if scale_factor > 1.0:
marlin_scales = (marlin_scales.float() * scale_factor).to(torch.half)
@@ -119,12 +129,14 @@ def mxfp4_marlin_process_scales(marlin_scales, input_dtype=None):
return marlin_scales
def nvfp4_marlin_process_global_scale(global_scale):
assert global_scale.dtype in [torch.half, torch.bfloat16]
def nvfp4_marlin_process_global_scale(global_scale, a_dtype: torch.dtype | None = None):
if a_dtype is None:
a_dtype = global_scale.dtype
assert a_dtype in [torch.half, torch.bfloat16]
fp4_exponent = 2
if global_scale.dtype == torch.half:
if a_dtype == torch.half:
target_exponent = 5
elif global_scale.dtype == torch.bfloat16:
elif a_dtype == torch.bfloat16:
target_exponent = 8
# exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14
# exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126
@@ -244,11 +256,15 @@ def prepare_fp4_layer_for_marlin(
)
if is_nvfp4:
weight_scale, scale_factor = nvfp4_marlin_process_scales(weight_scale)
weight_scale, scale_factor = nvfp4_marlin_process_scales(
weight_scale, a_dtype=param_dtype
)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
weight_global_scale = layer.weight_global_scale.to(param_dtype)
weight_global_scale = nvfp4_marlin_process_global_scale(weight_global_scale)
weight_global_scale = layer.weight_global_scale.to(torch.float32)
weight_global_scale = nvfp4_marlin_process_global_scale(
weight_global_scale, param_dtype
)
weight_global_scale = weight_global_scale / scale_factor
layer.weight_global_scale = torch.nn.Parameter(
weight_global_scale, requires_grad=False
@@ -339,7 +355,6 @@ def prepare_nvfp4_moe_layer_for_marlin(
scales: torch.Tensor, g_scales: torch.Tensor, name: str
) -> tuple[torch.Tensor, torch.Tensor]:
scales = scales.to(param_dtype)
g_scales = g_scales.to(param_dtype)
tensor_list = []
num_shards = 2 if is_act_and_mul else 1
@@ -350,7 +365,7 @@ def prepare_nvfp4_moe_layer_for_marlin(
# All experts share one global_scale, so compute the max
# scale_factor across all experts first, then apply uniformly.
combined_scale_factor = _nvfp4_compute_scale_factor(scales)
combined_scale_factor = _nvfp4_compute_scale_factor(scales, param_dtype)
for i in range(E):
scale = scales[i].T
@@ -362,12 +377,12 @@ def prepare_nvfp4_moe_layer_for_marlin(
is_a_8bit=is_a_8bit,
)
marlin_scales, _ = nvfp4_marlin_process_scales(
marlin_scales, scale_factor=combined_scale_factor
marlin_scales, scale_factor=combined_scale_factor, a_dtype=param_dtype
)
tensor_list.append(marlin_scales)
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
g_scales = nvfp4_marlin_process_global_scale(g_scales)
g_scales = nvfp4_marlin_process_global_scale(g_scales, param_dtype)
g_scales = g_scales / combined_scale_factor
return scales, g_scales
@@ -438,7 +453,7 @@ def prepare_moe_fp4_layer_for_marlin(
scales = scales.view(torch.float8_e8m0fnu)
scales = scales.to(param_dtype)
if is_nvfp4:
global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype)
global_scale = getattr(layer, name + "_weight_scale_2")
tensor_list = []
if "w13" in name:
@@ -449,7 +464,7 @@ def prepare_moe_fp4_layer_for_marlin(
# For NVFP4: compute unified scale_factor across all experts
combined_scale_factor = None
if is_nvfp4:
combined_scale_factor = _nvfp4_compute_scale_factor(scales)
combined_scale_factor = _nvfp4_compute_scale_factor(scales, param_dtype)
for i in range(e):
scale = scales[i].T
@@ -463,7 +478,9 @@ def prepare_moe_fp4_layer_for_marlin(
)
if is_nvfp4:
marlin_scales, _ = nvfp4_marlin_process_scales(
marlin_scales, scale_factor=combined_scale_factor
marlin_scales,
scale_factor=combined_scale_factor,
a_dtype=param_dtype,
)
else:
marlin_scales = mxfp4_marlin_process_scales(
@@ -477,7 +494,7 @@ def prepare_moe_fp4_layer_for_marlin(
if is_nvfp4:
assert combined_scale_factor is not None
global_scale = nvfp4_marlin_process_global_scale(global_scale)
global_scale = nvfp4_marlin_process_global_scale(global_scale, param_dtype)
global_scale = global_scale / combined_scale_factor
global_scale = torch.nn.Parameter(global_scale, requires_grad=False)
setattr(layer, name + "_weight_scale_2", global_scale)
@@ -665,7 +682,7 @@ def rand_marlin_weight_nvfp4_like(weight, group_size, input_dtype=None):
)
marlin_scales, scale_factor = nvfp4_marlin_process_scales(marlin_scales)
global_scale = nvfp4_marlin_process_global_scale(global_scale)
global_scale = nvfp4_marlin_process_global_scale(global_scale).to(torch.float32)
global_scale = global_scale / scale_factor
return weight_ref.T, marlin_qweight, marlin_scales, global_scale