[Bugfix] Use latency MOE backend as default for Flashinfer and other misc fixes (#27439)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -31,6 +31,13 @@
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename Int>
|
||||
__host__ __device__ inline Int round_up(Int x, Int y) {
|
||||
static_assert(std::is_integral_v<Int>,
|
||||
"round_up argument must be integral type");
|
||||
return (x + y - 1) / y * y;
|
||||
}
|
||||
|
||||
// Use UE4M3 by default.
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
@@ -42,10 +49,21 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
||||
"Vec size is not matched.");
|
||||
|
||||
int sf_m = round_up<int>(numRows, 128);
|
||||
int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE;
|
||||
int sf_n_int = round_up<int>(sf_n_unpadded, 4) / 4;
|
||||
for (int row = numRows + blockIdx.x; row < sf_m; row += gridDim.x) {
|
||||
// Each thread writes 4 uint32_t elements.
|
||||
for (int col = sf_n_unpadded + threadIdx.x * 4; col < sf_n_int;
|
||||
col += blockDim.x * 4) {
|
||||
SFout[row * sf_n_int + col] = 0x00;
|
||||
}
|
||||
}
|
||||
|
||||
// Get the global scaling factor, which will be applied to the SF.
|
||||
// Note SFScale is the same as next GEMM's alpha, which is
|
||||
// (448.f / (Alpha_A / 6.f)).
|
||||
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0];
|
||||
float const global_scale = SFScale == nullptr ? 1.0f : SFScale[0];
|
||||
|
||||
// Input tensor row/col loops.
|
||||
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
|
||||
@@ -64,7 +82,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
rowIdx, colIdx, numCols, SFout);
|
||||
|
||||
out_pos =
|
||||
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
||||
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, global_scale, sf_out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user