full cudagraph for flex-attn (#36298)

Signed-off-by: shunting314 <shunting@meta.com>
This commit is contained in:
shunting314
2026-04-02 21:15:01 -07:00
committed by GitHub
parent 2ad7c0335f
commit 8b141ed8c3
4 changed files with 145 additions and 11 deletions

View File

@@ -170,14 +170,3 @@ class TestFullCUDAGraph:
piecewise_res.outputs[0].text.lower()
== full_res.outputs[0].text.lower()
)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
def test_full_cudagraph_with_invalid_backend():
# Flex_Attention is not supported with full cuda graph
with pytest.raises(RuntimeError):
LLM(
model="Qwen/Qwen2-1.5B-Instruct",
compilation_config=CompilationConfig(cudagraph_mode="FULL"),
attention_config={"backend": "FLEX_ATTENTION"},
)