[Compile] Fix Compile Warning for w4a8_mm_entry.cu (#23660)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -11,6 +11,7 @@
|
|||||||
#include "core/registration.h"
|
#include "core/registration.h"
|
||||||
|
|
||||||
#include "cutlass/cutlass.h"
|
#include "cutlass/cutlass.h"
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
#include "cute/tensor.hpp"
|
#include "cute/tensor.hpp"
|
||||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||||
@@ -169,6 +170,11 @@ struct W4A8GemmKernel {
|
|||||||
int k = A.size(1);
|
int k = A.size(1);
|
||||||
int n = B.size(1);
|
int n = B.size(1);
|
||||||
|
|
||||||
|
// safely cast group_size to int
|
||||||
|
TORCH_CHECK(group_size > 0 && group_size <= std::numeric_limits<int>::max(),
|
||||||
|
"group_size out of supported range for int: ", group_size);
|
||||||
|
int const group_size_int = static_cast<int>(group_size);
|
||||||
|
|
||||||
// Allocate output
|
// Allocate output
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||||
auto device = A.device();
|
auto device = A.device();
|
||||||
@@ -192,7 +198,7 @@ struct W4A8GemmKernel {
|
|||||||
cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
|
cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||||
|
|
||||||
// strides
|
// strides
|
||||||
int const scale_k = cutlass::ceil_div(k, group_size);
|
int const scale_k = cutlass::ceil_div(k, group_size_int);
|
||||||
StrideA stride_A =
|
StrideA stride_A =
|
||||||
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
|
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
|
||||||
// Reverse stride here due to swap and transpose
|
// Reverse stride here due to swap and transpose
|
||||||
@@ -211,8 +217,8 @@ struct W4A8GemmKernel {
|
|||||||
using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments;
|
using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments;
|
||||||
|
|
||||||
MainloopArguments mainloop_arguments{
|
MainloopArguments mainloop_arguments{
|
||||||
B_ptr, layout_B_reordered, A_ptr, stride_A,
|
B_ptr, layout_B_reordered, A_ptr, stride_A,
|
||||||
S_ptr, stride_S, group_size};
|
S_ptr, stride_S, group_size_int};
|
||||||
|
|
||||||
EpilogueArguments epilogue_arguments{
|
EpilogueArguments epilogue_arguments{
|
||||||
ChTokScalesEpilogue::prepare_args(channel_scales, token_scales),
|
ChTokScalesEpilogue::prepare_args(channel_scales, token_scales),
|
||||||
|
|||||||
Reference in New Issue
Block a user