[V1] Integrate Piecewise CUDA graphs (#10058)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2024-11-05 22:16:04 -08:00
committed by GitHub
parent 9d59b75593
commit 4089985552
3 changed files with 133 additions and 36 deletions

View File

@@ -496,7 +496,10 @@ class PiecewiseBackend:
return entry.runnable(*args) return entry.runnable(*args)
if self.is_first_graph: if self.is_first_graph:
logger.info("Capturing a cudagraph for shape %s", # Since we capture cudagraph for many different shapes and
# capturing is fast, we don't need to log it for every shape.
# We only log it in the debug mode.
logger.debug("Capturing a cudagraph for shape %s",
runtime_shape) runtime_shape)
input_addresses = [ input_addresses = [

View File

@@ -51,6 +51,7 @@ class FlashAttentionMetadata:
# |-------------------- seq_len ---------------------| # |-------------------- seq_len ---------------------|
# |-- query_len ---| # |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int max_query_len: int
query_start_loc: torch.Tensor query_start_loc: torch.Tensor
max_seq_len: int max_seq_len: int
@@ -134,7 +135,9 @@ class FlashAttentionImpl(AttentionImpl):
assert k_scale == 1.0 and v_scale == 1.0, ( assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.") "key/v_scale is not supported in FlashAttention.")
output = torch.ops.vllm.unified_flash_attention( output = torch.empty_like(query)
torch.ops.vllm.unified_flash_attention(
output,
query, query,
key, key,
value, value,
@@ -154,6 +157,7 @@ class FlashAttentionImpl(AttentionImpl):
def unified_flash_attention( def unified_flash_attention(
output: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
@@ -168,17 +172,17 @@ def unified_flash_attention(
window_size: Optional[List[int]] = None, window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None, logits_soft_cap: Optional[float] = None,
) -> torch.Tensor: ) -> None:
current_metadata = get_forward_context() current_metadata = get_forward_context()
if current_metadata is None: if current_metadata is None:
# Profiling run. # Profiling run.
return torch.empty_like(query) return
assert current_metadata is not None assert current_metadata is not None
assert isinstance(current_metadata, FlashAttentionMetadata) assert isinstance(current_metadata, FlashAttentionMetadata)
attn_metadata: FlashAttentionMetadata = current_metadata attn_metadata: FlashAttentionMetadata = current_metadata
num_actual_tokens = attn_metadata.num_actual_tokens
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, num_heads, head_size) query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_kv_heads, head_size) key = key.view(-1, num_kv_heads, head_size)
@@ -188,18 +192,18 @@ def unified_flash_attention(
key_cache = kv_cache[0] key_cache = kv_cache[0]
value_cache = kv_cache[1] value_cache = kv_cache[1]
torch.ops._C_cache_ops.reshape_and_cache_flash( torch.ops._C_cache_ops.reshape_and_cache_flash(
key, key[:num_actual_tokens],
value, value[:num_actual_tokens],
kv_cache[0], key_cache,
kv_cache[1], value_cache,
attn_metadata.slot_mapping, attn_metadata.slot_mapping,
kv_cache_dtype, kv_cache_dtype,
k_scale, k_scale,
v_scale, v_scale,
) )
output = flash_attn_varlen_func( attn_output = flash_attn_varlen_func(
q=query, q=query[:num_actual_tokens],
k=key_cache, k=key_cache,
v=value_cache, v=value_cache,
cu_seqlens_q=attn_metadata.query_start_loc, cu_seqlens_q=attn_metadata.query_start_loc,
@@ -213,10 +217,13 @@ def unified_flash_attention(
block_table=attn_metadata.block_table, block_table=attn_metadata.block_table,
softcap=logits_soft_cap, softcap=logits_soft_cap,
) )
return output.view(num_tokens, hidden_size) attn_output = attn_output.view(num_actual_tokens, -1)
# TODO(woosuk): Optimize this.
output[:num_actual_tokens].copy_(attn_output)
def unified_flash_attention_fake( def unified_flash_attention_fake(
output: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
@@ -231,13 +238,13 @@ def unified_flash_attention_fake(
window_size: Optional[List[int]] = None, window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None, logits_soft_cap: Optional[float] = None,
) -> torch.Tensor: ) -> None:
return torch.empty_like(query) return
direct_register_custom_op( direct_register_custom_op(
op_name="unified_flash_attention", op_name="unified_flash_attention",
op_func=unified_flash_attention, op_func=unified_flash_attention,
mutates_args=["kv_cache"], mutates_args=["kv_cache", "output"],
fake_impl=unified_flash_attention_fake, fake_impl=unified_flash_attention_fake,
) )

View File

@@ -1,3 +1,5 @@
import os
import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set from typing import TYPE_CHECKING, Dict, List, Optional, Set
from unittest.mock import patch from unittest.mock import patch
@@ -7,11 +9,16 @@ import torch
import torch.distributed import torch.distributed
import torch.nn as nn import torch.nn as nn
from vllm import envs
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.config import CompilationConfig
from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.plugins import set_compilation_config
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv,
is_pin_memory_available) is_pin_memory_available)
@@ -86,6 +93,18 @@ class GPUModelRunner:
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
) )
self.use_cuda_graph = (envs.VLLM_TORCH_COMPILE_LEVEL
== CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
self.cudagraph_batch_sizes = [1, 2, 4] + [i for i in range(8, 513, 8)]
self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)
def _update_states(self, scheduler_output: "SchedulerOutput") -> None: def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove stopped requests from the cached states. # Remove stopped requests from the cached states.
# Keep the states of the pre-empted requests. # Keep the states of the pre-empted requests.
@@ -268,12 +287,16 @@ class GPUModelRunner:
seq_start_loc_np[0] = 0 seq_start_loc_np[0] = 0
np.cumsum(seq_lens, out=seq_start_loc_np[1:]) np.cumsum(seq_lens, out=seq_start_loc_np[1:])
input_ids = input_ids.to(self.device, non_blocking=True) self.input_ids[:total_num_scheduled_tokens].copy_(input_ids,
positions = positions.to(self.device, non_blocking=True).long() non_blocking=True)
self.positions[:total_num_scheduled_tokens].copy_(positions,
non_blocking=True)
query_start_loc = query_start_loc.to(self.device, non_blocking=True) query_start_loc = query_start_loc.to(self.device, non_blocking=True)
seq_start_loc = seq_start_loc.to(self.device, non_blocking=True) seq_start_loc = seq_start_loc.to(self.device, non_blocking=True)
slot_mapping = slot_mapping.to(self.device, non_blocking=True).long() slot_mapping = slot_mapping.to(self.device, non_blocking=True).long()
attn_metadata = FlashAttentionMetadata( attn_metadata = FlashAttentionMetadata(
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens,
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
@@ -287,7 +310,7 @@ class GPUModelRunner:
# token from the partial request. # token from the partial request.
# TODO: Support prompt logprobs. # TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1 logits_indices = query_start_loc[1:] - 1
return input_ids, positions, attn_metadata, logits_indices return attn_metadata, logits_indices
def _prepare_sampling( def _prepare_sampling(
self, self,
@@ -310,16 +333,26 @@ class GPUModelRunner:
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput: ) -> ModelRunnerOutput:
self._update_states(scheduler_output) self._update_states(scheduler_output)
inputs = self._prepare_inputs(scheduler_output) attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
input_ids, positions, attn_metadata, logits_indices = inputs num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self._get_padded_batch_size(
num_scheduled_tokens)
else:
# Eager mode.
num_input_tokens = num_scheduled_tokens
with set_forward_context(attn_metadata): with set_forward_context(attn_metadata):
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=self.input_ids[:num_input_tokens],
positions=positions, positions=self.positions[:num_input_tokens],
kv_caches=self.kv_caches, kv_caches=self.kv_caches,
attn_metadata=attn_metadata, attn_metadata=None,
) )
hidden_states = hidden_states[:num_scheduled_tokens]
hidden_states = hidden_states[logits_indices] hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(hidden_states, None) logits = self.model.compute_logits(hidden_states, None)
@@ -371,6 +404,18 @@ class GPUModelRunner:
return model_runner_output return model_runner_output
def load_model(self) -> None: def load_model(self) -> None:
if self.use_cuda_graph:
# FIXME(woosuk): Currently, the custom ops are not supported
# in the piecewise compilation mode. We rely on TorchInductor
# to optimize the model.
os.environ["VLLM_CUSTOM_OPS"] = "none"
set_compilation_config(
CompilationConfig(
use_cudagraph=True,
non_cudagraph_ops=["vllm.unified_flash_attention"],
use_inductor=True,
))
logger.info("Starting to load model %s...", self.model_config.model) logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117 with DeviceMemoryProfiler() as m: # noqa: SIM117
with patch("vllm.model_executor.layers.sampler.Sampler", Sampler): with patch("vllm.model_executor.layers.sampler.Sampler", Sampler):
@@ -381,27 +426,62 @@ class GPUModelRunner:
self.model_memory_usage / float(2**30)) self.model_memory_usage / float(2**30))
def _dummy_run(self, model: nn.Module, num_tokens: int) -> None: def _dummy_run(self, model: nn.Module, num_tokens: int) -> None:
input_ids = torch.zeros(num_tokens, # use an empty tensor instead of `None`` to force Dynamo to pass
dtype=torch.int32, # it by reference, rather by specializing on the value `None`.
device=self.device) # the `dtype` argument does not matter, and we use `float32` as
positions = torch.zeros(num_tokens, # a placeholder (it has wide hardware support).
dtype=torch.long, # it is important to create tensors inside the loop, rather than
device=self.device) # multiplying the list, to avoid Dynamo from treating them as
kv_caches = [None for _ in range(self.num_attn_layers)] # tensor aliasing.
model(input_ids, positions, kv_caches, attn_metadata=None) dummy_kv_caches = [
return torch.tensor([], dtype=torch.float32, device=self.device)
for _ in range(self.num_attn_layers)
]
with set_forward_context(None): # noqa: SIM117
with set_compile_context(self.cudagraph_batch_sizes):
# Trigger compilation for general shape.
model(self.input_ids,
self.positions,
dummy_kv_caches,
attn_metadata=None)
@torch.inference_mode() @torch.inference_mode()
def profile_run(self) -> None: def profile_run(self) -> None:
self._dummy_run(self.model, self.max_num_tokens) self._dummy_run(self.model, self.max_num_tokens)
torch.cuda.synchronize() torch.cuda.synchronize()
return
@torch.inference_mode() @torch.inference_mode()
def capture_model(self) -> None: def capture_model(self) -> None:
# TODO: Implement CUDA graph support. if not self.use_cuda_graph:
logger.warning(
"Skipping CUDA graph capture. Please set "
"VLLM_TORCH_COMPILE_LEVEL=%d to use CUDA graphs.",
CompilationLevel.PIECEWISE)
return return
start_time = time.perf_counter()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
with set_forward_context(None):
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
for num_tokens in reversed(self.cudagraph_batch_sizes):
self.model(
self.input_ids[:num_tokens],
self.positions[:num_tokens],
kv_caches=self.kv_caches,
attn_metadata=None,
)
end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
elapsed_time = end_time - start_time
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
# This usually takes 5~20 seconds.
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, cuda_graph_size / (1 << 30))
def initialize_kv_cache(self, num_blocks: int) -> None: def initialize_kv_cache(self, num_blocks: int) -> None:
assert len(self.kv_caches) == 0 assert len(self.kv_caches) == 0
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
@@ -412,6 +492,13 @@ class GPUModelRunner:
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
device=self.device)) device=self.device))
def _get_padded_batch_size(self, batch_size: int) -> Optional[int]:
# TODO: Optimize this?
for size in self.cudagraph_batch_sizes:
if batch_size <= size:
return size
return None
@dataclass @dataclass
class CachedRequestState: class CachedRequestState: