KV-1: Fix FP8 round-trip mismatch in fused quantize

CRITICAL: quantize must use the FP8-round-tripped block scale, not the raw
pre-FP8 value. The dequant reads the FP8 bytes back, so the quantize must
match exactly. Same pattern as quantize_nvfp4.cu. This was the root cause
of cos=0.925 (should be ~0.995).
This commit is contained in:
2026-06-02 09:46:32 +00:00
parent d74ff5768d
commit 0fefadedd4

View File

@@ -104,12 +104,20 @@ __global__ void csa_compress_reduce_quant_kernel(
int base=b*16;
float ba=0; for(int i=0;i<16;i++){int c=base+i;if(c<hd)ba=fmaxf(ba,fabsf(sv[c])/gsa);}
float bsf=ba/6.0f; bool z=(ba<6.0f*0.001953125f);
if(z){out_sf[bi*(hd/16)+b]=0;}else{__nv_fp8_e4m3 o(bsf);out_sf[bi*(hd/16)+b]=*(uint8_t*)&o;}
float ab=z?0.0f:bsf;
// CRITICAL: quantize using the FP8-round-tripped block scale, not the raw value.
// The dequant reads FP8, so quantize must match the dequant's scale exactly.
// Same pattern as quantize_nvfp4.cu.
float bs_rt=0.0f;
if(z){out_sf[bi*(hd/16)+b]=0;}
else{
__nv_fp8_e4m3 o(bsf);
out_sf[bi*(hd/16)+b]=*(uint8_t*)&o;
bs_rt=(float)o; // round-tripped through FP8
}
for(int i=0;i<8;i++){
int c0=base+2*i,c1=base+2*i+1; uint8_t lo=0,hi=0;
if(!z&&c0<hd){float s=(sv[c0]/gsa)/ab;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;lo=half_step_to_e2m1(hs);if(s<0)lo+=8;}
if(!z&&c1<hd){float s=(sv[c1]/gsa)/ab;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;hi=half_step_to_e2m1(hs);if(s<0)hi+=8;}
if(!z&&c0<hd){float s=(sv[c0]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;lo=half_step_to_e2m1(hs);if(s<0)lo+=8;}
if(!z&&c1<hd){float s=(sv[c1]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;hi=half_step_to_e2m1(hs);if(s<0)hi+=8;}
out_fp4[bi*(hd/2)+b*8+i]=(hi<<4)|lo;
}
}
@@ -156,12 +164,13 @@ __global__ void hca_compress_reduce_quant_kernel(
int base=b*16;
float ba=0;for(int i=0;i<16;i++){int c=base+i;if(c<hd)ba=fmaxf(ba,fabsf(sv[c])/gsa);}
float bsf=ba/6.0f;bool z=(ba<6.0f*0.001953125f);
if(z){out_sf[bi*(hd/16)+b]=0;}else{__nv_fp8_e4m3 o(bsf);out_sf[bi*(hd/16)+b]=*(uint8_t*)&o;}
float ab=z?0.0f:bsf;
float bs_rt=0.0f;
if(z){out_sf[bi*(hd/16)+b]=0;}
else{__nv_fp8_e4m3 o(bsf);out_sf[bi*(hd/16)+b]=*(uint8_t*)&o;bs_rt=(float)o;}
for(int i=0;i<8;i++){
int c0=base+2*i,c1=base+2*i+1;uint8_t lo=0,hi=0;
if(!z&&c0<hd){float s=(sv[c0]/gsa)/ab;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;lo=half_step_to_e2m1(hs);if(s<0)lo+=8;}
if(!z&&c1<hd){float s=(sv[c1]/gsa)/ab;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;hi=half_step_to_e2m1(hs);if(s<0)hi+=8;}
if(!z&&c0<hd){float s=(sv[c0]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;lo=half_step_to_e2m1(hs);if(s<0)lo+=8;}
if(!z&&c1<hd){float s=(sv[c1]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;hi=half_step_to_e2m1(hs);if(s<0)hi+=8;}
out_fp4[bi*(hd/2)+b*8+i]=(hi<<4)|lo;
}
}