Add ability to use CUDAGraphs with use_inductor=False (#17345)
Signed-off-by: rzou <zou3519@gmail.com>
This commit is contained in:
@@ -261,12 +261,14 @@ def tractable_computation(input_ids: torch.Tensor,
|
||||
@torch.inference_mode
|
||||
def run_model(llama_config,
|
||||
use_compile: bool,
|
||||
use_inductor: bool,
|
||||
split_attn: bool = False) -> torch.Tensor:
|
||||
|
||||
if use_compile:
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
use_inductor=use_inductor,
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
)
|
||||
if split_attn:
|
||||
@@ -304,7 +306,7 @@ def run_model(llama_config,
|
||||
return output.cpu()
|
||||
|
||||
|
||||
def test_toy_llama():
|
||||
def _test_toy_llama(*, use_inductor):
|
||||
# compare output with and without piecewise compilation
|
||||
|
||||
llama_config = LlamaConfig(hidden_size=128,
|
||||
@@ -326,8 +328,14 @@ def test_toy_llama():
|
||||
num_backend_compilations=0,
|
||||
num_cudagraph_caputured=0,
|
||||
):
|
||||
outputs.append(run_model(llama_config, use_compile=False))
|
||||
run_model(tractable_config, use_compile=False)
|
||||
outputs.append(
|
||||
run_model(llama_config, use_inductor=False, use_compile=False))
|
||||
run_model(tractable_config, use_inductor=False, use_compile=False)
|
||||
|
||||
if use_inductor:
|
||||
kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
|
||||
else:
|
||||
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1, # one graph for the model
|
||||
@@ -336,9 +344,13 @@ def test_toy_llama():
|
||||
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
|
||||
num_cudagraph_caputured=
|
||||
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
**kwargs,
|
||||
):
|
||||
outputs.append(run_model(llama_config, use_compile=True))
|
||||
run_model(tractable_config, use_compile=True)
|
||||
outputs.append(
|
||||
run_model(llama_config,
|
||||
use_inductor=use_inductor,
|
||||
use_compile=True))
|
||||
run_model(tractable_config, use_inductor=use_inductor, use_compile=True)
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1, # one graph for the model
|
||||
@@ -353,13 +365,27 @@ def test_toy_llama():
|
||||
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
outputs.append(
|
||||
run_model(llama_config, use_compile=True, split_attn=True))
|
||||
run_model(tractable_config, use_compile=True, split_attn=True)
|
||||
run_model(llama_config,
|
||||
use_inductor=use_inductor,
|
||||
use_compile=True,
|
||||
split_attn=True))
|
||||
run_model(tractable_config,
|
||||
use_inductor=use_inductor,
|
||||
use_compile=True,
|
||||
split_attn=True)
|
||||
|
||||
for i in range(1, len(outputs)):
|
||||
assert torch.allclose(outputs[0], outputs[i])
|
||||
|
||||
|
||||
def test_toy_llama_inductor():
|
||||
_test_toy_llama(use_inductor=True)
|
||||
|
||||
|
||||
def test_toy_no_inductor():
|
||||
_test_toy_llama(use_inductor=False)
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def benchmark():
|
||||
from triton.testing import do_bench
|
||||
|
||||
Reference in New Issue
Block a user