[Perf] Refactor cudagraph_support to enable full CUDA graphs for spec decoding with FlashInfer (#28479)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
a742134cc5
commit
304419576a
@@ -177,8 +177,9 @@ The following table lists backends that support full CUDA Graphs at the time of
|
|||||||
| FlashAttention v3 | `ALWAYS` | has unified routine for both batches, so `FULL` mode is good |
|
| FlashAttention v3 | `ALWAYS` | has unified routine for both batches, so `FULL` mode is good |
|
||||||
| Triton Attention | `ALWAYS` | prefer `FULL_AND_PIECEWISE` since it has different kernels for prefill/mixed and pure decode batches |
|
| Triton Attention | `ALWAYS` | prefer `FULL_AND_PIECEWISE` since it has different kernels for prefill/mixed and pure decode batches |
|
||||||
| AITER FlashAttention | `UNIFORM_BATCH`| |
|
| AITER FlashAttention | `UNIFORM_BATCH`| |
|
||||||
| FlashInfer | `UNIFORM_SINGLE_TOKEN_DECODE` | |
|
| FlashInfer | `UNIFORM_SINGLE_TOKEN_DECODE` | Will be set to `UNIFORM_BATCH` when using TRTLLM attention on Blackwell |
|
||||||
| FlashMLA | `UNIFORM_BATCH` | |
|
| FlashMLA | `UNIFORM_BATCH` | |
|
||||||
|
| FlashInferMLA | `UNIFORM_BATCH` | |
|
||||||
| AITER MLA | `UNIFORM_SINGLE_TOKEN_DECODE` | |
|
| AITER MLA | `UNIFORM_SINGLE_TOKEN_DECODE` | |
|
||||||
| CUTLASS MLA | `UNIFORM_SINGLE_TOKEN_DECODE` | |
|
| CUTLASS MLA | `UNIFORM_SINGLE_TOKEN_DECODE` | |
|
||||||
| Mamba attention| `UNIFORM_SINGLE_TOKEN_DECODE` | |
|
| Mamba attention| `UNIFORM_SINGLE_TOKEN_DECODE` | |
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ def create_chunked_local_attention_backend(
|
|||||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||||
|
|
||||||
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
|
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
|
||||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
|
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
|
||||||
|
|
||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -207,7 +207,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
|||||||
# to FULL_AND_PIECEWISE.
|
# to FULL_AND_PIECEWISE.
|
||||||
# TODO(luka, lucas): audit FA2 as part of:
|
# TODO(luka, lucas): audit FA2 as part of:
|
||||||
# https://github.com/vllm-project/vllm/issues/22945
|
# https://github.com/vllm-project/vllm/issues/22945
|
||||||
cudagraph_support = (
|
_cudagraph_support = (
|
||||||
AttentionCGSupport.ALWAYS
|
AttentionCGSupport.ALWAYS
|
||||||
if get_flash_attn_version() == 3
|
if get_flash_attn_version() == 3
|
||||||
else AttentionCGSupport.UNIFORM_BATCH
|
else AttentionCGSupport.UNIFORM_BATCH
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from flashinfer import (
|
|||||||
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
|
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
|
||||||
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
||||||
from flashinfer.utils import FP4Tensor
|
from flashinfer.utils import FP4Tensor
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.backends.abstract import (
|
from vllm.attention.backends.abstract import (
|
||||||
@@ -274,10 +275,6 @@ class FlashInferMetadata:
|
|||||||
|
|
||||||
|
|
||||||
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||||
cudagraph_support: ClassVar[AttentionCGSupport] = (
|
|
||||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
|
||||||
)
|
|
||||||
|
|
||||||
reorder_batch_threshold: int = 1
|
reorder_batch_threshold: int = 1
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -355,6 +352,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
else:
|
else:
|
||||||
self.q_data_type = self.model_config.dtype
|
self.q_data_type = self.model_config.dtype
|
||||||
|
|
||||||
|
# Prefer TRTLLM attention for decoding in all cases.
|
||||||
|
# This allows us to use AttentionCGSupport.UNIFORM_BATCH mode.
|
||||||
|
self.use_trtllm_decode_attention = can_use_trtllm
|
||||||
self._init_reorder_batch_threshold(1, supports_spec_as_decode=can_use_trtllm)
|
self._init_reorder_batch_threshold(1, supports_spec_as_decode=can_use_trtllm)
|
||||||
|
|
||||||
self._cascade_wrapper = None # Wrapper for cascade attention
|
self._cascade_wrapper = None # Wrapper for cascade attention
|
||||||
@@ -412,6 +412,24 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
"passing --block-size 32 or --block-size 64."
|
"passing --block-size 32 or --block-size 64."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@override
|
||||||
|
def get_cudagraph_support(
|
||||||
|
cls: type["FlashInferMetadataBuilder"],
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
kv_cache_spec: AttentionSpec,
|
||||||
|
) -> AttentionCGSupport:
|
||||||
|
has_trtllm_support = can_use_trtllm_attention(
|
||||||
|
num_qo_heads=vllm_config.model_config.get_num_attention_heads(
|
||||||
|
vllm_config.parallel_config
|
||||||
|
),
|
||||||
|
num_kv_heads=kv_cache_spec.num_kv_heads,
|
||||||
|
)
|
||||||
|
if has_trtllm_support:
|
||||||
|
return AttentionCGSupport.UNIFORM_BATCH
|
||||||
|
else:
|
||||||
|
return AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||||
|
|
||||||
def _get_workspace_buffer(self):
|
def _get_workspace_buffer(self):
|
||||||
if self._workspace_buffer is None:
|
if self._workspace_buffer is None:
|
||||||
buffer_size = envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE
|
buffer_size = envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE
|
||||||
@@ -573,17 +591,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
has_sinks=self.has_sinks,
|
has_sinks=self.has_sinks,
|
||||||
has_spec=uses_spec_reorder,
|
has_spec=uses_spec_reorder,
|
||||||
)
|
)
|
||||||
decode_use_trtllm = use_trtllm_attention(
|
decode_use_trtllm = self.use_trtllm_decode_attention
|
||||||
self.num_qo_heads,
|
|
||||||
self.num_kv_heads,
|
|
||||||
num_decode_tokens,
|
|
||||||
max_seq_len,
|
|
||||||
self.cache_dtype,
|
|
||||||
self.q_data_type,
|
|
||||||
is_prefill=False,
|
|
||||||
has_sinks=self.has_sinks,
|
|
||||||
has_spec=uses_spec_reorder,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not (prefill_use_trtllm and decode_use_trtllm):
|
if not (prefill_use_trtllm and decode_use_trtllm):
|
||||||
if self.has_sinks:
|
if self.has_sinks:
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ class GDNAttentionMetadata:
|
|||||||
|
|
||||||
|
|
||||||
class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]):
|
class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]):
|
||||||
cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
|
_cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
|
||||||
|
|
||||||
reorder_batch_threshold: int = 1
|
reorder_batch_threshold: int = 1
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ M = TypeVar("M")
|
|||||||
|
|
||||||
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||||
reorder_batch_threshold: int = 1
|
reorder_batch_threshold: int = 1
|
||||||
cudagraph_support: ClassVar[AttentionCGSupport] = (
|
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||||
# enable full CUDA Graph support for decode-only capture
|
# enable full CUDA Graph support for decode-only capture
|
||||||
cudagraph_support: ClassVar[AttentionCGSupport] = (
|
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
|
|||||||
|
|
||||||
|
|
||||||
class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
|
class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
|
||||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN
|
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN
|
||||||
reorder_batch_threshold: int = 512 # process small prefills with decode pathway
|
reorder_batch_threshold: int = 512 # process small prefills with decode pathway
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
|
|||||||
|
|
||||||
|
|
||||||
class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
|
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
|
|||||||
|
|
||||||
|
|
||||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
|
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
|
||||||
reorder_batch_threshold: int = 128 # process small prefills with decode pathway
|
reorder_batch_threshold: int = 128 # process small prefills with decode pathway
|
||||||
# ^ TODO(matt): tune this
|
# ^ TODO(matt): tune this
|
||||||
|
|||||||
@@ -248,7 +248,7 @@ def triton_convert_req_index_to_global_index(
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]):
|
class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]):
|
||||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -206,7 +206,7 @@ def split_prefill_chunks(
|
|||||||
|
|
||||||
|
|
||||||
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||||
cudagraph_support: ClassVar[AttentionCGSupport] = (
|
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
|||||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||||
# TODO(luka, lucas): audit this as part of:
|
# TODO(luka, lucas): audit this as part of:
|
||||||
# https://github.com/vllm-project/vllm/issues/22945
|
# https://github.com/vllm-project/vllm/issues/22945
|
||||||
cudagraph_support: ClassVar[AttentionCGSupport] = (
|
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -251,7 +251,7 @@ class AiterFlashAttentionMetadata:
|
|||||||
class AiterFlashAttentionMetadataBuilder(
|
class AiterFlashAttentionMetadataBuilder(
|
||||||
AttentionMetadataBuilder[AiterFlashAttentionMetadata]
|
AttentionMetadataBuilder[AiterFlashAttentionMetadata]
|
||||||
):
|
):
|
||||||
cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
_cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||||
reorder_batch_threshold: int = 1
|
reorder_batch_threshold: int = 1
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ class RocmAttentionMetadata:
|
|||||||
|
|
||||||
|
|
||||||
class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadata]):
|
class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadata]):
|
||||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
|
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ class TritonAttentionMetadata:
|
|||||||
|
|
||||||
|
|
||||||
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
|
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
|
||||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
|
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -244,7 +244,8 @@ class AttentionCGSupport(enum.Enum):
|
|||||||
|
|
||||||
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||||
# Does this backend/builder support CUDA Graphs for attention (default: no).
|
# Does this backend/builder support CUDA Graphs for attention (default: no).
|
||||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
|
# Do not access directly. Call get_cudagraph_support() instead.
|
||||||
|
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
|
||||||
# Does this backend/builder reorder the batch?
|
# Does this backend/builder reorder the batch?
|
||||||
# If not, set this to None. Otherwise set it to the query
|
# If not, set this to None. Otherwise set it to the query
|
||||||
# length that will be pulled into the front of the batch.
|
# length that will be pulled into the front of the batch.
|
||||||
@@ -263,6 +264,15 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
|||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cudagraph_support(
|
||||||
|
cls: type["AttentionMetadataBuilder"],
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
kv_cache_spec: AttentionSpec,
|
||||||
|
) -> AttentionCGSupport:
|
||||||
|
"""Get the cudagraph support level of this builder class."""
|
||||||
|
return cls._cudagraph_support
|
||||||
|
|
||||||
def _init_reorder_batch_threshold(
|
def _init_reorder_batch_threshold(
|
||||||
self,
|
self,
|
||||||
reorder_batch_threshold: int | None = 1,
|
reorder_batch_threshold: int | None = 1,
|
||||||
|
|||||||
@@ -4167,14 +4167,16 @@ class GPUModelRunner(
|
|||||||
return attn_groups
|
return attn_groups
|
||||||
|
|
||||||
attention_backend_maps = []
|
attention_backend_maps = []
|
||||||
attention_backend_set: set[type[AttentionBackend]] = set()
|
attention_backend_list = []
|
||||||
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
||||||
attn_backends = get_attn_backends_for_group(kv_cache_group_spec)
|
attn_backends = get_attn_backends_for_group(kv_cache_group_spec)
|
||||||
attention_backend_maps.append(attn_backends[0])
|
attention_backend_maps.append(attn_backends[0])
|
||||||
attention_backend_set.update(attn_backends[1])
|
attention_backend_list.append(attn_backends[1])
|
||||||
|
|
||||||
# Resolve cudagraph_mode before actually initialize metadata_builders
|
# Resolve cudagraph_mode before actually initialize metadata_builders
|
||||||
self._check_and_update_cudagraph_mode(attention_backend_set)
|
self._check_and_update_cudagraph_mode(
|
||||||
|
attention_backend_list, kv_cache_config.kv_cache_groups
|
||||||
|
)
|
||||||
|
|
||||||
for i, attn_backend_map in enumerate(attention_backend_maps):
|
for i, attn_backend_map in enumerate(attention_backend_maps):
|
||||||
self.attn_groups.append(create_attn_groups(attn_backend_map, i))
|
self.attn_groups.append(create_attn_groups(attn_backend_map, i))
|
||||||
@@ -4203,22 +4205,31 @@ class GPUModelRunner(
|
|||||||
self.calculate_reorder_batch_threshold()
|
self.calculate_reorder_batch_threshold()
|
||||||
|
|
||||||
def _check_and_update_cudagraph_mode(
|
def _check_and_update_cudagraph_mode(
|
||||||
self, attention_backends: set[type[AttentionBackend]]
|
self,
|
||||||
|
attention_backends: list[set[type[AttentionBackend]]],
|
||||||
|
kv_cache_groups: list[KVCacheGroupSpec],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Resolve the cudagraph_mode when there are multiple attention
|
Resolve the cudagraph_mode when there are multiple attention
|
||||||
backends with potential conflicting CUDA graph support.
|
groups with potential conflicting CUDA graph support.
|
||||||
Then initialize the cudagraph_dispatcher based on the resolved
|
Then initialize the cudagraph_dispatcher based on the resolved
|
||||||
cudagraph_mode.
|
cudagraph_mode.
|
||||||
"""
|
"""
|
||||||
min_cg_support = AttentionCGSupport.ALWAYS
|
min_cg_support = AttentionCGSupport.ALWAYS
|
||||||
min_cg_backend_name = None
|
min_cg_backend_name = None
|
||||||
|
|
||||||
for attn_backend in attention_backends:
|
for attn_backend_set, kv_cache_group in zip(
|
||||||
builder_cls = attn_backend.get_builder_cls()
|
attention_backends, kv_cache_groups
|
||||||
if builder_cls.cudagraph_support.value < min_cg_support.value:
|
):
|
||||||
min_cg_support = builder_cls.cudagraph_support
|
for attn_backend in attn_backend_set:
|
||||||
min_cg_backend_name = attn_backend.__name__
|
builder_cls = attn_backend.get_builder_cls()
|
||||||
|
|
||||||
|
cg_support = builder_cls.get_cudagraph_support(
|
||||||
|
self.vllm_config, kv_cache_group.kv_cache_spec
|
||||||
|
)
|
||||||
|
if cg_support.value < min_cg_support.value:
|
||||||
|
min_cg_support = cg_support
|
||||||
|
min_cg_backend_name = attn_backend.__name__
|
||||||
# Flexible resolve the cudagraph mode
|
# Flexible resolve the cudagraph mode
|
||||||
cudagraph_mode = self.compilation_config.cudagraph_mode
|
cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||||
# check cudagraph for mixed batch is supported
|
# check cudagraph for mixed batch is supported
|
||||||
|
|||||||
Reference in New Issue
Block a user