fix: correct swa_len_local calculation per segment for D5c multi-tile
This commit is contained in:
@@ -173,40 +173,39 @@ def test_d5c_multitile():
|
||||
|
||||
# Run each segment
|
||||
segment_results = []
|
||||
n_comp_global = n_comp
|
||||
swa_len_global = swa_len # number of valid SWA tokens (relative to SWA region start at n_comp)
|
||||
|
||||
for seg_idx in range(n_segs):
|
||||
k_start = seg_idx * seg_size
|
||||
k_end = min(k_start + seg_size, n_total)
|
||||
k_seg = k_combined[k_start:k_end]
|
||||
v_seg = v_combined[k_start:k_end]
|
||||
|
||||
# Determine sink bias and swa_len for this segment
|
||||
# Sink bias applies to positions >= n_comp (absolute)
|
||||
# In this segment, the first local position maps to absolute position k_start
|
||||
# So sink bias applies to local positions >= max(0, n_comp - k_start)
|
||||
n_comp_local = max(0, n_comp - k_start) # compressed tokens in this segment
|
||||
# swa_len: how many SWA positions are valid in this segment
|
||||
# Absolute valid range: n_comp to n_comp + swa_len - 1
|
||||
swa_start_abs = n_comp
|
||||
swa_end_abs = n_comp + swa_len
|
||||
# This segment covers absolute positions k_start to k_end - 1
|
||||
# Valid SWA in this segment: max(swa_start_abs, k_start) to min(swa_end_abs, k_end) - 1
|
||||
# swa_len_local: number of valid SWA positions in this segment
|
||||
valid_start = max(swa_start_abs, k_start)
|
||||
valid_end = min(swa_end_abs, k_end)
|
||||
if valid_end > valid_start and n_comp_local > 0:
|
||||
# SWA positions in this segment: n_comp_local to n_comp_local + (valid_end - valid_start) - 1
|
||||
swa_len_local = n_comp_local + (valid_end - valid_start)
|
||||
elif valid_end <= valid_start:
|
||||
swa_len_local = 0 # No valid SWA in this segment
|
||||
else:
|
||||
swa_len_local = swa_len # All SWA is valid
|
||||
|
||||
# If no compressed tokens and no SWA tokens in this segment, skip
|
||||
if k_seg.shape[0] == 0:
|
||||
continue
|
||||
|
||||
# For segments that are entirely compressed (no SWA), don't apply sink bias
|
||||
has_swa = (k_start + seg_size) > n_comp
|
||||
# n_comp_local: compressed KV tokens in this segment
|
||||
# If segment starts before n_comp_global, some positions are compressed.
|
||||
# If segment starts at or after n_comp_global, all positions are SWA.
|
||||
if k_start < n_comp_global:
|
||||
n_comp_local = min(n_comp_global - k_start, seg_size)
|
||||
else:
|
||||
n_comp_local = 0
|
||||
|
||||
# swa_len_local: the kernel masks kv_pos >= n_comp_local + swa_len_local
|
||||
# We want to mask absolute positions >= n_comp_global + swa_len_global
|
||||
# So: n_comp_local + swa_len_local = n_comp_global + swa_len_global - k_start
|
||||
# => swa_len_local = n_comp_global + swa_len_global - k_start - n_comp_local
|
||||
swa_len_local = n_comp_global + swa_len_global - k_start - n_comp_local
|
||||
# Clamp: if > remaining segment size, no masking needed
|
||||
swa_len_local = min(swa_len_local, seg_size)
|
||||
# If <= 0, all SWA positions are masked (no valid SWA)
|
||||
if swa_len_local <= 0 and n_comp_local == 0:
|
||||
continue # Skip empty segment
|
||||
|
||||
# Whether this segment has SWA positions (needs sink bias)
|
||||
has_swa = (k_start + seg_size) > n_comp_global
|
||||
|
||||
o_norm, lse = run_segment(
|
||||
q, k_seg, v_seg, kernel, compiled, stream,
|
||||
|
||||
Reference in New Issue
Block a user