diff --git a/csrc/moe/marlin_moe_wna16/kernel.h b/csrc/moe/marlin_moe_wna16/kernel.h index e5a3a0b9c..09ed1a470 100644 --- a/csrc/moe/marlin_moe_wna16/kernel.h +++ b/csrc/moe/marlin_moe_wna16/kernel.h @@ -13,7 +13,7 @@ const int4 *__restrict__ b_bias_ptr, \ const float *__restrict__ a_scales_ptr, \ const int4 *__restrict__ scales_ptr, \ - const uint16_t *__restrict__ global_scale_ptr, \ + const float *__restrict__ global_scale_ptr, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ const int32_t *__restrict__ sorted_token_ids_ptr, \ const int32_t *__restrict__ expert_ids_ptr, \ diff --git a/csrc/moe/marlin_moe_wna16/marlin_template.h b/csrc/moe/marlin_moe_wna16/marlin_template.h index cddc42643..f5685b898 100644 --- a/csrc/moe/marlin_moe_wna16/marlin_template.h +++ b/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -260,7 +260,7 @@ __global__ void Marlin( // fp16 quantization scales. shape (k/groupsize, n) const int4* __restrict__ scales_ptr, // fp16 global scale (for nvfp4// only) - const uint16_t* __restrict__ global_scale_ptr, + const float* __restrict__ global_scale_ptr, // 4bit packed zero-points of shape // (k/groupsize, n/pack_factor) const int4* __restrict__ zp_ptr, @@ -308,7 +308,14 @@ __global__ void Marlin( constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 - constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id(); + static constexpr auto num_bits = + vllm::ScalarType::from_id(b_type_id).size_bits(); + // Disable use_fp16_accum for NVFP4 and cases when group_size == -1 && + // num_bits == 4 + constexpr bool use_fp16_accum = + a_type_id == vllm::kFloat16.id() && + (!(b_type_id == vllm::kFE2M1f.id() && s_type_id == vllm::kFE4M3fn.id()) && + !(group_blocks == -1 && num_bits == 4)); #else constexpr bool use_fp16_accum = false; #endif @@ -357,7 +364,7 @@ __global__ void Marlin( has_zp && !is_zp_float && !std::is_same::value || has_zp && !is_zp_float && !(b_type == vllm::kU8); - c_scalar_t2 global_scale; + float global_scale_f32 = 1.0f; constexpr bool has_act_order = group_blocks == 0; @@ -507,11 +514,12 @@ __global__ void Marlin( if (mul_topk_weights) { idx = idx < prob_m_top_k ? idx : 0; - c_scalar_t2 topk_weight_val = - Cdtype::num2num2(Cdtype::float2num(topk_weights_ptr[idx])); + float topk_weight_tmp = topk_weights_ptr[idx]; if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { - topk_weight_val = __hmul2(topk_weight_val, global_scale); + topk_weight_tmp *= global_scale_f32; } + c_scalar_t2 topk_weight_val = + Cdtype::num2num2(Cdtype::float2num(topk_weight_tmp)); sh_block_topk_weights[threadIdx.x] = topk_weight_val; } } @@ -532,8 +540,7 @@ __global__ void Marlin( expert_id = expert_ids_ptr[block_id]; if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { - uint16_t val = global_scale_ptr[expert_id]; - global_scale = Cdtype::num2num2(*reinterpret_cast(&val)); + global_scale_f32 = global_scale_ptr[expert_id]; } B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4); @@ -1784,6 +1791,13 @@ __global__ void Marlin( // We first reorder in shared memory to guarantee the most efficient final // global write patterns auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) { + if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { + if (!mul_topk_weights) { + c0 *= global_scale_f32; + c1 *= global_scale_f32; + } + } + c_scalar_t2 res = Cdtype::nums2num2(Cdtype::float2num(c0), Cdtype::float2num(c1)); @@ -1800,11 +1814,6 @@ __global__ void Marlin( res = __hmul2(res, tmp_scale); } - if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { - if (!mul_topk_weights) { - res = __hmul2(res, global_scale); - } - } if (has_bias && last) { c_scalar_t2 tmp_bias = b_bias[0]; if constexpr (m_block_size_8) { diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu index e3f3b4175..60681ad93 100644 --- a/csrc/moe/marlin_moe_wna16/ops.cu +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -382,7 +382,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, const int4* bias_ptr = (const int4*)b_bias; const float* a_s_ptr = (const float*)a_s; const int4* b_s_ptr = (const int4*)b_s; - const uint16_t* g_s_ptr = (const uint16_t*)g_s; + const float* g_s_ptr = (const float*)g_s; const int4* zp_ptr = (const int4*)zp; const int* g_idx_ptr = (const int*)g_idx; const int* perm_ptr = (const int*)perm; @@ -759,7 +759,7 @@ torch::Tensor moe_wna16_marlin_gemm( TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn, "global_scale can only be used for nvfp4 format."); } else { - global_scale = torch::empty({0}, options); + global_scale = torch::empty({0}, options_fp32); TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn), "the global_scale parameter must be passed for nvfp4 format."); } @@ -842,8 +842,8 @@ torch::Tensor moe_wna16_marlin_gemm( TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float, "scalar type of a_scales must be float"); - TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(), - "scalar type of global_scale must be the same with c"); + TORCH_CHECK(global_scale.scalar_type() == at::ScalarType::Float, + "scalar type of global_scale must be float"); if (a_type.size_bits() == 16) { TORCH_CHECK( a.scalar_type() == c.scalar_type(), diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index c0153bb41..8cc645c33 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -189,10 +189,7 @@ __device__ __forceinline__ void cp_async_wait<0>() { } __device__ __forceinline__ float clip(float v, float mmin, float mmax) { -#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 return fminf(mmax, fmaxf(v, mmin)); -#else -#endif } __device__ __forceinline__ __nv_bfloat16 clip(__nv_bfloat16 v, diff --git a/csrc/quantization/marlin/kernel.h b/csrc/quantization/marlin/kernel.h index b3b79c8ae..8c9cec88b 100644 --- a/csrc/quantization/marlin/kernel.h +++ b/csrc/quantization/marlin/kernel.h @@ -13,7 +13,7 @@ const int4 *__restrict__ b_bias_ptr, \ const float *__restrict__ a_scales_ptr, \ const int4 *__restrict__ scales_ptr, \ - const uint16_t *__restrict__ global_scale_ptr, \ + const float *__restrict__ global_scale_ptr, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \ diff --git a/csrc/quantization/marlin/marlin.cu b/csrc/quantization/marlin/marlin.cu index 62826128c..fbdb619c2 100644 --- a/csrc/quantization/marlin/marlin.cu +++ b/csrc/quantization/marlin/marlin.cu @@ -57,7 +57,7 @@ torch::Tensor marlin_gemm( int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) { TORCH_CHECK_NOT_IMPLEMENTED(false, - "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); + "marlin_gemm(..) requires CUDA_ARCH >= 7.5"); return torch::empty({1, 1}); } @@ -356,7 +356,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, const int4* bias_ptr = (const int4*)b_bias; const float* a_s_ptr = (const float*)a_s; const int4* b_s_ptr = (const int4*)b_s; - const uint16_t* g_s_ptr = (const uint16_t*)g_s; + const float* g_s_ptr = (const float*)g_s; const int4* zp_ptr = (const int4*)zp; const int* g_idx_ptr = (const int*)g_idx; @@ -751,7 +751,7 @@ torch::Tensor marlin_gemm( TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn, "global_scale can only be used for nvfp4 format."); } else { - global_scale = torch::empty({0}, options); + global_scale = torch::empty({0}, options_fp32); TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn), "the global_scale parameter must be passed for nvfp4 format."); } @@ -832,8 +832,8 @@ torch::Tensor marlin_gemm( TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float, "scalar type of a_scales must be float"); - TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(), - "scalar type of global_scale must be the same with c"); + TORCH_CHECK(global_scale.scalar_type() == at::ScalarType::Float, + "scalar type of global_scale must be float"); if (a_type.size_bits() == 16) { TORCH_CHECK( a.scalar_type() == c.scalar_type(), diff --git a/csrc/quantization/marlin/marlin_template.h b/csrc/quantization/marlin/marlin_template.h index c7b53696c..9e625b645 100644 --- a/csrc/quantization/marlin/marlin_template.h +++ b/csrc/quantization/marlin/marlin_template.h @@ -251,8 +251,8 @@ __global__ void Marlin( const float* __restrict__ a_scales_ptr, // fp16 quantization scales. shape (k/groupsize, n) const int4* __restrict__ scales_ptr, - // fp16 global scale (for nvfp4// only) - const uint16_t* __restrict__ global_scale_ptr, + // float global scale (for nvfp4// only) + const float* __restrict__ global_scale_ptr, // 4bit packed zero-points of shape // (k/groupsize, n/pack_factor) const int4* __restrict__ zp_ptr, @@ -292,7 +292,13 @@ __global__ void Marlin( #endif #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 - constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id(); + constexpr auto num_bits = vllm::ScalarType::from_id(b_type_id).size_bits(); + // Disable use_fp16_accum for NVFP4 and cases when group_size == -1 && + // num_bits == 4 + constexpr bool use_fp16_accum = + a_type_id == vllm::kFloat16.id() && + (!(b_type_id == vllm::kFE2M1f.id() && s_type_id == vllm::kFE4M3fn.id()) && + !(group_blocks == -1 && num_bits == 4)); #else constexpr bool use_fp16_accum = false; #endif @@ -342,11 +348,10 @@ __global__ void Marlin( has_zp && !is_zp_float && !std::is_same::value || has_zp && !is_zp_float && !(b_type == vllm::kU8); - c_scalar_t2 global_scale; + float global_scale_f32 = 1.0f; if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { - uint16_t val = global_scale_ptr[0]; - global_scale = Cdtype::num2num2(*reinterpret_cast(&val)); + global_scale_f32 = global_scale_ptr[0]; } constexpr bool has_act_order = group_blocks == 0; @@ -1644,6 +1649,10 @@ __global__ void Marlin( // We first reorder in shared memory to guarantee the most efficient final // global write patterns auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) { + if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { + c0 *= global_scale_f32; + c1 *= global_scale_f32; + } c_scalar_t2 res = Cdtype::nums2num2(Cdtype::float2num(c0), Cdtype::float2num(c1)); @@ -1659,10 +1668,6 @@ __global__ void Marlin( } res = __hmul2(res, tmp_scale); } - - if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { - res = __hmul2(res, global_scale); - } if (has_bias && last) { c_scalar_t2 tmp_bias = b_bias[0]; if constexpr (m_block_size_8) { diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index 9bc58d2f3..4fd484ede 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -27,10 +27,19 @@ def is_fp4_marlin_supported(): return current_platform.has_device_capability(75) -def _nvfp4_compute_scale_factor(marlin_scales: torch.Tensor) -> float: +def _nvfp4_compute_scale_factor( + marlin_scales: torch.Tensor, + a_dtype: torch.dtype | None = None, +) -> float: """Compute the power-of-2 scale_factor needed so that all non-zero values in marlin_scales * 2^7 are >= 2 after rescaling. Returns a Python float (power of 2, >= 1.0).""" + + # Since half has a smaller dynamic range compared to bfloat16, + # no rescaling is applied here if active dtype is half. + if a_dtype is not None and a_dtype == torch.half: + return 1.0 + ws_float = marlin_scales.float() * (2**7) nonzero_mask = ws_float > 0 if nonzero_mask.any(): @@ -44,6 +53,7 @@ def _nvfp4_compute_scale_factor(marlin_scales: torch.Tensor) -> float: def nvfp4_marlin_process_scales( marlin_scales: torch.Tensor, scale_factor: float | None = None, + a_dtype: torch.dtype | None = None, ) -> tuple[torch.Tensor, float]: """Process NVFP4 weight scales into the special S0E5M3 format for Marlin. @@ -91,7 +101,7 @@ def nvfp4_marlin_process_scales( # to fully utilize the E4M3 dynamic range (e.g., global_scale=1). # The caller must compensate by dividing global_scale by scale_factor. if scale_factor is None: - scale_factor = _nvfp4_compute_scale_factor(marlin_scales) + scale_factor = _nvfp4_compute_scale_factor(marlin_scales, a_dtype) if scale_factor > 1.0: marlin_scales = (marlin_scales.float() * scale_factor).to(torch.half) @@ -119,12 +129,14 @@ def mxfp4_marlin_process_scales(marlin_scales, input_dtype=None): return marlin_scales -def nvfp4_marlin_process_global_scale(global_scale): - assert global_scale.dtype in [torch.half, torch.bfloat16] +def nvfp4_marlin_process_global_scale(global_scale, a_dtype: torch.dtype | None = None): + if a_dtype is None: + a_dtype = global_scale.dtype + assert a_dtype in [torch.half, torch.bfloat16] fp4_exponent = 2 - if global_scale.dtype == torch.half: + if a_dtype == torch.half: target_exponent = 5 - elif global_scale.dtype == torch.bfloat16: + elif a_dtype == torch.bfloat16: target_exponent = 8 # exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14 # exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126 @@ -244,11 +256,15 @@ def prepare_fp4_layer_for_marlin( ) if is_nvfp4: - weight_scale, scale_factor = nvfp4_marlin_process_scales(weight_scale) + weight_scale, scale_factor = nvfp4_marlin_process_scales( + weight_scale, a_dtype=param_dtype + ) layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) - weight_global_scale = layer.weight_global_scale.to(param_dtype) - weight_global_scale = nvfp4_marlin_process_global_scale(weight_global_scale) + weight_global_scale = layer.weight_global_scale.to(torch.float32) + weight_global_scale = nvfp4_marlin_process_global_scale( + weight_global_scale, param_dtype + ) weight_global_scale = weight_global_scale / scale_factor layer.weight_global_scale = torch.nn.Parameter( weight_global_scale, requires_grad=False @@ -339,7 +355,6 @@ def prepare_nvfp4_moe_layer_for_marlin( scales: torch.Tensor, g_scales: torch.Tensor, name: str ) -> tuple[torch.Tensor, torch.Tensor]: scales = scales.to(param_dtype) - g_scales = g_scales.to(param_dtype) tensor_list = [] num_shards = 2 if is_act_and_mul else 1 @@ -350,7 +365,7 @@ def prepare_nvfp4_moe_layer_for_marlin( # All experts share one global_scale, so compute the max # scale_factor across all experts first, then apply uniformly. - combined_scale_factor = _nvfp4_compute_scale_factor(scales) + combined_scale_factor = _nvfp4_compute_scale_factor(scales, param_dtype) for i in range(E): scale = scales[i].T @@ -362,12 +377,12 @@ def prepare_nvfp4_moe_layer_for_marlin( is_a_8bit=is_a_8bit, ) marlin_scales, _ = nvfp4_marlin_process_scales( - marlin_scales, scale_factor=combined_scale_factor + marlin_scales, scale_factor=combined_scale_factor, a_dtype=param_dtype ) tensor_list.append(marlin_scales) scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) - g_scales = nvfp4_marlin_process_global_scale(g_scales) + g_scales = nvfp4_marlin_process_global_scale(g_scales, param_dtype) g_scales = g_scales / combined_scale_factor return scales, g_scales @@ -438,7 +453,7 @@ def prepare_moe_fp4_layer_for_marlin( scales = scales.view(torch.float8_e8m0fnu) scales = scales.to(param_dtype) if is_nvfp4: - global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype) + global_scale = getattr(layer, name + "_weight_scale_2") tensor_list = [] if "w13" in name: @@ -449,7 +464,7 @@ def prepare_moe_fp4_layer_for_marlin( # For NVFP4: compute unified scale_factor across all experts combined_scale_factor = None if is_nvfp4: - combined_scale_factor = _nvfp4_compute_scale_factor(scales) + combined_scale_factor = _nvfp4_compute_scale_factor(scales, param_dtype) for i in range(e): scale = scales[i].T @@ -463,7 +478,9 @@ def prepare_moe_fp4_layer_for_marlin( ) if is_nvfp4: marlin_scales, _ = nvfp4_marlin_process_scales( - marlin_scales, scale_factor=combined_scale_factor + marlin_scales, + scale_factor=combined_scale_factor, + a_dtype=param_dtype, ) else: marlin_scales = mxfp4_marlin_process_scales( @@ -477,7 +494,7 @@ def prepare_moe_fp4_layer_for_marlin( if is_nvfp4: assert combined_scale_factor is not None - global_scale = nvfp4_marlin_process_global_scale(global_scale) + global_scale = nvfp4_marlin_process_global_scale(global_scale, param_dtype) global_scale = global_scale / combined_scale_factor global_scale = torch.nn.Parameter(global_scale, requires_grad=False) setattr(layer, name + "_weight_scale_2", global_scale) @@ -665,7 +682,7 @@ def rand_marlin_weight_nvfp4_like(weight, group_size, input_dtype=None): ) marlin_scales, scale_factor = nvfp4_marlin_process_scales(marlin_scales) - global_scale = nvfp4_marlin_process_global_scale(global_scale) + global_scale = nvfp4_marlin_process_global_scale(global_scale).to(torch.float32) global_scale = global_scale / scale_factor return weight_ref.T, marlin_qweight, marlin_scales, global_scale