[Attention] Support distinguishing between short extends and decodes (#37303)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -362,6 +362,11 @@ class CommonAttentionMetadata:
|
||||
dcp_local_seq_lens_cpu: torch.Tensor | None = None
|
||||
"""Sequence lengths of the local rank in decode context parallelism world"""
|
||||
|
||||
is_prefilling: torch.Tensor | None = None
|
||||
"""(batch_size,) bool tensor: True if request is still in prefill phase
|
||||
(num_computed_tokens < num_prompt_tokens). Used by some backends to
|
||||
distinguish actual decodes from short extends."""
|
||||
|
||||
# WARNING: Deprecated fields. Will be removed in a future release (v0.15.0)
|
||||
_seq_lens_cpu: torch.Tensor | None = None
|
||||
_num_computed_tokens_cpu: torch.Tensor | None = None
|
||||
@@ -443,6 +448,7 @@ class CommonAttentionMetadata:
|
||||
encoder_seq_lens_cpu=maybe_slice_reqs(self.encoder_seq_lens_cpu),
|
||||
dcp_local_seq_lens=maybe_slice_reqs(self.dcp_local_seq_lens),
|
||||
dcp_local_seq_lens_cpu=maybe_slice_reqs(self.dcp_local_seq_lens_cpu),
|
||||
is_prefilling=maybe_slice_reqs(self.is_prefilling),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -358,7 +358,9 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=decode_threshold
|
||||
common_attn_metadata,
|
||||
decode_threshold=decode_threshold,
|
||||
treat_short_extends_as_decodes=False,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -489,11 +489,15 @@ def split_decodes_and_prefills(
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
decode_threshold: int = 1,
|
||||
require_uniform: bool = False,
|
||||
treat_short_extends_as_decodes: bool = True,
|
||||
) -> tuple[int, int, int, int]:
|
||||
"""
|
||||
Assuming a reordered batch, finds the boundary between prefill and decode
|
||||
requests.
|
||||
|
||||
The batch is expected to be ordered as:
|
||||
decode → short_extend → long_extend → prefill
|
||||
|
||||
Args:
|
||||
common_attn_metadata: CommonAttentionMetadata object containing the
|
||||
batch metadata.
|
||||
@@ -501,6 +505,9 @@ def split_decodes_and_prefills(
|
||||
require_uniform: If True, requires that all decode requests have the
|
||||
same query length. When set, some queries may be considered prefills
|
||||
even if they are <= decode_threshold, in order to ensure uniformity.
|
||||
treat_short_extends_as_decodes: If True (default), short extends
|
||||
(query_len <= threshold but still prefilling) are counted as
|
||||
decodes. If False, they are counted as prefills.
|
||||
|
||||
Returns:
|
||||
num_decodes: The number of decode requests.
|
||||
@@ -513,8 +520,10 @@ def split_decodes_and_prefills(
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
query_start_loc = common_attn_metadata.query_start_loc_cpu
|
||||
|
||||
if max_query_len <= decode_threshold and (
|
||||
not require_uniform or decode_threshold <= 1
|
||||
if (
|
||||
max_query_len <= decode_threshold
|
||||
and (not require_uniform or decode_threshold <= 1)
|
||||
and treat_short_extends_as_decodes
|
||||
):
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
|
||||
@@ -533,11 +542,14 @@ def split_decodes_and_prefills(
|
||||
else:
|
||||
is_prefill = query_lens > decode_threshold
|
||||
|
||||
if not treat_short_extends_as_decodes:
|
||||
assert common_attn_metadata.is_prefilling is not None
|
||||
is_prefill |= common_attn_metadata.is_prefilling
|
||||
|
||||
if not torch.any(is_prefill):
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
|
||||
first_prefill = is_prefill.int().argmax(dim=-1).item()
|
||||
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
|
||||
num_decodes = first_prefill
|
||||
num_prefills = num_reqs - num_decodes
|
||||
num_decode_tokens = query_start_loc[first_prefill].item()
|
||||
@@ -581,39 +593,52 @@ def reorder_batch_to_split_decodes_and_prefills(
|
||||
Reorders the batch to split into prefill and decode requests; places all
|
||||
requests with <= decode_threshold tokens at the front of the batch.
|
||||
|
||||
The batch is reordered into 4 regions:
|
||||
decode: (num_scheduled <= threshold AND is not prefilling)
|
||||
short_extend: (num_scheduled <= threshold AND is chunked prefilling)
|
||||
long_extend: (num_scheduled > threshold AND is chunked prefilling)
|
||||
prefill: (num_computed == 0) # First chunks
|
||||
|
||||
Returns:
|
||||
True if the batch was modified, False otherwise.
|
||||
"""
|
||||
# We now want to reorder the batch into decode → extend → prefill order
|
||||
# where:
|
||||
# decode: request with num_scheduled_tokens <= decode_threshold
|
||||
# extend: non-decode request with existing context
|
||||
# prefill: non-decode request with no existing context
|
||||
# NOTE for now we loosely use "decode" to mean requests where attention is
|
||||
# likely memory-bound and "prefill" to mean requests where attention is
|
||||
# likely compute-bound,
|
||||
num_reqs = len(input_batch.req_ids)
|
||||
num_scheduled_tokens = [
|
||||
scheduler_output.num_scheduled_tokens[id] for id in input_batch.req_ids
|
||||
]
|
||||
num_scheduled_tokens_np = np.array(num_scheduled_tokens)
|
||||
num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs]
|
||||
num_prompt_tokens_np = input_batch.num_prompt_tokens[:num_reqs]
|
||||
|
||||
is_prefill = num_computed_tokens_np == 0
|
||||
is_decode = (num_scheduled_tokens_np <= decode_threshold) & (~is_prefill)
|
||||
is_extend = (num_scheduled_tokens_np > decode_threshold) & (~is_prefill)
|
||||
has_context = num_computed_tokens_np > 0
|
||||
is_below_threshold = num_scheduled_tokens_np <= decode_threshold
|
||||
done_prefilling = num_computed_tokens_np >= num_prompt_tokens_np
|
||||
|
||||
# Desired order: decode → extend → prefill
|
||||
req_regions = np.zeros(is_decode.shape, dtype=np.int32) # 0 = decode by default
|
||||
req_regions[is_extend] = 1
|
||||
req_regions[is_prefill] = 2
|
||||
# Mutually exclusive categories (exactly one True per request):
|
||||
# 1. No context yet -> prefill
|
||||
# 2. Has context, above threshold -> long_extend
|
||||
# 3. Has context, below threshold, still prefilling -> short_extend
|
||||
# 4. Has context, below threshold, done prefilling -> decode
|
||||
is_pure_prefill = ~has_context
|
||||
is_long_extend = has_context & ~is_below_threshold
|
||||
is_short_extend = has_context & is_below_threshold & ~done_prefilling
|
||||
is_decode = has_context & is_below_threshold & done_prefilling
|
||||
|
||||
# Desired order: decode → short_extend → long_extend → prefill
|
||||
req_regions = np.zeros(num_reqs, dtype=np.int32) # 0 = decode by default
|
||||
req_regions[is_short_extend] = 1
|
||||
req_regions[is_long_extend] = 2
|
||||
req_regions[is_pure_prefill] = 3
|
||||
|
||||
num_decodes = int(is_decode.sum())
|
||||
num_extends = int(is_extend.sum())
|
||||
num_short_extends = int(is_short_extend.sum())
|
||||
num_long_extends = int(is_long_extend.sum())
|
||||
num_prefills = int(is_pure_prefill.sum())
|
||||
|
||||
target_regions = np.zeros(num_reqs, dtype=np.int32)
|
||||
target_regions[num_decodes : num_decodes + num_extends] = 1
|
||||
target_regions[num_decodes + num_extends :] = 2
|
||||
target_regions = np.repeat(
|
||||
[0, 1, 2, 3],
|
||||
[num_decodes, num_short_extends, num_long_extends, num_prefills],
|
||||
).astype(np.int32)
|
||||
|
||||
needs_swap = req_regions != target_regions
|
||||
|
||||
|
||||
Reference in New Issue
Block a user