diff --git a/dsv4/kernels/cuda/compressor_reduce_quant.cu b/dsv4/kernels/cuda/compressor_reduce_quant.cu index 797471d7..08fcc80a 100644 --- a/dsv4/kernels/cuda/compressor_reduce_quant.cu +++ b/dsv4/kernels/cuda/compressor_reduce_quant.cu @@ -1,8 +1,10 @@ /** * FUSED CSA/HCA compress + RMSNorm + NVFP4 quantize kernels. * KV-1/KV-2: Single kernel launch. FP32 -> E2M1 + E4M3 + FP32 gsa. - * Quantize: each thread independently handles 16-element blocks. - * No warp-shuffle group reduction (previous version had a cross-group bug). + * + * FIX: block_reduce_sum/max need n_warps shared memory slots, not a single + * float. Previous version passed &s_amax (1 float) but the functions write + * smem[0..n_warps-1], corrupting adjacent shared variables. */ #include @@ -62,6 +64,10 @@ __global__ void csa_compress_reduce_quant_kernel( int ps = (bi-1)*m, cs = bi*m; int cpt = (hd+nt-1)/nt; + // Shared memory: reduction scratch (8 floats for n_warps=4) + FP32 staging + __shared__ float s_scratch[8]; // 0-3: sum/norm, 4-7: max/gsa + __shared__ float s_vals[512]; // FP32 staging for quantize + float lv[4],lm[4],ld[4],la[4]; for (int ci=0;ci=hd) break; @@ -83,41 +89,40 @@ __global__ void csa_compress_reduce_quant_kernel( lv[ci]=(ld[ci]>0)?(la[ci]/ld[ci]):0; } + // kv_norm if(kv_norm_weight) { float ls=0; for(int ci=0;ci=hd)break;ls+=lv[ci]*lv[ci];} - __shared__ float ss; float ts=block_reduce_sum(ls,&ss,nw); - __shared__ float sir; if(tid==0) sir=rsqrtf(ts/hd+1e-6f); __syncthreads(); + float ts=block_reduce_sum(ls,&s_scratch[0],nw); + if(tid==0) s_scratch[0]=rsqrtf(ts/hd+1e-6f); // s_inv_rms + __syncthreads(); + float sir=s_scratch[0]; for(int ci=0;ci=hd)break;lv[ci]*=sir*kv_norm_weight[c];} } + // gsa float ea=0; for(int ci=0;ci=hd)break;ea=fmaxf(ea,fabsf(lv[ci]));} - __shared__ float sam; float ga=block_reduce_max(ea,&sam,nw); + float ga=block_reduce_max(ea,&s_scratch[4],nw); float gsa=fmaxf(ga,1e-8f)/(6.0f*448.0f); if(tid==0) out_gsa[bi]=gsa; - __shared__ float sv[512]; - for(int ci=0;ci=hd)break;sv[c]=lv[ci];} + // Write to shared memory for quantize + for(int ci=0;ci=hd)break;s_vals[c]=lv[ci];} __syncthreads(); + // Quantize: each thread handles one or more 16-element blocks int nfb=hd/16; for(int b=tid;b12)hs=12;lo=half_step_to_e2m1(hs);if(s<0)lo+=8;} - if(!z&&c112)hs=12;hi=half_step_to_e2m1(hs);if(s<0)hi+=8;} + if(!z&&c012)hs=12;lo=half_step_to_e2m1(hs);if(s<0)lo+=8;} + if(!z&&c112)hs=12;hi=half_step_to_e2m1(hs);if(s<0)hi+=8;} out_fp4[bi*(hd/2)+b*8+i]=(hi<<4)|lo; } } @@ -134,6 +139,9 @@ __global__ void hca_compress_reduce_quant_kernel( if(bi>=n_blocks) return; int cpt=(hd+nt-1)/nt; + __shared__ float s_scratch[8]; + __shared__ float s_vals[512]; + float lv[4]; for(int ci=0;ci=hd)break; @@ -145,32 +153,33 @@ __global__ void hca_compress_reduce_quant_kernel( if(kv_norm_weight){ float ls=0;for(int ci=0;ci=hd)break;ls+=lv[ci]*lv[ci];} - __shared__ float ss;float ts=block_reduce_sum(ls,&ss,nw); - __shared__ float sir;if(tid==0)sir=rsqrtf(ts/hd+1e-6f);__syncthreads(); + float ts=block_reduce_sum(ls,&s_scratch[0],nw); + if(tid==0)s_scratch[0]=rsqrtf(ts/hd+1e-6f); + __syncthreads(); + float sir=s_scratch[0]; for(int ci=0;ci=hd)break;lv[ci]*=sir*kv_norm_weight[c];} } float ea=0;for(int ci=0;ci=hd)break;ea=fmaxf(ea,fabsf(lv[ci]));} - __shared__ float sam;float ga=block_reduce_max(ea,&sam,nw); + float ga=block_reduce_max(ea,&s_scratch[4],nw); float gsa=fmaxf(ga,1e-8f)/(6.0f*448.0f); if(tid==0)out_gsa[bi]=gsa; - __shared__ float sv[512]; - for(int ci=0;ci=hd)break;sv[c]=lv[ci];} + for(int ci=0;ci=hd)break;s_vals[c]=lv[ci];} __syncthreads(); int nfb=hd/16; for(int b=tid;b12)hs=12;lo=half_step_to_e2m1(hs);if(s<0)lo+=8;} - if(!z&&c112)hs=12;hi=half_step_to_e2m1(hs);if(s<0)hi+=8;} + if(!z&&c012)hs=12;lo=half_step_to_e2m1(hs);if(s<0)lo+=8;} + if(!z&&c112)hs=12;hi=half_step_to_e2m1(hs);if(s<0)hi+=8;} out_fp4[bi*(hd/2)+b*8+i]=(hi<<4)|lo; } }