[Perf] Warmup FlashInfer attention during startup (#23439)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Matthew Bonanni <mbonanni001@gmail.com>
This commit is contained in:
@@ -2578,6 +2578,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
uniform_decode: bool = False,
|
||||
skip_eplb: bool = False,
|
||||
is_profile: bool = False,
|
||||
create_mixed_batch: bool = False,
|
||||
remove_lora: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@@ -2596,6 +2597,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
uniform_decode: If True, the batch is a uniform decode batch.
|
||||
skip_eplb: If True, skip EPLB state update.
|
||||
is_profile: If True, this is a profile run.
|
||||
create_mixed_batch: If True, create a mixed batch with both decode
|
||||
(1 token) and prefill (multiple tokens) requests.
|
||||
remove_lora: If False, dummy LoRAs are not destroyed after the run
|
||||
"""
|
||||
assert cudagraph_runtime_mode in {
|
||||
@@ -2627,7 +2630,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# has num_tokens in total.
|
||||
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
|
||||
max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
if uniform_decode:
|
||||
if create_mixed_batch:
|
||||
assert not uniform_decode
|
||||
# Create mixed batch:
|
||||
# first half decode tokens, second half one prefill
|
||||
num_decode_tokens = num_tokens // 2
|
||||
num_prefill_tokens = num_tokens - num_decode_tokens
|
||||
num_reqs = num_decode_tokens + 1
|
||||
|
||||
# Create decode requests (1 token each) followed by prefill request
|
||||
num_scheduled_tokens_list = [1] * num_decode_tokens + [
|
||||
num_prefill_tokens
|
||||
]
|
||||
# Note: Overriding max_query_len to be the prefill tokens
|
||||
max_query_len = num_prefill_tokens
|
||||
elif uniform_decode:
|
||||
assert not create_mixed_batch
|
||||
num_reqs = cdiv(num_tokens, max_query_len)
|
||||
assert num_reqs <= max_num_reqs, \
|
||||
"Do not capture num_reqs > max_num_reqs for uniform batch"
|
||||
@@ -2652,8 +2670,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
attn_metadata = {}
|
||||
|
||||
# Make sure max_model_len is used at the graph capture time.
|
||||
self.seq_lens.np[:num_reqs] = self.max_model_len
|
||||
if create_mixed_batch:
|
||||
# In the mixed batch mode (used for FI warmup), we use
|
||||
# shorter sequence lengths to run faster.
|
||||
# TODO(luka) better system for describing dummy batches
|
||||
seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1]
|
||||
else:
|
||||
# Make sure max_model_len is used at the graph capture time.
|
||||
seq_lens = self.max_model_len
|
||||
self.seq_lens.np[:num_reqs] = seq_lens
|
||||
self.seq_lens.np[num_reqs:] = 0
|
||||
self.seq_lens.copy_to_gpu()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user