KV-1: Fix shared memory corruption in block_reduce
block_reduce_sum/max write to smem[0..n_warps-1] but we passed &s_amax (single float). For 128 threads / 4 warps, this wrote 4 floats starting at &s_amax, corrupting adjacent shared variables (s_inv_rms, s_vals). Fix: use s_scratch[8] array (4 for sum, 4 for max) with proper sizing.
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
/**
|
||||
* FUSED CSA/HCA compress + RMSNorm + NVFP4 quantize kernels.
|
||||
* KV-1/KV-2: Single kernel launch. FP32 -> E2M1 + E4M3 + FP32 gsa.
|
||||
* Quantize: each thread independently handles 16-element blocks.
|
||||
* No warp-shuffle group reduction (previous version had a cross-group bug).
|
||||
*
|
||||
* FIX: block_reduce_sum/max need n_warps shared memory slots, not a single
|
||||
* float. Previous version passed &s_amax (1 float) but the functions write
|
||||
* smem[0..n_warps-1], corrupting adjacent shared variables.
|
||||
*/
|
||||
|
||||
#include <cuda.h>
|
||||
@@ -62,6 +64,10 @@ __global__ void csa_compress_reduce_quant_kernel(
|
||||
int ps = (bi-1)*m, cs = bi*m;
|
||||
int cpt = (hd+nt-1)/nt;
|
||||
|
||||
// Shared memory: reduction scratch (8 floats for n_warps=4) + FP32 staging
|
||||
__shared__ float s_scratch[8]; // 0-3: sum/norm, 4-7: max/gsa
|
||||
__shared__ float s_vals[512]; // FP32 staging for quantize
|
||||
|
||||
float lv[4],lm[4],ld[4],la[4];
|
||||
for (int ci=0;ci<cpt;ci++) {
|
||||
int c=tid+ci*nt; if(c>=hd) break;
|
||||
@@ -83,41 +89,40 @@ __global__ void csa_compress_reduce_quant_kernel(
|
||||
lv[ci]=(ld[ci]>0)?(la[ci]/ld[ci]):0;
|
||||
}
|
||||
|
||||
// kv_norm
|
||||
if(kv_norm_weight) {
|
||||
float ls=0; for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;ls+=lv[ci]*lv[ci];}
|
||||
__shared__ float ss; float ts=block_reduce_sum(ls,&ss,nw);
|
||||
__shared__ float sir; if(tid==0) sir=rsqrtf(ts/hd+1e-6f); __syncthreads();
|
||||
float ts=block_reduce_sum(ls,&s_scratch[0],nw);
|
||||
if(tid==0) s_scratch[0]=rsqrtf(ts/hd+1e-6f); // s_inv_rms
|
||||
__syncthreads();
|
||||
float sir=s_scratch[0];
|
||||
for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;lv[ci]*=sir*kv_norm_weight[c];}
|
||||
}
|
||||
|
||||
// gsa
|
||||
float ea=0; for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;ea=fmaxf(ea,fabsf(lv[ci]));}
|
||||
__shared__ float sam; float ga=block_reduce_max(ea,&sam,nw);
|
||||
float ga=block_reduce_max(ea,&s_scratch[4],nw);
|
||||
float gsa=fmaxf(ga,1e-8f)/(6.0f*448.0f);
|
||||
if(tid==0) out_gsa[bi]=gsa;
|
||||
|
||||
__shared__ float sv[512];
|
||||
for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;sv[c]=lv[ci];}
|
||||
// Write to shared memory for quantize
|
||||
for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;s_vals[c]=lv[ci];}
|
||||
__syncthreads();
|
||||
|
||||
// Quantize: each thread handles one or more 16-element blocks
|
||||
int nfb=hd/16;
|
||||
for(int b=tid;b<nfb;b+=nt) {
|
||||
int base=b*16;
|
||||
float ba=0; for(int i=0;i<16;i++){int c=base+i;if(c<hd)ba=fmaxf(ba,fabsf(sv[c])/gsa);}
|
||||
float ba=0; for(int i=0;i<16;i++){int c=base+i;if(c<hd)ba=fmaxf(ba,fabsf(s_vals[c])/gsa);}
|
||||
float bsf=ba/6.0f; bool z=(ba<6.0f*0.001953125f);
|
||||
// CRITICAL: quantize using the FP8-round-tripped block scale, not the raw value.
|
||||
// The dequant reads FP8, so quantize must match the dequant's scale exactly.
|
||||
// Same pattern as quantize_nvfp4.cu.
|
||||
// Quantize using FP8-round-tripped block scale (matches dequant)
|
||||
float bs_rt=0.0f;
|
||||
if(z){out_sf[bi*(hd/16)+b]=0;}
|
||||
else{
|
||||
__nv_fp8_e4m3 o(bsf);
|
||||
out_sf[bi*(hd/16)+b]=*(uint8_t*)&o;
|
||||
bs_rt=(float)o; // round-tripped through FP8
|
||||
}
|
||||
else{__nv_fp8_e4m3 o(bsf);out_sf[bi*(hd/16)+b]=*(uint8_t*)&o;bs_rt=(float)o;}
|
||||
for(int i=0;i<8;i++){
|
||||
int c0=base+2*i,c1=base+2*i+1; uint8_t lo=0,hi=0;
|
||||
if(!z&&c0<hd){float s=(sv[c0]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;lo=half_step_to_e2m1(hs);if(s<0)lo+=8;}
|
||||
if(!z&&c1<hd){float s=(sv[c1]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;hi=half_step_to_e2m1(hs);if(s<0)hi+=8;}
|
||||
if(!z&&c0<hd){float s=(s_vals[c0]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;lo=half_step_to_e2m1(hs);if(s<0)lo+=8;}
|
||||
if(!z&&c1<hd){float s=(s_vals[c1]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;hi=half_step_to_e2m1(hs);if(s<0)hi+=8;}
|
||||
out_fp4[bi*(hd/2)+b*8+i]=(hi<<4)|lo;
|
||||
}
|
||||
}
|
||||
@@ -134,6 +139,9 @@ __global__ void hca_compress_reduce_quant_kernel(
|
||||
if(bi>=n_blocks) return;
|
||||
int cpt=(hd+nt-1)/nt;
|
||||
|
||||
__shared__ float s_scratch[8];
|
||||
__shared__ float s_vals[512];
|
||||
|
||||
float lv[4];
|
||||
for(int ci=0;ci<cpt;ci++){
|
||||
int c=tid+ci*nt;if(c>=hd)break;
|
||||
@@ -145,32 +153,33 @@ __global__ void hca_compress_reduce_quant_kernel(
|
||||
|
||||
if(kv_norm_weight){
|
||||
float ls=0;for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;ls+=lv[ci]*lv[ci];}
|
||||
__shared__ float ss;float ts=block_reduce_sum(ls,&ss,nw);
|
||||
__shared__ float sir;if(tid==0)sir=rsqrtf(ts/hd+1e-6f);__syncthreads();
|
||||
float ts=block_reduce_sum(ls,&s_scratch[0],nw);
|
||||
if(tid==0)s_scratch[0]=rsqrtf(ts/hd+1e-6f);
|
||||
__syncthreads();
|
||||
float sir=s_scratch[0];
|
||||
for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;lv[ci]*=sir*kv_norm_weight[c];}
|
||||
}
|
||||
|
||||
float ea=0;for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;ea=fmaxf(ea,fabsf(lv[ci]));}
|
||||
__shared__ float sam;float ga=block_reduce_max(ea,&sam,nw);
|
||||
float ga=block_reduce_max(ea,&s_scratch[4],nw);
|
||||
float gsa=fmaxf(ga,1e-8f)/(6.0f*448.0f);
|
||||
if(tid==0)out_gsa[bi]=gsa;
|
||||
|
||||
__shared__ float sv[512];
|
||||
for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;sv[c]=lv[ci];}
|
||||
for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;s_vals[c]=lv[ci];}
|
||||
__syncthreads();
|
||||
|
||||
int nfb=hd/16;
|
||||
for(int b=tid;b<nfb;b+=nt){
|
||||
int base=b*16;
|
||||
float ba=0;for(int i=0;i<16;i++){int c=base+i;if(c<hd)ba=fmaxf(ba,fabsf(sv[c])/gsa);}
|
||||
float ba=0;for(int i=0;i<16;i++){int c=base+i;if(c<hd)ba=fmaxf(ba,fabsf(s_vals[c])/gsa);}
|
||||
float bsf=ba/6.0f;bool z=(ba<6.0f*0.001953125f);
|
||||
float bs_rt=0.0f;
|
||||
if(z){out_sf[bi*(hd/16)+b]=0;}
|
||||
else{__nv_fp8_e4m3 o(bsf);out_sf[bi*(hd/16)+b]=*(uint8_t*)&o;bs_rt=(float)o;}
|
||||
for(int i=0;i<8;i++){
|
||||
int c0=base+2*i,c1=base+2*i+1;uint8_t lo=0,hi=0;
|
||||
if(!z&&c0<hd){float s=(sv[c0]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;lo=half_step_to_e2m1(hs);if(s<0)lo+=8;}
|
||||
if(!z&&c1<hd){float s=(sv[c1]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;hi=half_step_to_e2m1(hs);if(s<0)hi+=8;}
|
||||
if(!z&&c0<hd){float s=(s_vals[c0]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;lo=half_step_to_e2m1(hs);if(s<0)lo+=8;}
|
||||
if(!z&&c1<hd){float s=(s_vals[c1]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;hi=half_step_to_e2m1(hs);if(s<0)hi+=8;}
|
||||
out_fp4[bi*(hd/2)+b*8+i]=(hi<<4)|lo;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user