[NVIDIA] Blackwell Family (#24673)
Signed-off-by: Johnny <johnnynuca14@gmail.com> Signed-off-by: johnnynunez <johnnynuca14@gmail.com> Signed-off-by: Johnny <johnnync13@gmail.com> Signed-off-by: Salvatore Cena <cena@cenas.it> Co-authored-by: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com> Co-authored-by: Salvatore Cena <cena@cenas.it>
This commit is contained in:
@@ -231,7 +231,7 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
|
||||
} else {
|
||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||
OutType, 1, TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>,
|
||||
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
|
||||
Shape<_1, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
@@ -245,7 +245,7 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
|
||||
} else {
|
||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||
OutType, 1, TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>,
|
||||
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
|
||||
Shape<_1, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
@@ -259,7 +259,7 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
|
||||
} else {
|
||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||
OutType, 1, TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>,
|
||||
Shape<_2, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized2Sm,
|
||||
Shape<_2, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized2Sm,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
@@ -271,10 +271,10 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
|
||||
// TMA epilogue isn't compatible with Swap A/B
|
||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||
OutType, TILE_M, 1, TILE_K, Shape<Int<TILE_M>, Int<TILE_N>, Int<TILE_K>>,
|
||||
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
|
||||
Shape<_1, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100, true>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
} // namespace vllm
|
||||
@@ -133,4 +133,4 @@ void cutlass_scaled_mm_sm100_fp8_epilogue(torch::Tensor& out,
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
} // namespace vllm
|
||||
|
||||
Reference in New Issue
Block a user