[Core] Support full cuda graph in v1 (#16072)

Signed-off-by: Chanh Nguyen <cnguyen@linkedin.com>
Co-authored-by: Chanh Nguyen <cnguyen@linkedin.com>
This commit is contained in:
Chanh Nguyen
2025-05-07 22:30:15 -07:00
committed by GitHub
parent 3d13ca0e24
commit 7ea2adb802
5 changed files with 190 additions and 13 deletions

View File

@@ -12,6 +12,7 @@ import torch.nn as nn
from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
@@ -139,6 +140,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
raise NotImplementedError(
"Non-Attention backend is not supported by V1 GPUModelRunner.")
if self.vllm_config.compilation_config.full_cuda_graph:
attn_backend_name = self.attn_backend.__name__
flash_attn_version = get_flash_attn_version()
if attn_backend_name != "FlashAttentionBackend" or \
flash_attn_version != 3:
raise ValueError(
f"full_cuda_graph is only supported with "
f"FA3. Current attention backend is {attn_backend_name}, "
f"FlashAttention version is {flash_attn_version}.")
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self))
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
@@ -219,6 +230,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)
self.query_start_loc = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device=self.device)
self.seq_lens = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device=self.device)
self.slot_mapping = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)
# None in the first PP rank. The rest are set after load_model.
self.intermediate_tensors: Optional[IntermediateTensors] = None
@@ -271,7 +292,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
pin_memory=self.pin_memory)
self.positions_np = self.positions_cpu.numpy()
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
@@ -589,10 +610,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.positions_cpu[:total_num_scheduled_tokens],
non_blocking=True)
query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
self.device, non_blocking=True)
seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device,
non_blocking=True)
self.query_start_loc[:num_reqs + 1].copy_(
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True)
self.slot_mapping[:total_num_scheduled_tokens].copy_(
self.slot_mapping_cpu[:total_num_scheduled_tokens],
non_blocking=True)
# Fill unused with -1. Needed for reshape_and_cache
self.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
self.seq_lens[num_reqs:].fill_(0)
self.query_start_loc[num_reqs + 1:].fill_(-1)
query_start_loc = self.query_start_loc[:num_reqs + 1]
seq_lens = self.seq_lens[:num_reqs]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, seq_lens=seq_lens)
@@ -1478,6 +1511,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _dummy_run(
self,
num_tokens: int,
skip_attn: bool = True,
) -> torch.Tensor:
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
@@ -1494,6 +1528,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32)
if skip_attn:
attn_metadata = None
else:
query_start_loc = self.query_start_loc[:num_reqs + 1]
seq_lens = self.seq_lens[:num_reqs]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, seq_lens=seq_lens)
attn_metadata = self.attn_metadata_builder.build(
num_reqs=num_tokens,
num_actual_tokens=num_tokens,
max_query_len=num_tokens,
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
model = self.model
@@ -1522,7 +1573,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for k, v in self.intermediate_tensors.items()
})
with set_forward_context(None,
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_tokens):
outputs = model(
@@ -1708,11 +1759,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with graph_capture(device=self.device):
skip_attn = not self.vllm_config.compilation_config.full_cuda_graph
for num_tokens in reversed(self.cudagraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens)
self._dummy_run(num_tokens)
self._dummy_run(num_tokens, skip_attn=skip_attn)
self._dummy_run(num_tokens, skip_attn=skip_attn)
end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]