Add ability to use CUDAGraphs with use_inductor=False (#17345)

Signed-off-by: rzou <zou3519@gmail.com>
This commit is contained in:
Richard Zou
2025-05-28 22:16:52 -04:00
committed by GitHub
parent 515b413ebf
commit 26b4fa45be
5 changed files with 51 additions and 13 deletions

View File

@@ -261,12 +261,14 @@ def tractable_computation(input_ids: torch.Tensor,
@torch.inference_mode
def run_model(llama_config,
use_compile: bool,
use_inductor: bool,
split_attn: bool = False) -> torch.Tensor:
if use_compile:
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
use_inductor=use_inductor,
cudagraph_capture_sizes=[1, 2],
)
if split_attn:
@@ -304,7 +306,7 @@ def run_model(llama_config,
return output.cpu()
def test_toy_llama():
def _test_toy_llama(*, use_inductor):
# compare output with and without piecewise compilation
llama_config = LlamaConfig(hidden_size=128,
@@ -326,8 +328,14 @@ def test_toy_llama():
num_backend_compilations=0,
num_cudagraph_caputured=0,
):
outputs.append(run_model(llama_config, use_compile=False))
run_model(tractable_config, use_compile=False)
outputs.append(
run_model(llama_config, use_inductor=False, use_compile=False))
run_model(tractable_config, use_inductor=False, use_compile=False)
if use_inductor:
kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
else:
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
@@ -336,9 +344,13 @@ def test_toy_llama():
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
num_cudagraph_caputured=
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
**kwargs,
):
outputs.append(run_model(llama_config, use_compile=True))
run_model(tractable_config, use_compile=True)
outputs.append(
run_model(llama_config,
use_inductor=use_inductor,
use_compile=True))
run_model(tractable_config, use_inductor=use_inductor, use_compile=True)
with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
@@ -353,13 +365,27 @@ def test_toy_llama():
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
outputs.append(
run_model(llama_config, use_compile=True, split_attn=True))
run_model(tractable_config, use_compile=True, split_attn=True)
run_model(llama_config,
use_inductor=use_inductor,
use_compile=True,
split_attn=True))
run_model(tractable_config,
use_inductor=use_inductor,
use_compile=True,
split_attn=True)
for i in range(1, len(outputs)):
assert torch.allclose(outputs[0], outputs[i])
def test_toy_llama_inductor():
_test_toy_llama(use_inductor=True)
def test_toy_no_inductor():
_test_toy_llama(use_inductor=False)
@torch.inference_mode
def benchmark():
from triton.testing import do_bench