[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin@redhat.com>
This commit is contained in:
@@ -437,10 +437,10 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
||||
for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) {
|
||||
#pragma unroll
|
||||
for (int k_idx = 0; k_idx < 2; ++k_idx) {
|
||||
FType low16 =
|
||||
ScalarType<FType>::float2num(C_frag[m_idx][n_idx][k_idx * 2]);
|
||||
FType high16 =
|
||||
ScalarType<FType>::float2num(C_frag[m_idx][n_idx][k_idx * 2 + 1]);
|
||||
FType low16 = MarlinScalarType2<FType>::float2num(
|
||||
C_frag[m_idx][n_idx][k_idx * 2]);
|
||||
FType high16 = MarlinScalarType2<FType>::float2num(
|
||||
C_frag[m_idx][n_idx][k_idx * 2 + 1]);
|
||||
uint32_t tmp = (reinterpret_cast<uint32_t&>(low16) & 0xffff) |
|
||||
(reinterpret_cast<uint32_t&>(high16) << 16);
|
||||
int sts_offset =
|
||||
|
||||
Reference in New Issue
Block a user