[v1][torch.compile] support managing cudagraph buffer (#10203)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
youkaichao
2024-11-11 11:10:27 -08:00
committed by GitHub
parent d7a4f2207b
commit 330e82d34a
4 changed files with 59 additions and 8 deletions

View File

@@ -80,7 +80,7 @@ def test_simple_piecewise_compile():
config = os.path.join(directory, "piecewise_compilation_config.json")
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
input_buffer = torch.randn(100).cuda()
inputs = torch.randn(100).cuda()
with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
@@ -92,15 +92,15 @@ def test_simple_piecewise_compile():
):
with set_compile_context([1, 2]):
model(input_buffer)
model(inputs)
model(input_buffer[:2])
model(input_buffer[:1])
model(torch.randn(2).cuda())
model(torch.randn(1).cuda())
input_buffer[:2].zero_()
input = torch.zeros(2).cuda()
global global_counter
global_counter = 0
output = model(input_buffer[:2])
output = model(input)
assert global_counter == 2
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))