[Refactor][Kernel] Add global helper to deduplicate vectorized memory ops (#35105)
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com> Signed-off-by: LopezCastroRoberto <roberto.lopez.castro@udc.es> Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
e3691988d0
commit
a201ad72d8
@@ -42,7 +42,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
Type const* __restrict__ in,
|
||||
float const* __restrict__ SFScale,
|
||||
uint32_t* __restrict__ out, uint32_t* __restrict__ SFout) {
|
||||
using PackedVec = vllm::PackedVec<Type>;
|
||||
using PackedVec = vllm::PackedVec<Type, CVT_FP4_PACK16>;
|
||||
|
||||
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
|
||||
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||
@@ -71,13 +71,13 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
// If we are outside valid rows OR outside valid columns -> Use Zeros
|
||||
bool valid = (rowIdx < numRows) && (elem_idx < numCols);
|
||||
if constexpr (CVT_FP4_PACK16) {
|
||||
ld256_or_zero_cg_u32<Type>(
|
||||
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 8],
|
||||
valid);
|
||||
ld256_cg_or_zero(reinterpret_cast<u32x8_t&>(in_vec),
|
||||
&reinterpret_cast<const uint32_t*>(in)[inOffset * 8],
|
||||
valid);
|
||||
} else {
|
||||
ld128_or_zero_cg_u32<Type>(
|
||||
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 4],
|
||||
valid);
|
||||
ld128_cg_or_zero(reinterpret_cast<uint4&>(in_vec),
|
||||
&reinterpret_cast<const uint32_t*>(in)[inOffset * 4],
|
||||
valid);
|
||||
}
|
||||
|
||||
auto sf_out =
|
||||
@@ -114,7 +114,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
float const* __restrict__ SFScale,
|
||||
uint32_t* __restrict__ out,
|
||||
uint32_t* __restrict__ SFout) {
|
||||
using PackedVec = PackedVec<Type>;
|
||||
using PackedVec = PackedVec<Type, CVT_FP4_PACK16>;
|
||||
|
||||
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
|
||||
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||
@@ -139,13 +139,13 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
// If we are outside valid rows OR outside valid columns -> Use Zeros
|
||||
bool valid = (rowIdx < numRows) && (elem_idx < numCols);
|
||||
if constexpr (CVT_FP4_PACK16) {
|
||||
ld256_or_zero_cg_u32<Type>(
|
||||
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 8],
|
||||
valid);
|
||||
ld256_cg_or_zero(reinterpret_cast<u32x8_t&>(in_vec),
|
||||
&reinterpret_cast<const uint32_t*>(in)[inOffset * 8],
|
||||
valid);
|
||||
} else {
|
||||
ld128_or_zero_cg_u32<Type>(
|
||||
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 4],
|
||||
valid);
|
||||
ld128_cg_or_zero(reinterpret_cast<uint4&>(in_vec),
|
||||
&reinterpret_cast<const uint32_t*>(in)[inOffset * 4],
|
||||
valid);
|
||||
}
|
||||
|
||||
auto sf_out =
|
||||
|
||||
Reference in New Issue
Block a user