[V1] Integrate Piecewise CUDA graphs (#10058)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -496,7 +496,10 @@ class PiecewiseBackend:
|
|||||||
return entry.runnable(*args)
|
return entry.runnable(*args)
|
||||||
|
|
||||||
if self.is_first_graph:
|
if self.is_first_graph:
|
||||||
logger.info("Capturing a cudagraph for shape %s",
|
# Since we capture cudagraph for many different shapes and
|
||||||
|
# capturing is fast, we don't need to log it for every shape.
|
||||||
|
# We only log it in the debug mode.
|
||||||
|
logger.debug("Capturing a cudagraph for shape %s",
|
||||||
runtime_shape)
|
runtime_shape)
|
||||||
|
|
||||||
input_addresses = [
|
input_addresses = [
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ class FlashAttentionMetadata:
|
|||||||
# |-------------------- seq_len ---------------------|
|
# |-------------------- seq_len ---------------------|
|
||||||
# |-- query_len ---|
|
# |-- query_len ---|
|
||||||
|
|
||||||
|
num_actual_tokens: int # Number of tokens excluding padding.
|
||||||
max_query_len: int
|
max_query_len: int
|
||||||
query_start_loc: torch.Tensor
|
query_start_loc: torch.Tensor
|
||||||
max_seq_len: int
|
max_seq_len: int
|
||||||
@@ -134,7 +135,9 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
assert k_scale == 1.0 and v_scale == 1.0, (
|
assert k_scale == 1.0 and v_scale == 1.0, (
|
||||||
"key/v_scale is not supported in FlashAttention.")
|
"key/v_scale is not supported in FlashAttention.")
|
||||||
|
|
||||||
output = torch.ops.vllm.unified_flash_attention(
|
output = torch.empty_like(query)
|
||||||
|
torch.ops.vllm.unified_flash_attention(
|
||||||
|
output,
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
@@ -154,6 +157,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
|
|
||||||
def unified_flash_attention(
|
def unified_flash_attention(
|
||||||
|
output: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
@@ -168,17 +172,17 @@ def unified_flash_attention(
|
|||||||
window_size: Optional[List[int]] = None,
|
window_size: Optional[List[int]] = None,
|
||||||
alibi_slopes: Optional[torch.Tensor] = None,
|
alibi_slopes: Optional[torch.Tensor] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
) -> torch.Tensor:
|
) -> None:
|
||||||
current_metadata = get_forward_context()
|
current_metadata = get_forward_context()
|
||||||
if current_metadata is None:
|
if current_metadata is None:
|
||||||
# Profiling run.
|
# Profiling run.
|
||||||
return torch.empty_like(query)
|
return
|
||||||
|
|
||||||
assert current_metadata is not None
|
assert current_metadata is not None
|
||||||
assert isinstance(current_metadata, FlashAttentionMetadata)
|
assert isinstance(current_metadata, FlashAttentionMetadata)
|
||||||
attn_metadata: FlashAttentionMetadata = current_metadata
|
attn_metadata: FlashAttentionMetadata = current_metadata
|
||||||
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
|
|
||||||
num_tokens, hidden_size = query.shape
|
|
||||||
# Reshape the query, key, and value tensors.
|
# Reshape the query, key, and value tensors.
|
||||||
query = query.view(-1, num_heads, head_size)
|
query = query.view(-1, num_heads, head_size)
|
||||||
key = key.view(-1, num_kv_heads, head_size)
|
key = key.view(-1, num_kv_heads, head_size)
|
||||||
@@ -188,18 +192,18 @@ def unified_flash_attention(
|
|||||||
key_cache = kv_cache[0]
|
key_cache = kv_cache[0]
|
||||||
value_cache = kv_cache[1]
|
value_cache = kv_cache[1]
|
||||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||||
key,
|
key[:num_actual_tokens],
|
||||||
value,
|
value[:num_actual_tokens],
|
||||||
kv_cache[0],
|
key_cache,
|
||||||
kv_cache[1],
|
value_cache,
|
||||||
attn_metadata.slot_mapping,
|
attn_metadata.slot_mapping,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
k_scale,
|
k_scale,
|
||||||
v_scale,
|
v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = flash_attn_varlen_func(
|
attn_output = flash_attn_varlen_func(
|
||||||
q=query,
|
q=query[:num_actual_tokens],
|
||||||
k=key_cache,
|
k=key_cache,
|
||||||
v=value_cache,
|
v=value_cache,
|
||||||
cu_seqlens_q=attn_metadata.query_start_loc,
|
cu_seqlens_q=attn_metadata.query_start_loc,
|
||||||
@@ -213,10 +217,13 @@ def unified_flash_attention(
|
|||||||
block_table=attn_metadata.block_table,
|
block_table=attn_metadata.block_table,
|
||||||
softcap=logits_soft_cap,
|
softcap=logits_soft_cap,
|
||||||
)
|
)
|
||||||
return output.view(num_tokens, hidden_size)
|
attn_output = attn_output.view(num_actual_tokens, -1)
|
||||||
|
# TODO(woosuk): Optimize this.
|
||||||
|
output[:num_actual_tokens].copy_(attn_output)
|
||||||
|
|
||||||
|
|
||||||
def unified_flash_attention_fake(
|
def unified_flash_attention_fake(
|
||||||
|
output: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
@@ -231,13 +238,13 @@ def unified_flash_attention_fake(
|
|||||||
window_size: Optional[List[int]] = None,
|
window_size: Optional[List[int]] = None,
|
||||||
alibi_slopes: Optional[torch.Tensor] = None,
|
alibi_slopes: Optional[torch.Tensor] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
) -> torch.Tensor:
|
) -> None:
|
||||||
return torch.empty_like(query)
|
return
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="unified_flash_attention",
|
op_name="unified_flash_attention",
|
||||||
op_func=unified_flash_attention,
|
op_func=unified_flash_attention,
|
||||||
mutates_args=["kv_cache"],
|
mutates_args=["kv_cache", "output"],
|
||||||
fake_impl=unified_flash_attention_fake,
|
fake_impl=unified_flash_attention_fake,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set
|
from typing import TYPE_CHECKING, Dict, List, Optional, Set
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
@@ -7,11 +9,16 @@ import torch
|
|||||||
import torch.distributed
|
import torch.distributed
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
|
from vllm.compilation.compile_context import set_compile_context
|
||||||
|
from vllm.compilation.config import CompilationConfig
|
||||||
|
from vllm.compilation.levels import CompilationLevel
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.multimodal import MultiModalDataDict
|
from vllm.multimodal import MultiModalDataDict
|
||||||
|
from vllm.plugins import set_compilation_config
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv,
|
||||||
is_pin_memory_available)
|
is_pin_memory_available)
|
||||||
@@ -86,6 +93,18 @@ class GPUModelRunner:
|
|||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.use_cuda_graph = (envs.VLLM_TORCH_COMPILE_LEVEL
|
||||||
|
== CompilationLevel.PIECEWISE
|
||||||
|
and not self.model_config.enforce_eager)
|
||||||
|
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
|
||||||
|
self.cudagraph_batch_sizes = [1, 2, 4] + [i for i in range(8, 513, 8)]
|
||||||
|
self.input_ids = torch.zeros(self.max_num_tokens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
self.positions = torch.zeros(self.max_num_tokens,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||||
# Remove stopped requests from the cached states.
|
# Remove stopped requests from the cached states.
|
||||||
# Keep the states of the pre-empted requests.
|
# Keep the states of the pre-empted requests.
|
||||||
@@ -268,12 +287,16 @@ class GPUModelRunner:
|
|||||||
seq_start_loc_np[0] = 0
|
seq_start_loc_np[0] = 0
|
||||||
np.cumsum(seq_lens, out=seq_start_loc_np[1:])
|
np.cumsum(seq_lens, out=seq_start_loc_np[1:])
|
||||||
|
|
||||||
input_ids = input_ids.to(self.device, non_blocking=True)
|
self.input_ids[:total_num_scheduled_tokens].copy_(input_ids,
|
||||||
positions = positions.to(self.device, non_blocking=True).long()
|
non_blocking=True)
|
||||||
|
self.positions[:total_num_scheduled_tokens].copy_(positions,
|
||||||
|
non_blocking=True)
|
||||||
|
|
||||||
query_start_loc = query_start_loc.to(self.device, non_blocking=True)
|
query_start_loc = query_start_loc.to(self.device, non_blocking=True)
|
||||||
seq_start_loc = seq_start_loc.to(self.device, non_blocking=True)
|
seq_start_loc = seq_start_loc.to(self.device, non_blocking=True)
|
||||||
slot_mapping = slot_mapping.to(self.device, non_blocking=True).long()
|
slot_mapping = slot_mapping.to(self.device, non_blocking=True).long()
|
||||||
attn_metadata = FlashAttentionMetadata(
|
attn_metadata = FlashAttentionMetadata(
|
||||||
|
num_actual_tokens=total_num_scheduled_tokens,
|
||||||
max_query_len=max_num_scheduled_tokens,
|
max_query_len=max_num_scheduled_tokens,
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
@@ -287,7 +310,7 @@ class GPUModelRunner:
|
|||||||
# token from the partial request.
|
# token from the partial request.
|
||||||
# TODO: Support prompt logprobs.
|
# TODO: Support prompt logprobs.
|
||||||
logits_indices = query_start_loc[1:] - 1
|
logits_indices = query_start_loc[1:] - 1
|
||||||
return input_ids, positions, attn_metadata, logits_indices
|
return attn_metadata, logits_indices
|
||||||
|
|
||||||
def _prepare_sampling(
|
def _prepare_sampling(
|
||||||
self,
|
self,
|
||||||
@@ -310,16 +333,26 @@ class GPUModelRunner:
|
|||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
) -> ModelRunnerOutput:
|
) -> ModelRunnerOutput:
|
||||||
self._update_states(scheduler_output)
|
self._update_states(scheduler_output)
|
||||||
inputs = self._prepare_inputs(scheduler_output)
|
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
|
||||||
input_ids, positions, attn_metadata, logits_indices = inputs
|
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
|
if (self.use_cuda_graph
|
||||||
|
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
|
||||||
|
# Use piecewise CUDA graphs.
|
||||||
|
# Add padding to the batch size.
|
||||||
|
num_input_tokens = self._get_padded_batch_size(
|
||||||
|
num_scheduled_tokens)
|
||||||
|
else:
|
||||||
|
# Eager mode.
|
||||||
|
num_input_tokens = num_scheduled_tokens
|
||||||
|
|
||||||
with set_forward_context(attn_metadata):
|
with set_forward_context(attn_metadata):
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=self.input_ids[:num_input_tokens],
|
||||||
positions=positions,
|
positions=self.positions[:num_input_tokens],
|
||||||
kv_caches=self.kv_caches,
|
kv_caches=self.kv_caches,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=None,
|
||||||
)
|
)
|
||||||
|
hidden_states = hidden_states[:num_scheduled_tokens]
|
||||||
hidden_states = hidden_states[logits_indices]
|
hidden_states = hidden_states[logits_indices]
|
||||||
logits = self.model.compute_logits(hidden_states, None)
|
logits = self.model.compute_logits(hidden_states, None)
|
||||||
|
|
||||||
@@ -371,6 +404,18 @@ class GPUModelRunner:
|
|||||||
return model_runner_output
|
return model_runner_output
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
|
if self.use_cuda_graph:
|
||||||
|
# FIXME(woosuk): Currently, the custom ops are not supported
|
||||||
|
# in the piecewise compilation mode. We rely on TorchInductor
|
||||||
|
# to optimize the model.
|
||||||
|
os.environ["VLLM_CUSTOM_OPS"] = "none"
|
||||||
|
set_compilation_config(
|
||||||
|
CompilationConfig(
|
||||||
|
use_cudagraph=True,
|
||||||
|
non_cudagraph_ops=["vllm.unified_flash_attention"],
|
||||||
|
use_inductor=True,
|
||||||
|
))
|
||||||
|
|
||||||
logger.info("Starting to load model %s...", self.model_config.model)
|
logger.info("Starting to load model %s...", self.model_config.model)
|
||||||
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
||||||
with patch("vllm.model_executor.layers.sampler.Sampler", Sampler):
|
with patch("vllm.model_executor.layers.sampler.Sampler", Sampler):
|
||||||
@@ -381,27 +426,62 @@ class GPUModelRunner:
|
|||||||
self.model_memory_usage / float(2**30))
|
self.model_memory_usage / float(2**30))
|
||||||
|
|
||||||
def _dummy_run(self, model: nn.Module, num_tokens: int) -> None:
|
def _dummy_run(self, model: nn.Module, num_tokens: int) -> None:
|
||||||
input_ids = torch.zeros(num_tokens,
|
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||||
dtype=torch.int32,
|
# it by reference, rather by specializing on the value `None`.
|
||||||
device=self.device)
|
# the `dtype` argument does not matter, and we use `float32` as
|
||||||
positions = torch.zeros(num_tokens,
|
# a placeholder (it has wide hardware support).
|
||||||
dtype=torch.long,
|
# it is important to create tensors inside the loop, rather than
|
||||||
device=self.device)
|
# multiplying the list, to avoid Dynamo from treating them as
|
||||||
kv_caches = [None for _ in range(self.num_attn_layers)]
|
# tensor aliasing.
|
||||||
model(input_ids, positions, kv_caches, attn_metadata=None)
|
dummy_kv_caches = [
|
||||||
return
|
torch.tensor([], dtype=torch.float32, device=self.device)
|
||||||
|
for _ in range(self.num_attn_layers)
|
||||||
|
]
|
||||||
|
with set_forward_context(None): # noqa: SIM117
|
||||||
|
with set_compile_context(self.cudagraph_batch_sizes):
|
||||||
|
# Trigger compilation for general shape.
|
||||||
|
model(self.input_ids,
|
||||||
|
self.positions,
|
||||||
|
dummy_kv_caches,
|
||||||
|
attn_metadata=None)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def profile_run(self) -> None:
|
def profile_run(self) -> None:
|
||||||
self._dummy_run(self.model, self.max_num_tokens)
|
self._dummy_run(self.model, self.max_num_tokens)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
return
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def capture_model(self) -> None:
|
def capture_model(self) -> None:
|
||||||
# TODO: Implement CUDA graph support.
|
if not self.use_cuda_graph:
|
||||||
|
logger.warning(
|
||||||
|
"Skipping CUDA graph capture. Please set "
|
||||||
|
"VLLM_TORCH_COMPILE_LEVEL=%d to use CUDA graphs.",
|
||||||
|
CompilationLevel.PIECEWISE)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||||
|
|
||||||
|
with set_forward_context(None):
|
||||||
|
# Trigger CUDA graph capture for specific shapes.
|
||||||
|
# Capture the large shapes first so that the smaller shapes
|
||||||
|
# can reuse the memory pool allocated for the large shapes.
|
||||||
|
for num_tokens in reversed(self.cudagraph_batch_sizes):
|
||||||
|
self.model(
|
||||||
|
self.input_ids[:num_tokens],
|
||||||
|
self.positions[:num_tokens],
|
||||||
|
kv_caches=self.kv_caches,
|
||||||
|
attn_metadata=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||||
|
elapsed_time = end_time - start_time
|
||||||
|
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
|
||||||
|
# This usually takes 5~20 seconds.
|
||||||
|
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
|
||||||
|
elapsed_time, cuda_graph_size / (1 << 30))
|
||||||
|
|
||||||
def initialize_kv_cache(self, num_blocks: int) -> None:
|
def initialize_kv_cache(self, num_blocks: int) -> None:
|
||||||
assert len(self.kv_caches) == 0
|
assert len(self.kv_caches) == 0
|
||||||
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
|
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
|
||||||
@@ -412,6 +492,13 @@ class GPUModelRunner:
|
|||||||
dtype=self.kv_cache_dtype,
|
dtype=self.kv_cache_dtype,
|
||||||
device=self.device))
|
device=self.device))
|
||||||
|
|
||||||
|
def _get_padded_batch_size(self, batch_size: int) -> Optional[int]:
|
||||||
|
# TODO: Optimize this?
|
||||||
|
for size in self.cudagraph_batch_sizes:
|
||||||
|
if batch_size <= size:
|
||||||
|
return size
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CachedRequestState:
|
class CachedRequestState:
|
||||||
|
|||||||
Reference in New Issue
Block a user