[v1][torch.compile] support managing cudagraph buffer (#10203)
Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -80,7 +80,7 @@ def test_simple_piecewise_compile():
|
||||
config = os.path.join(directory, "piecewise_compilation_config.json")
|
||||
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
|
||||
|
||||
input_buffer = torch.randn(100).cuda()
|
||||
inputs = torch.randn(100).cuda()
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1, # one graph for the model
|
||||
@@ -92,15 +92,15 @@ def test_simple_piecewise_compile():
|
||||
):
|
||||
|
||||
with set_compile_context([1, 2]):
|
||||
model(input_buffer)
|
||||
model(inputs)
|
||||
|
||||
model(input_buffer[:2])
|
||||
model(input_buffer[:1])
|
||||
model(torch.randn(2).cuda())
|
||||
model(torch.randn(1).cuda())
|
||||
|
||||
input_buffer[:2].zero_()
|
||||
input = torch.zeros(2).cuda()
|
||||
global global_counter
|
||||
global_counter = 0
|
||||
output = model(input_buffer[:2])
|
||||
output = model(input)
|
||||
assert global_counter == 2
|
||||
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user