[Attention] Support distinguishing between short extends and decodes (#37303)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2026-03-20 10:49:36 -07:00
committed by GitHub
parent 79eb9369c5
commit e1d85e5c24
9 changed files with 176 additions and 133 deletions

View File

@@ -362,6 +362,11 @@ class CommonAttentionMetadata:
dcp_local_seq_lens_cpu: torch.Tensor | None = None
"""Sequence lengths of the local rank in decode context parallelism world"""
is_prefilling: torch.Tensor | None = None
"""(batch_size,) bool tensor: True if request is still in prefill phase
(num_computed_tokens < num_prompt_tokens). Used by some backends to
distinguish actual decodes from short extends."""
# WARNING: Deprecated fields. Will be removed in a future release (v0.15.0)
_seq_lens_cpu: torch.Tensor | None = None
_num_computed_tokens_cpu: torch.Tensor | None = None
@@ -443,6 +448,7 @@ class CommonAttentionMetadata:
encoder_seq_lens_cpu=maybe_slice_reqs(self.encoder_seq_lens_cpu),
dcp_local_seq_lens=maybe_slice_reqs(self.dcp_local_seq_lens),
dcp_local_seq_lens_cpu=maybe_slice_reqs(self.dcp_local_seq_lens_cpu),
is_prefilling=maybe_slice_reqs(self.is_prefilling),
)

View File

@@ -358,7 +358,9 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=decode_threshold
common_attn_metadata,
decode_threshold=decode_threshold,
treat_short_extends_as_decodes=False,
)
)

View File

@@ -489,11 +489,15 @@ def split_decodes_and_prefills(
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,
require_uniform: bool = False,
treat_short_extends_as_decodes: bool = True,
) -> tuple[int, int, int, int]:
"""
Assuming a reordered batch, finds the boundary between prefill and decode
requests.
The batch is expected to be ordered as:
decode → short_extend → long_extend → prefill
Args:
common_attn_metadata: CommonAttentionMetadata object containing the
batch metadata.
@@ -501,6 +505,9 @@ def split_decodes_and_prefills(
require_uniform: If True, requires that all decode requests have the
same query length. When set, some queries may be considered prefills
even if they are <= decode_threshold, in order to ensure uniformity.
treat_short_extends_as_decodes: If True (default), short extends
(query_len <= threshold but still prefilling) are counted as
decodes. If False, they are counted as prefills.
Returns:
num_decodes: The number of decode requests.
@@ -513,8 +520,10 @@ def split_decodes_and_prefills(
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc_cpu
if max_query_len <= decode_threshold and (
not require_uniform or decode_threshold <= 1
if (
max_query_len <= decode_threshold
and (not require_uniform or decode_threshold <= 1)
and treat_short_extends_as_decodes
):
return num_reqs, 0, num_tokens, 0
@@ -533,11 +542,14 @@ def split_decodes_and_prefills(
else:
is_prefill = query_lens > decode_threshold
if not treat_short_extends_as_decodes:
assert common_attn_metadata.is_prefilling is not None
is_prefill |= common_attn_metadata.is_prefilling
if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0
first_prefill = is_prefill.int().argmax(dim=-1).item()
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
num_decodes = first_prefill
num_prefills = num_reqs - num_decodes
num_decode_tokens = query_start_loc[first_prefill].item()
@@ -581,39 +593,52 @@ def reorder_batch_to_split_decodes_and_prefills(
Reorders the batch to split into prefill and decode requests; places all
requests with <= decode_threshold tokens at the front of the batch.
The batch is reordered into 4 regions:
decode: (num_scheduled <= threshold AND is not prefilling)
short_extend: (num_scheduled <= threshold AND is chunked prefilling)
long_extend: (num_scheduled > threshold AND is chunked prefilling)
prefill: (num_computed == 0) # First chunks
Returns:
True if the batch was modified, False otherwise.
"""
# We now want to reorder the batch into decode → extend → prefill order
# where:
# decode: request with num_scheduled_tokens <= decode_threshold
# extend: non-decode request with existing context
# prefill: non-decode request with no existing context
# NOTE for now we loosely use "decode" to mean requests where attention is
# likely memory-bound and "prefill" to mean requests where attention is
# likely compute-bound,
num_reqs = len(input_batch.req_ids)
num_scheduled_tokens = [
scheduler_output.num_scheduled_tokens[id] for id in input_batch.req_ids
]
num_scheduled_tokens_np = np.array(num_scheduled_tokens)
num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs]
num_prompt_tokens_np = input_batch.num_prompt_tokens[:num_reqs]
is_prefill = num_computed_tokens_np == 0
is_decode = (num_scheduled_tokens_np <= decode_threshold) & (~is_prefill)
is_extend = (num_scheduled_tokens_np > decode_threshold) & (~is_prefill)
has_context = num_computed_tokens_np > 0
is_below_threshold = num_scheduled_tokens_np <= decode_threshold
done_prefilling = num_computed_tokens_np >= num_prompt_tokens_np
# Desired order: decode → extend → prefill
req_regions = np.zeros(is_decode.shape, dtype=np.int32) # 0 = decode by default
req_regions[is_extend] = 1
req_regions[is_prefill] = 2
# Mutually exclusive categories (exactly one True per request):
# 1. No context yet -> prefill
# 2. Has context, above threshold -> long_extend
# 3. Has context, below threshold, still prefilling -> short_extend
# 4. Has context, below threshold, done prefilling -> decode
is_pure_prefill = ~has_context
is_long_extend = has_context & ~is_below_threshold
is_short_extend = has_context & is_below_threshold & ~done_prefilling
is_decode = has_context & is_below_threshold & done_prefilling
# Desired order: decode → short_extend → long_extend → prefill
req_regions = np.zeros(num_reqs, dtype=np.int32) # 0 = decode by default
req_regions[is_short_extend] = 1
req_regions[is_long_extend] = 2
req_regions[is_pure_prefill] = 3
num_decodes = int(is_decode.sum())
num_extends = int(is_extend.sum())
num_short_extends = int(is_short_extend.sum())
num_long_extends = int(is_long_extend.sum())
num_prefills = int(is_pure_prefill.sum())
target_regions = np.zeros(num_reqs, dtype=np.int32)
target_regions[num_decodes : num_decodes + num_extends] = 1
target_regions[num_decodes + num_extends :] = 2
target_regions = np.repeat(
[0, 1, 2, 3],
[num_decodes, num_short_extends, num_long_extends, num_prefills],
).astype(np.int32)
needs_swap = req_regions != target_regions