KV-1: Fix FP8 round-trip mismatch in fused quantize
CRITICAL: quantize must use the FP8-round-tripped block scale, not the raw pre-FP8 value. The dequant reads the FP8 bytes back, so the quantize must match exactly. Same pattern as quantize_nvfp4.cu. This was the root cause of cos=0.925 (should be ~0.995).
This commit is contained in:
@@ -104,12 +104,20 @@ __global__ void csa_compress_reduce_quant_kernel(
|
||||
int base=b*16;
|
||||
float ba=0; for(int i=0;i<16;i++){int c=base+i;if(c<hd)ba=fmaxf(ba,fabsf(sv[c])/gsa);}
|
||||
float bsf=ba/6.0f; bool z=(ba<6.0f*0.001953125f);
|
||||
if(z){out_sf[bi*(hd/16)+b]=0;}else{__nv_fp8_e4m3 o(bsf);out_sf[bi*(hd/16)+b]=*(uint8_t*)&o;}
|
||||
float ab=z?0.0f:bsf;
|
||||
// CRITICAL: quantize using the FP8-round-tripped block scale, not the raw value.
|
||||
// The dequant reads FP8, so quantize must match the dequant's scale exactly.
|
||||
// Same pattern as quantize_nvfp4.cu.
|
||||
float bs_rt=0.0f;
|
||||
if(z){out_sf[bi*(hd/16)+b]=0;}
|
||||
else{
|
||||
__nv_fp8_e4m3 o(bsf);
|
||||
out_sf[bi*(hd/16)+b]=*(uint8_t*)&o;
|
||||
bs_rt=(float)o; // round-tripped through FP8
|
||||
}
|
||||
for(int i=0;i<8;i++){
|
||||
int c0=base+2*i,c1=base+2*i+1; uint8_t lo=0,hi=0;
|
||||
if(!z&&c0<hd){float s=(sv[c0]/gsa)/ab;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;lo=half_step_to_e2m1(hs);if(s<0)lo+=8;}
|
||||
if(!z&&c1<hd){float s=(sv[c1]/gsa)/ab;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;hi=half_step_to_e2m1(hs);if(s<0)hi+=8;}
|
||||
if(!z&&c0<hd){float s=(sv[c0]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;lo=half_step_to_e2m1(hs);if(s<0)lo+=8;}
|
||||
if(!z&&c1<hd){float s=(sv[c1]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;hi=half_step_to_e2m1(hs);if(s<0)hi+=8;}
|
||||
out_fp4[bi*(hd/2)+b*8+i]=(hi<<4)|lo;
|
||||
}
|
||||
}
|
||||
@@ -156,12 +164,13 @@ __global__ void hca_compress_reduce_quant_kernel(
|
||||
int base=b*16;
|
||||
float ba=0;for(int i=0;i<16;i++){int c=base+i;if(c<hd)ba=fmaxf(ba,fabsf(sv[c])/gsa);}
|
||||
float bsf=ba/6.0f;bool z=(ba<6.0f*0.001953125f);
|
||||
if(z){out_sf[bi*(hd/16)+b]=0;}else{__nv_fp8_e4m3 o(bsf);out_sf[bi*(hd/16)+b]=*(uint8_t*)&o;}
|
||||
float ab=z?0.0f:bsf;
|
||||
float bs_rt=0.0f;
|
||||
if(z){out_sf[bi*(hd/16)+b]=0;}
|
||||
else{__nv_fp8_e4m3 o(bsf);out_sf[bi*(hd/16)+b]=*(uint8_t*)&o;bs_rt=(float)o;}
|
||||
for(int i=0;i<8;i++){
|
||||
int c0=base+2*i,c1=base+2*i+1;uint8_t lo=0,hi=0;
|
||||
if(!z&&c0<hd){float s=(sv[c0]/gsa)/ab;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;lo=half_step_to_e2m1(hs);if(s<0)lo+=8;}
|
||||
if(!z&&c1<hd){float s=(sv[c1]/gsa)/ab;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;hi=half_step_to_e2m1(hs);if(s<0)hi+=8;}
|
||||
if(!z&&c0<hd){float s=(sv[c0]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;lo=half_step_to_e2m1(hs);if(s<0)lo+=8;}
|
||||
if(!z&&c1<hd){float s=(sv[c1]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;hi=half_step_to_e2m1(hs);if(s<0)hi+=8;}
|
||||
out_fp4[bi*(hd/2)+b*8+i]=(hi<<4)|lo;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user