[Perf] Warmup FlashInfer attention during startup (#23439)

Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Matthew Bonanni <mbonanni001@gmail.com>
This commit is contained in:
Michael Goin
2025-09-10 18:03:17 -04:00
committed by GitHub
parent b5e383cd8b
commit fba7856581
3 changed files with 55 additions and 20 deletions

View File

@@ -2578,6 +2578,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
uniform_decode: bool = False,
skip_eplb: bool = False,
is_profile: bool = False,
create_mixed_batch: bool = False,
remove_lora: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
@@ -2596,6 +2597,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
uniform_decode: If True, the batch is a uniform decode batch.
skip_eplb: If True, skip EPLB state update.
is_profile: If True, this is a profile run.
create_mixed_batch: If True, create a mixed batch with both decode
(1 token) and prefill (multiple tokens) requests.
remove_lora: If False, dummy LoRAs are not destroyed after the run
"""
assert cudagraph_runtime_mode in {
@@ -2627,7 +2630,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# has num_tokens in total.
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
max_num_reqs = self.scheduler_config.max_num_seqs
if uniform_decode:
if create_mixed_batch:
assert not uniform_decode
# Create mixed batch:
# first half decode tokens, second half one prefill
num_decode_tokens = num_tokens // 2
num_prefill_tokens = num_tokens - num_decode_tokens
num_reqs = num_decode_tokens + 1
# Create decode requests (1 token each) followed by prefill request
num_scheduled_tokens_list = [1] * num_decode_tokens + [
num_prefill_tokens
]
# Note: Overriding max_query_len to be the prefill tokens
max_query_len = num_prefill_tokens
elif uniform_decode:
assert not create_mixed_batch
num_reqs = cdiv(num_tokens, max_query_len)
assert num_reqs <= max_num_reqs, \
"Do not capture num_reqs > max_num_reqs for uniform batch"
@@ -2652,8 +2670,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
attn_metadata = {}
# Make sure max_model_len is used at the graph capture time.
self.seq_lens.np[:num_reqs] = self.max_model_len
if create_mixed_batch:
# In the mixed batch mode (used for FI warmup), we use
# shorter sequence lengths to run faster.
# TODO(luka) better system for describing dummy batches
seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1]
else:
# Make sure max_model_len is used at the graph capture time.
seq_lens = self.max_model_len
self.seq_lens.np[:num_reqs] = seq_lens
self.seq_lens.np[num_reqs:] = 0
self.seq_lens.copy_to_gpu()