diff --git a/tests/compile/passes/test_functionalization.py b/tests/compile/passes/test_functionalization.py index 788ae7889..8d13e622d 100644 --- a/tests/compile/passes/test_functionalization.py +++ b/tests/compile/passes/test_functionalization.py @@ -309,12 +309,15 @@ def test_fix_functionalization( model = model_class() inputs_func = model.example_inputs() inputs_no_func = copy.deepcopy(inputs_func) - model_func = model_class() - model_no_func = copy.deepcopy(model_func) + model_func = copy.deepcopy(model) + model_no_func = copy.deepcopy(model) model_func = torch.compile(model_func, backend=backend_func) model_no_func = torch.compile(model_no_func, backend=backend_no_func) - model_func(*inputs_func) - model_no_func(*inputs_no_func) + + # deepcopy inputs to prevent potential in place mutation + outputs_func = model_func(*copy.deepcopy(inputs_func)) + outputs_no_func = model_no_func(*copy.deepcopy(inputs_no_func)) + torch.testing.assert_close(outputs_func, outputs_no_func) # check if the functionalization pass is applied for op in model.ops_in_model(do_fusion): @@ -332,8 +335,3 @@ def test_fix_functionalization( found[op] = True assert all(found[op] for op in model.ops_in_model(do_fusion)) assert all(not found.get(op) for op in model.ops_not_in_model()) - - # TODO (Rohan138): compare the outputs from model_func and model_no_func - # currently runs into errors while comparing `TestFusedAddRMSNorm` - # Linked issue: https://github.com/vllm-project/vllm/issues/34996 - # torch.testing.assert_close(outputs_func, outputs_no_func)