[Nvidia] Integrate SM100 cudnn prefill API to MLA prefill (#20411)
Signed-off-by: Elfie Guo <elfieg@nvidia.com> Co-authored-by: Elfie Guo <eflieg@nvidia.com>
This commit is contained in:
5
vllm/envs.py
Normal file → Executable file
5
vllm/envs.py
Normal file → Executable file
@@ -139,6 +139,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
|
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
|
||||||
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
|
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
|
||||||
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120
|
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120
|
||||||
|
VLLM_USE_CUDNN_PREFILL: bool = False
|
||||||
VLLM_LOOPBACK_IP: str = ""
|
VLLM_LOOPBACK_IP: str = ""
|
||||||
|
|
||||||
|
|
||||||
@@ -962,6 +963,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT":
|
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT":
|
||||||
lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120")),
|
lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120")),
|
||||||
|
|
||||||
|
# Controls whether or not to use cudnn prefill
|
||||||
|
"VLLM_USE_CUDNN_PREFILL":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))),
|
||||||
|
|
||||||
# If set to 1, use the TRTLLM Decode Attention backend in flashinfer.
|
# If set to 1, use the TRTLLM Decode Attention backend in flashinfer.
|
||||||
"VLLM_USE_TRTLLM_DECODE_ATTENTION":
|
"VLLM_USE_TRTLLM_DECODE_ATTENTION":
|
||||||
lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None),
|
lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None),
|
||||||
|
|||||||
113
vllm/v1/attention/backends/mla/common.py
Normal file → Executable file
113
vllm/v1/attention/backends/mla/common.py
Normal file → Executable file
@@ -194,6 +194,7 @@ from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
@@ -225,6 +226,8 @@ except ImportError:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
|
from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
|
||||||
|
from flashinfer.prefill import ( # noqa: F401
|
||||||
|
cudnn_batch_prefill_with_kv_cache)
|
||||||
flashinfer_available = True
|
flashinfer_available = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
flashinfer_available = False
|
flashinfer_available = False
|
||||||
@@ -236,6 +239,8 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
CUDNN_WORKSPACE_SIZE = 12800
|
||||||
|
|
||||||
|
|
||||||
class MLACommonBackend(AttentionBackend):
|
class MLACommonBackend(AttentionBackend):
|
||||||
|
|
||||||
@@ -294,6 +299,7 @@ class MLACommonPrefillMetadata:
|
|||||||
starts: torch.Tensor
|
starts: torch.Tensor
|
||||||
seq_tot: list[int]
|
seq_tot: list[int]
|
||||||
max_seq_lens: list[int]
|
max_seq_lens: list[int]
|
||||||
|
seq_lens: torch.Tensor
|
||||||
workspace: torch.Tensor
|
workspace: torch.Tensor
|
||||||
|
|
||||||
block_table: torch.Tensor
|
block_table: torch.Tensor
|
||||||
@@ -309,6 +315,17 @@ class FlashInferPrefillMetadata(MLACommonPrefillMetadata):
|
|||||||
default_factory=list)
|
default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CudnnPrefillMetadata(MLACommonPrefillMetadata):
|
||||||
|
|
||||||
|
class ChunkedContextMetadata(
|
||||||
|
MLACommonPrefillMetadata.ChunkedContextMetadata):
|
||||||
|
seq_lens: torch.Tensor
|
||||||
|
|
||||||
|
query_seq_lens: Optional[torch.Tensor] = None
|
||||||
|
cudnn_workspace: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MLACommonDecodeMetadata:
|
class MLACommonDecodeMetadata:
|
||||||
block_table: torch.Tensor
|
block_table: torch.Tensor
|
||||||
@@ -351,7 +368,8 @@ class MLACommonMetadata(Generic[D]):
|
|||||||
|
|
||||||
decode: Optional[D] = None
|
decode: Optional[D] = None
|
||||||
prefill: Optional[Union[MLACommonPrefillMetadata,
|
prefill: Optional[Union[MLACommonPrefillMetadata,
|
||||||
FlashInferPrefillMetadata]] = None
|
FlashInferPrefillMetadata,
|
||||||
|
CudnnPrefillMetadata]] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.head_dim is not None:
|
if self.head_dim is not None:
|
||||||
@@ -362,13 +380,19 @@ M = TypeVar("M", bound=MLACommonMetadata)
|
|||||||
|
|
||||||
|
|
||||||
def use_flashinfer_prefill() -> bool:
|
def use_flashinfer_prefill() -> bool:
|
||||||
if flashinfer_available:
|
if flashinfer_available and not envs.VLLM_USE_CUDNN_PREFILL:
|
||||||
# For blackwell default to flashinfer prefill if its available since
|
# For blackwell default to flashinfer prefill if its available since
|
||||||
# its faster than FA2.
|
# its faster than FA2.
|
||||||
return current_platform.has_device_capability(100)
|
return current_platform.has_device_capability(100)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def use_cudnn_prefill() -> bool:
|
||||||
|
if flashinfer_available and envs.VLLM_USE_CUDNN_PREFILL:
|
||||||
|
return current_platform.has_device_capability(100)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
# Currently 394MB, this can be tuned based on GEMM sizes used.
|
# Currently 394MB, this can be tuned based on GEMM sizes used.
|
||||||
# Choosen to be the same as sglang:
|
# Choosen to be the same as sglang:
|
||||||
# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37
|
# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37
|
||||||
@@ -427,11 +451,15 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
dtype=model_config.dtype,
|
dtype=model_config.dtype,
|
||||||
device=runner.device,
|
device=runner.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.block_table = block_table
|
self.block_table = block_table
|
||||||
|
|
||||||
|
self._use_cudnn_prefill = use_cudnn_prefill()
|
||||||
self._use_fi_prefill = use_flashinfer_prefill()
|
self._use_fi_prefill = use_flashinfer_prefill()
|
||||||
self.prefill_metadata_cls = FlashInferPrefillMetadata \
|
self.prefill_metadata_cls = (
|
||||||
if self._use_fi_prefill else MLACommonPrefillMetadata
|
FlashInferPrefillMetadata
|
||||||
|
if self._use_fi_prefill else CudnnPrefillMetadata
|
||||||
|
if self._use_cudnn_prefill else MLACommonPrefillMetadata)
|
||||||
|
|
||||||
if self._use_fi_prefill:
|
if self._use_fi_prefill:
|
||||||
self._workspace_buffer = torch.empty(
|
self._workspace_buffer = torch.empty(
|
||||||
@@ -447,6 +475,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
self._global_hyperparameters = infer_global_hyperparameters(
|
self._global_hyperparameters = infer_global_hyperparameters(
|
||||||
get_per_layer_parameters(runner.vllm_config, MLACommonImpl))
|
get_per_layer_parameters(runner.vllm_config, MLACommonImpl))
|
||||||
|
|
||||||
|
if self._use_cudnn_prefill:
|
||||||
|
self.cudnn_workspace = torch.empty(
|
||||||
|
CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs,
|
||||||
|
dtype=torch.int8,
|
||||||
|
device=runner.device,
|
||||||
|
)
|
||||||
|
|
||||||
def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
|
def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
|
||||||
qo_indptr = prefill.query_start_loc
|
qo_indptr = prefill.query_start_loc
|
||||||
|
|
||||||
@@ -692,15 +727,24 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
out=cu_seq_lens_cpu[:, 1:],
|
out=cu_seq_lens_cpu[:, 1:],
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
chunked_context_metadata_cls = \
|
||||||
|
CudnnPrefillMetadata.ChunkedContextMetadata \
|
||||||
|
if self._use_cudnn_prefill else \
|
||||||
|
MLACommonPrefillMetadata.ChunkedContextMetadata
|
||||||
|
|
||||||
chunked_context_metadata = \
|
chunked_context_metadata = \
|
||||||
MLACommonPrefillMetadata.ChunkedContextMetadata(
|
chunked_context_metadata_cls(
|
||||||
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
||||||
starts=chunk_starts.to(device, non_blocking=True),
|
starts=chunk_starts.to(device, non_blocking=True),
|
||||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||||
|
seq_lens=chunk_seq_lens,
|
||||||
workspace=self.chunked_prefill_workspace,
|
workspace=self.chunked_prefill_workspace,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self._use_cudnn_prefill:
|
||||||
|
chunked_context_metadata.seq_lens = chunk_seq_lens
|
||||||
|
|
||||||
assert max(chunked_context_metadata.max_seq_lens) <= \
|
assert max(chunked_context_metadata.max_seq_lens) <= \
|
||||||
self.chunked_prefill_workspace_size
|
self.chunked_prefill_workspace_size
|
||||||
|
|
||||||
@@ -711,6 +755,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
chunked_context=chunked_context_metadata,
|
chunked_context=chunked_context_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self._use_cudnn_prefill:
|
||||||
|
assert isinstance(prefill_metadata, CudnnPrefillMetadata)
|
||||||
|
prefill_metadata.query_seq_lens = prefill_query_start_loc[1:] \
|
||||||
|
- prefill_query_start_loc[:-1]
|
||||||
|
prefill_metadata.cudnn_workspace = self.cudnn_workspace
|
||||||
|
|
||||||
decode_metadata = None
|
decode_metadata = None
|
||||||
if self._num_decodes > 0:
|
if self._num_decodes > 0:
|
||||||
decode_metadata = self._build_decode(
|
decode_metadata = self._build_decode(
|
||||||
@@ -794,6 +844,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi
|
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi
|
||||||
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi
|
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi
|
||||||
self._pad_v = False
|
self._pad_v = False
|
||||||
|
elif use_cudnn_prefill():
|
||||||
|
logger.debug_once("Using CUDNN prefill for MLA")
|
||||||
|
self._run_prefill_context_chunk = \
|
||||||
|
self._run_prefill_context_chunk_cudnn
|
||||||
|
self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn
|
||||||
|
self._pad_v = False
|
||||||
else: # Use FlashAttention
|
else: # Use FlashAttention
|
||||||
logger.debug_once("Using FlashAttention prefill for MLA")
|
logger.debug_once("Using FlashAttention prefill for MLA")
|
||||||
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa
|
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa
|
||||||
@@ -882,6 +938,29 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
return_lse=return_softmax_lse,
|
return_lse=return_softmax_lse,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _run_prefill_new_tokens_cudnn(self, prefill: MLACommonPrefillMetadata,
|
||||||
|
q, k, v, return_softmax_lse):
|
||||||
|
assert isinstance(prefill, CudnnPrefillMetadata)
|
||||||
|
assert prefill.query_seq_lens is not None
|
||||||
|
output, lse = cudnn_batch_prefill_with_kv_cache(
|
||||||
|
q=q,
|
||||||
|
k_cache=k,
|
||||||
|
v_cache=v,
|
||||||
|
scale=self.scale,
|
||||||
|
workspace_buffer=prefill.cudnn_workspace,
|
||||||
|
max_token_per_sequence=prefill.max_query_len,
|
||||||
|
max_sequence_kv=prefill.max_query_len,
|
||||||
|
actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1),
|
||||||
|
actual_seq_lens_kv=prefill.query_seq_lens.view(-1, 1, 1, 1),
|
||||||
|
causal=True,
|
||||||
|
return_lse=True, # do not support False for now
|
||||||
|
is_cuda_graph_compatible=
|
||||||
|
True, #Indicates actual_seq_lens are on GPU or CPU.
|
||||||
|
)
|
||||||
|
if return_softmax_lse:
|
||||||
|
return output, lse
|
||||||
|
return output
|
||||||
|
|
||||||
def _run_prefill_context_chunk_fa(self, prefill: MLACommonPrefillMetadata,
|
def _run_prefill_context_chunk_fa(self, prefill: MLACommonPrefillMetadata,
|
||||||
chunk_idx: int, q, k, v):
|
chunk_idx: int, q, k, v):
|
||||||
assert prefill.chunked_context is not None
|
assert prefill.chunked_context is not None
|
||||||
@@ -908,6 +987,30 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
return_lse=True,
|
return_lse=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _run_prefill_context_chunk_cudnn(self,
|
||||||
|
prefill: MLACommonPrefillMetadata,
|
||||||
|
chunk_idx: int, q, k, v):
|
||||||
|
assert isinstance(prefill, CudnnPrefillMetadata)
|
||||||
|
assert prefill.chunked_context is not None
|
||||||
|
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
|
||||||
|
assert prefill.query_seq_lens is not None
|
||||||
|
return cudnn_batch_prefill_with_kv_cache(
|
||||||
|
q=q,
|
||||||
|
k_cache=k,
|
||||||
|
v_cache=v,
|
||||||
|
scale=self.scale,
|
||||||
|
workspace_buffer=prefill.cudnn_workspace,
|
||||||
|
max_token_per_sequence=prefill.max_query_len,
|
||||||
|
max_sequence_kv=prefill.chunked_context.max_seq_lens[chunk_idx],
|
||||||
|
actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1),
|
||||||
|
actual_seq_lens_kv=prefill.chunked_context.seq_lens[chunk_idx].
|
||||||
|
view(-1, 1, 1, 1),
|
||||||
|
causal=False,
|
||||||
|
return_lse=True,
|
||||||
|
is_cuda_graph_compatible=
|
||||||
|
True, #Indicates actual_seq_lens are on GPU or CPU.
|
||||||
|
)
|
||||||
|
|
||||||
def _v_up_proj(self, x):
|
def _v_up_proj(self, x):
|
||||||
# Convert from (B, N, L) to (N, B, L)
|
# Convert from (B, N, L) to (N, B, L)
|
||||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||||
|
|||||||
Reference in New Issue
Block a user