[Refactor] Fix Compile Warning #1444-D (#21208)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -20,6 +20,7 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "../cuda_compat.h"
|
||||
#include <cuda/std/functional>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/util_type.cuh>
|
||||
@@ -62,7 +63,7 @@ __launch_bounds__(TPB) __global__
|
||||
|
||||
const int thread_row_offset = blockIdx.x * num_cols;
|
||||
|
||||
cub::Sum sum;
|
||||
cuda::std::plus<float> sum;
|
||||
float threadData(-FLT_MAX);
|
||||
|
||||
// Don't touch finished rows.
|
||||
|
||||
Reference in New Issue
Block a user