[Core] Allow full cudagraph with separate attention routines and orthogonal to compilation, add support for FA2 and FlashInfer (#20059)

Signed-off-by: fhl <2410591650@qq.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
fhl2000
2025-08-15 22:01:39 +08:00
committed by GitHub
parent a0632a3e03
commit 74f441f4b5
34 changed files with 1839 additions and 597 deletions

View File

@@ -11,7 +11,8 @@ from typing import Callable, Optional
import torch
import vllm.envs as envs
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.config import (CompilationLevel, CUDAGraphMode,
get_current_vllm_config)
from vllm.logger import init_logger
logger = init_logger(__name__)
@@ -115,8 +116,8 @@ class TorchCompileWrapperWithCustomDispatcher:
except Exception:
pass
if self.vllm_config.compilation_config.use_cudagraph and \
"update" in new_code.co_names:
if self.vllm_config.compilation_config.cudagraph_mode != \
CUDAGraphMode.NONE and "update" in new_code.co_names:
import depyf
src = depyf.decompile(new_code)
msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa