row_max is in scale_log2 domain, need to convert to natural log domain. attn_max = row_max * ln(2), so lse = ln(row_sum) + row_max * ln(2).
row_max is in scale_log2 domain, need to convert to natural log domain. attn_max = row_max * ln(2), so lse = ln(row_sum) + row_max * ln(2).