[CUDA] Enable full cudagraph for FlashMLA (#18581)
Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
Test the piecewise compilation with a simple model so that we
|
||||
can exactly calculate the expected output and side effects.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.library import Library
|
||||
@@ -14,6 +14,7 @@ from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
from vllm.envs import VLLM_USE_V1
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
global_counter = 0
|
||||
@@ -76,7 +77,8 @@ class SillyModel(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def _test_simple_piecewise_compile(*, use_inductor):
|
||||
@pytest.mark.parametrize("use_inductor", [True, False])
|
||||
def test_simple_piecewise_compile(use_inductor):
|
||||
assert VLLM_USE_V1
|
||||
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
@@ -99,7 +101,7 @@ def _test_simple_piecewise_compile(*, use_inductor):
|
||||
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
|
||||
num_cudagraph_captured=
|
||||
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
), set_forward_context({}, vllm_config=vllm_config):
|
||||
|
||||
model(inputs)
|
||||
|
||||
@@ -112,11 +114,3 @@ def _test_simple_piecewise_compile(*, use_inductor):
|
||||
output = model(input)
|
||||
assert global_counter == 2
|
||||
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
|
||||
|
||||
|
||||
def test_simple_piecewise_compile_inductor():
|
||||
_test_simple_piecewise_compile(use_inductor=True)
|
||||
|
||||
|
||||
def test_simple_piecewise_compile_no_inductor():
|
||||
_test_simple_piecewise_compile(use_inductor=False)
|
||||
|
||||
Reference in New Issue
Block a user