[Feature][CI]: compare func & no_func outputs in test_functionalization.py (#35481)

Signed-off-by: Bhuminjay <bhuminjaysoni@gmail.com>
Signed-off-by: Bhuminjay Soni <Soni5Happy@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Bhuminjay Soni
2026-03-04 23:31:16 +05:30
committed by GitHub
parent fd3bfe74c9
commit fb3e78ab09

View File

@@ -309,12 +309,15 @@ def test_fix_functionalization(
model = model_class()
inputs_func = model.example_inputs()
inputs_no_func = copy.deepcopy(inputs_func)
model_func = model_class()
model_no_func = copy.deepcopy(model_func)
model_func = copy.deepcopy(model)
model_no_func = copy.deepcopy(model)
model_func = torch.compile(model_func, backend=backend_func)
model_no_func = torch.compile(model_no_func, backend=backend_no_func)
model_func(*inputs_func)
model_no_func(*inputs_no_func)
# deepcopy inputs to prevent potential in place mutation
outputs_func = model_func(*copy.deepcopy(inputs_func))
outputs_no_func = model_no_func(*copy.deepcopy(inputs_no_func))
torch.testing.assert_close(outputs_func, outputs_no_func)
# check if the functionalization pass is applied
for op in model.ops_in_model(do_fusion):
@@ -332,8 +335,3 @@ def test_fix_functionalization(
found[op] = True
assert all(found[op] for op in model.ops_in_model(do_fusion))
assert all(not found.get(op) for op in model.ops_not_in_model())
# TODO (Rohan138): compare the outputs from model_func and model_no_func
# currently runs into errors while comparing `TestFusedAddRMSNorm`
# Linked issue: https://github.com/vllm-project/vllm/issues/34996
# torch.testing.assert_close(outputs_func, outputs_no_func)