[CUDA] Enable full cudagraph for FlashMLA (#18581)

Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
Luka Govedič
2025-06-13 14:12:26 -04:00
committed by GitHub
parent 1015296b79
commit 3597b06a4f
17 changed files with 452 additions and 219 deletions

View File

@@ -4,7 +4,7 @@
Test the piecewise compilation with a simple model so that we
can exactly calculate the expected output and side effects.
"""
import pytest
import torch
from torch import nn
from torch.library import Library
@@ -14,6 +14,7 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
set_current_vllm_config)
from vllm.envs import VLLM_USE_V1
from vllm.forward_context import set_forward_context
from vllm.utils import direct_register_custom_op
global_counter = 0
@@ -76,7 +77,8 @@ class SillyModel(nn.Module):
return x
def _test_simple_piecewise_compile(*, use_inductor):
@pytest.mark.parametrize("use_inductor", [True, False])
def test_simple_piecewise_compile(use_inductor):
assert VLLM_USE_V1
vllm_config = VllmConfig(compilation_config=CompilationConfig(
@@ -99,7 +101,7 @@ def _test_simple_piecewise_compile(*, use_inductor):
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured=
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
), set_forward_context({}, vllm_config=vllm_config):
model(inputs)
@@ -112,11 +114,3 @@ def _test_simple_piecewise_compile(*, use_inductor):
output = model(input)
assert global_counter == 2
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
def test_simple_piecewise_compile_inductor():
_test_simple_piecewise_compile(use_inductor=True)
def test_simple_piecewise_compile_no_inductor():
_test_simple_piecewise_compile(use_inductor=False)