diff --git a/tests/unit/test_d5c_multitile.py b/tests/unit/test_d5c_multitile.py index b84659ca..f6aad639 100644 --- a/tests/unit/test_d5c_multitile.py +++ b/tests/unit/test_d5c_multitile.py @@ -173,40 +173,39 @@ def test_d5c_multitile(): # Run each segment segment_results = [] + n_comp_global = n_comp + swa_len_global = swa_len # number of valid SWA tokens (relative to SWA region start at n_comp) + for seg_idx in range(n_segs): k_start = seg_idx * seg_size k_end = min(k_start + seg_size, n_total) k_seg = k_combined[k_start:k_end] v_seg = v_combined[k_start:k_end] - # Determine sink bias and swa_len for this segment - # Sink bias applies to positions >= n_comp (absolute) - # In this segment, the first local position maps to absolute position k_start - # So sink bias applies to local positions >= max(0, n_comp - k_start) - n_comp_local = max(0, n_comp - k_start) # compressed tokens in this segment - # swa_len: how many SWA positions are valid in this segment - # Absolute valid range: n_comp to n_comp + swa_len - 1 - swa_start_abs = n_comp - swa_end_abs = n_comp + swa_len - # This segment covers absolute positions k_start to k_end - 1 - # Valid SWA in this segment: max(swa_start_abs, k_start) to min(swa_end_abs, k_end) - 1 - # swa_len_local: number of valid SWA positions in this segment - valid_start = max(swa_start_abs, k_start) - valid_end = min(swa_end_abs, k_end) - if valid_end > valid_start and n_comp_local > 0: - # SWA positions in this segment: n_comp_local to n_comp_local + (valid_end - valid_start) - 1 - swa_len_local = n_comp_local + (valid_end - valid_start) - elif valid_end <= valid_start: - swa_len_local = 0 # No valid SWA in this segment - else: - swa_len_local = swa_len # All SWA is valid - - # If no compressed tokens and no SWA tokens in this segment, skip if k_seg.shape[0] == 0: continue - # For segments that are entirely compressed (no SWA), don't apply sink bias - has_swa = (k_start + seg_size) > n_comp + # n_comp_local: compressed KV tokens in this segment + # If segment starts before n_comp_global, some positions are compressed. + # If segment starts at or after n_comp_global, all positions are SWA. + if k_start < n_comp_global: + n_comp_local = min(n_comp_global - k_start, seg_size) + else: + n_comp_local = 0 + + # swa_len_local: the kernel masks kv_pos >= n_comp_local + swa_len_local + # We want to mask absolute positions >= n_comp_global + swa_len_global + # So: n_comp_local + swa_len_local = n_comp_global + swa_len_global - k_start + # => swa_len_local = n_comp_global + swa_len_global - k_start - n_comp_local + swa_len_local = n_comp_global + swa_len_global - k_start - n_comp_local + # Clamp: if > remaining segment size, no masking needed + swa_len_local = min(swa_len_local, seg_size) + # If <= 0, all SWA positions are masked (no valid SWA) + if swa_len_local <= 0 and n_comp_local == 0: + continue # Skip empty segment + + # Whether this segment has SWA positions (needs sink bias) + has_swa = (k_start + seg_size) > n_comp_global o_norm, lse = run_segment( q, k_seg, v_seg, kernel, compiled, stream,