[Core] Allow full cudagraph with separate attention routines and orthogonal to compilation, add support for FA2 and FlashInfer (#20059)
Signed-off-by: fhl <2410591650@qq.com> Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
@@ -58,8 +58,7 @@ class TritonAttentionMetadata:
|
||||
|
||||
class TritonAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[TritonAttentionMetadata]):
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.ALWAYS
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
@@ -132,11 +131,6 @@ class TritonAttentionMetadataBuilder(
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
# Full CUDA Graph always supported
|
||||
return True
|
||||
|
||||
|
||||
class TritonAttentionBackend(AttentionBackend):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user