KV-1: Fix shared memory corruption in block_reduce

block_reduce_sum/max write to smem[0..n_warps-1] but we passed &s_amax
(single float). For 128 threads / 4 warps, this wrote 4 floats starting
at &s_amax, corrupting adjacent shared variables (s_inv_rms, s_vals).

Fix: use s_scratch[8] array (4 for sum, 4 for max) with proper sizing.
This commit is contained in:
2026-06-02 09:49:12 +00:00
parent 0fefadedd4
commit 40dd56eac2

View File

@@ -1,8 +1,10 @@
/**
* FUSED CSA/HCA compress + RMSNorm + NVFP4 quantize kernels.
* KV-1/KV-2: Single kernel launch. FP32 -> E2M1 + E4M3 + FP32 gsa.
* Quantize: each thread independently handles 16-element blocks.
* No warp-shuffle group reduction (previous version had a cross-group bug).
*
* FIX: block_reduce_sum/max need n_warps shared memory slots, not a single
* float. Previous version passed &s_amax (1 float) but the functions write
* smem[0..n_warps-1], corrupting adjacent shared variables.
*/
#include <cuda.h>
@@ -62,6 +64,10 @@ __global__ void csa_compress_reduce_quant_kernel(
int ps = (bi-1)*m, cs = bi*m;
int cpt = (hd+nt-1)/nt;
// Shared memory: reduction scratch (8 floats for n_warps=4) + FP32 staging
__shared__ float s_scratch[8]; // 0-3: sum/norm, 4-7: max/gsa
__shared__ float s_vals[512]; // FP32 staging for quantize
float lv[4],lm[4],ld[4],la[4];
for (int ci=0;ci<cpt;ci++) {
int c=tid+ci*nt; if(c>=hd) break;
@@ -83,41 +89,40 @@ __global__ void csa_compress_reduce_quant_kernel(
lv[ci]=(ld[ci]>0)?(la[ci]/ld[ci]):0;
}
// kv_norm
if(kv_norm_weight) {
float ls=0; for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;ls+=lv[ci]*lv[ci];}
__shared__ float ss; float ts=block_reduce_sum(ls,&ss,nw);
__shared__ float sir; if(tid==0) sir=rsqrtf(ts/hd+1e-6f); __syncthreads();
float ts=block_reduce_sum(ls,&s_scratch[0],nw);
if(tid==0) s_scratch[0]=rsqrtf(ts/hd+1e-6f); // s_inv_rms
__syncthreads();
float sir=s_scratch[0];
for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;lv[ci]*=sir*kv_norm_weight[c];}
}
// gsa
float ea=0; for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;ea=fmaxf(ea,fabsf(lv[ci]));}
__shared__ float sam; float ga=block_reduce_max(ea,&sam,nw);
float ga=block_reduce_max(ea,&s_scratch[4],nw);
float gsa=fmaxf(ga,1e-8f)/(6.0f*448.0f);
if(tid==0) out_gsa[bi]=gsa;
__shared__ float sv[512];
for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;sv[c]=lv[ci];}
// Write to shared memory for quantize
for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;s_vals[c]=lv[ci];}
__syncthreads();
// Quantize: each thread handles one or more 16-element blocks
int nfb=hd/16;
for(int b=tid;b<nfb;b+=nt) {
int base=b*16;
float ba=0; for(int i=0;i<16;i++){int c=base+i;if(c<hd)ba=fmaxf(ba,fabsf(sv[c])/gsa);}
float ba=0; for(int i=0;i<16;i++){int c=base+i;if(c<hd)ba=fmaxf(ba,fabsf(s_vals[c])/gsa);}
float bsf=ba/6.0f; bool z=(ba<6.0f*0.001953125f);
// CRITICAL: quantize using the FP8-round-tripped block scale, not the raw value.
// The dequant reads FP8, so quantize must match the dequant's scale exactly.
// Same pattern as quantize_nvfp4.cu.
// Quantize using FP8-round-tripped block scale (matches dequant)
float bs_rt=0.0f;
if(z){out_sf[bi*(hd/16)+b]=0;}
else{
__nv_fp8_e4m3 o(bsf);
out_sf[bi*(hd/16)+b]=*(uint8_t*)&o;
bs_rt=(float)o; // round-tripped through FP8
}
else{__nv_fp8_e4m3 o(bsf);out_sf[bi*(hd/16)+b]=*(uint8_t*)&o;bs_rt=(float)o;}
for(int i=0;i<8;i++){
int c0=base+2*i,c1=base+2*i+1; uint8_t lo=0,hi=0;
if(!z&&c0<hd){float s=(sv[c0]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;lo=half_step_to_e2m1(hs);if(s<0)lo+=8;}
if(!z&&c1<hd){float s=(sv[c1]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;hi=half_step_to_e2m1(hs);if(s<0)hi+=8;}
if(!z&&c0<hd){float s=(s_vals[c0]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;lo=half_step_to_e2m1(hs);if(s<0)lo+=8;}
if(!z&&c1<hd){float s=(s_vals[c1]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;hi=half_step_to_e2m1(hs);if(s<0)hi+=8;}
out_fp4[bi*(hd/2)+b*8+i]=(hi<<4)|lo;
}
}
@@ -134,6 +139,9 @@ __global__ void hca_compress_reduce_quant_kernel(
if(bi>=n_blocks) return;
int cpt=(hd+nt-1)/nt;
__shared__ float s_scratch[8];
__shared__ float s_vals[512];
float lv[4];
for(int ci=0;ci<cpt;ci++){
int c=tid+ci*nt;if(c>=hd)break;
@@ -145,32 +153,33 @@ __global__ void hca_compress_reduce_quant_kernel(
if(kv_norm_weight){
float ls=0;for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;ls+=lv[ci]*lv[ci];}
__shared__ float ss;float ts=block_reduce_sum(ls,&ss,nw);
__shared__ float sir;if(tid==0)sir=rsqrtf(ts/hd+1e-6f);__syncthreads();
float ts=block_reduce_sum(ls,&s_scratch[0],nw);
if(tid==0)s_scratch[0]=rsqrtf(ts/hd+1e-6f);
__syncthreads();
float sir=s_scratch[0];
for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;lv[ci]*=sir*kv_norm_weight[c];}
}
float ea=0;for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;ea=fmaxf(ea,fabsf(lv[ci]));}
__shared__ float sam;float ga=block_reduce_max(ea,&sam,nw);
float ga=block_reduce_max(ea,&s_scratch[4],nw);
float gsa=fmaxf(ga,1e-8f)/(6.0f*448.0f);
if(tid==0)out_gsa[bi]=gsa;
__shared__ float sv[512];
for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;sv[c]=lv[ci];}
for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;s_vals[c]=lv[ci];}
__syncthreads();
int nfb=hd/16;
for(int b=tid;b<nfb;b+=nt){
int base=b*16;
float ba=0;for(int i=0;i<16;i++){int c=base+i;if(c<hd)ba=fmaxf(ba,fabsf(sv[c])/gsa);}
float ba=0;for(int i=0;i<16;i++){int c=base+i;if(c<hd)ba=fmaxf(ba,fabsf(s_vals[c])/gsa);}
float bsf=ba/6.0f;bool z=(ba<6.0f*0.001953125f);
float bs_rt=0.0f;
if(z){out_sf[bi*(hd/16)+b]=0;}
else{__nv_fp8_e4m3 o(bsf);out_sf[bi*(hd/16)+b]=*(uint8_t*)&o;bs_rt=(float)o;}
for(int i=0;i<8;i++){
int c0=base+2*i,c1=base+2*i+1;uint8_t lo=0,hi=0;
if(!z&&c0<hd){float s=(sv[c0]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;lo=half_step_to_e2m1(hs);if(s<0)lo+=8;}
if(!z&&c1<hd){float s=(sv[c1]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;hi=half_step_to_e2m1(hs);if(s<0)hi+=8;}
if(!z&&c0<hd){float s=(s_vals[c0]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;lo=half_step_to_e2m1(hs);if(s<0)lo+=8;}
if(!z&&c1<hd){float s=(s_vals[c1]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;hi=half_step_to_e2m1(hs);if(s<0)hi+=8;}
out_fp4[bi*(hd/2)+b*8+i]=(hi<<4)|lo;
}
}