fix: correct swa_len_local calculation per segment for D5c multi-tile

This commit is contained in:
2026-05-26 15:22:03 +00:00
parent 3abcc7ff09
commit 2efd15c852

View File

@@ -173,40 +173,39 @@ def test_d5c_multitile():
# Run each segment
segment_results = []
n_comp_global = n_comp
swa_len_global = swa_len # number of valid SWA tokens (relative to SWA region start at n_comp)
for seg_idx in range(n_segs):
k_start = seg_idx * seg_size
k_end = min(k_start + seg_size, n_total)
k_seg = k_combined[k_start:k_end]
v_seg = v_combined[k_start:k_end]
# Determine sink bias and swa_len for this segment
# Sink bias applies to positions >= n_comp (absolute)
# In this segment, the first local position maps to absolute position k_start
# So sink bias applies to local positions >= max(0, n_comp - k_start)
n_comp_local = max(0, n_comp - k_start) # compressed tokens in this segment
# swa_len: how many SWA positions are valid in this segment
# Absolute valid range: n_comp to n_comp + swa_len - 1
swa_start_abs = n_comp
swa_end_abs = n_comp + swa_len
# This segment covers absolute positions k_start to k_end - 1
# Valid SWA in this segment: max(swa_start_abs, k_start) to min(swa_end_abs, k_end) - 1
# swa_len_local: number of valid SWA positions in this segment
valid_start = max(swa_start_abs, k_start)
valid_end = min(swa_end_abs, k_end)
if valid_end > valid_start and n_comp_local > 0:
# SWA positions in this segment: n_comp_local to n_comp_local + (valid_end - valid_start) - 1
swa_len_local = n_comp_local + (valid_end - valid_start)
elif valid_end <= valid_start:
swa_len_local = 0 # No valid SWA in this segment
else:
swa_len_local = swa_len # All SWA is valid
# If no compressed tokens and no SWA tokens in this segment, skip
if k_seg.shape[0] == 0:
continue
# For segments that are entirely compressed (no SWA), don't apply sink bias
has_swa = (k_start + seg_size) > n_comp
# n_comp_local: compressed KV tokens in this segment
# If segment starts before n_comp_global, some positions are compressed.
# If segment starts at or after n_comp_global, all positions are SWA.
if k_start < n_comp_global:
n_comp_local = min(n_comp_global - k_start, seg_size)
else:
n_comp_local = 0
# swa_len_local: the kernel masks kv_pos >= n_comp_local + swa_len_local
# We want to mask absolute positions >= n_comp_global + swa_len_global
# So: n_comp_local + swa_len_local = n_comp_global + swa_len_global - k_start
# => swa_len_local = n_comp_global + swa_len_global - k_start - n_comp_local
swa_len_local = n_comp_global + swa_len_global - k_start - n_comp_local
# Clamp: if > remaining segment size, no masking needed
swa_len_local = min(swa_len_local, seg_size)
# If <= 0, all SWA positions are masked (no valid SWA)
if swa_len_local <= 0 and n_comp_local == 0:
continue # Skip empty segment
# Whether this segment has SWA positions (needs sink bias)
has_swa = (k_start + seg_size) > n_comp_global
o_norm, lse = run_segment(
q, k_seg, v_seg, kernel, compiled, stream,