[Core] Support full cuda graph in v1 (#16072)
Signed-off-by: Chanh Nguyen <cnguyen@linkedin.com> Co-authored-by: Chanh Nguyen <cnguyen@linkedin.com>
This commit is contained in:
@@ -12,6 +12,7 @@ import torch.nn as nn
|
||||
|
||||
from vllm.attention import AttentionType, get_attn_backend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
from vllm.config import (CompilationLevel, VllmConfig,
|
||||
get_layers_from_vllm_config)
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
@@ -139,6 +140,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
raise NotImplementedError(
|
||||
"Non-Attention backend is not supported by V1 GPUModelRunner.")
|
||||
|
||||
if self.vllm_config.compilation_config.full_cuda_graph:
|
||||
attn_backend_name = self.attn_backend.__name__
|
||||
flash_attn_version = get_flash_attn_version()
|
||||
if attn_backend_name != "FlashAttentionBackend" or \
|
||||
flash_attn_version != 3:
|
||||
raise ValueError(
|
||||
f"full_cuda_graph is only supported with "
|
||||
f"FA3. Current attention backend is {attn_backend_name}, "
|
||||
f"FlashAttention version is {flash_attn_version}.")
|
||||
|
||||
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
||||
weakref.proxy(self))
|
||||
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
|
||||
@@ -219,6 +230,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.positions = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
self.query_start_loc = torch.zeros(self.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.seq_lens = torch.zeros(self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.slot_mapping = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
|
||||
# None in the first PP rank. The rest are set after load_model.
|
||||
self.intermediate_tensors: Optional[IntermediateTensors] = None
|
||||
|
||||
@@ -271,7 +292,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
pin_memory=self.pin_memory)
|
||||
self.positions_np = self.positions_cpu.numpy()
|
||||
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||
@@ -589,10 +610,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.positions_cpu[:total_num_scheduled_tokens],
|
||||
non_blocking=True)
|
||||
|
||||
query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
|
||||
self.device, non_blocking=True)
|
||||
seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device,
|
||||
non_blocking=True)
|
||||
self.query_start_loc[:num_reqs + 1].copy_(
|
||||
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
|
||||
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
|
||||
non_blocking=True)
|
||||
self.slot_mapping[:total_num_scheduled_tokens].copy_(
|
||||
self.slot_mapping_cpu[:total_num_scheduled_tokens],
|
||||
non_blocking=True)
|
||||
|
||||
# Fill unused with -1. Needed for reshape_and_cache
|
||||
self.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
|
||||
self.seq_lens[num_reqs:].fill_(0)
|
||||
self.query_start_loc[num_reqs + 1:].fill_(-1)
|
||||
|
||||
query_start_loc = self.query_start_loc[:num_reqs + 1]
|
||||
seq_lens = self.seq_lens[:num_reqs]
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
||||
|
||||
@@ -1478,6 +1511,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
def _dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
skip_attn: bool = True,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
||||
@@ -1494,6 +1528,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
||||
dtype=np.int32)
|
||||
|
||||
if skip_attn:
|
||||
attn_metadata = None
|
||||
else:
|
||||
query_start_loc = self.query_start_loc[:num_reqs + 1]
|
||||
seq_lens = self.seq_lens[:num_reqs]
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
||||
|
||||
attn_metadata = self.attn_metadata_builder.build(
|
||||
num_reqs=num_tokens,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=num_tokens,
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
|
||||
with self.maybe_dummy_run_with_lora(self.lora_config,
|
||||
num_scheduled_tokens):
|
||||
model = self.model
|
||||
@@ -1522,7 +1573,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
for k, v in self.intermediate_tensors.items()
|
||||
})
|
||||
|
||||
with set_forward_context(None,
|
||||
with set_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens):
|
||||
outputs = model(
|
||||
@@ -1708,11 +1759,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Capture the large shapes first so that the smaller shapes
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
with graph_capture(device=self.device):
|
||||
skip_attn = not self.vllm_config.compilation_config.full_cuda_graph
|
||||
for num_tokens in reversed(self.cudagraph_batch_sizes):
|
||||
for _ in range(self.vllm_config.compilation_config.
|
||||
cudagraph_num_of_warmups):
|
||||
self._dummy_run(num_tokens)
|
||||
self._dummy_run(num_tokens)
|
||||
self._dummy_run(num_tokens, skip_attn=skip_attn)
|
||||
self._dummy_run(num_tokens, skip_attn=skip_attn)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
Reference in New Issue
Block a user