[Feature][CI]: compare func & no_func outputs in test_functionalization.py (#35481)
Signed-off-by: Bhuminjay <bhuminjaysoni@gmail.com> Signed-off-by: Bhuminjay Soni <Soni5Happy@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -309,12 +309,15 @@ def test_fix_functionalization(
|
||||
model = model_class()
|
||||
inputs_func = model.example_inputs()
|
||||
inputs_no_func = copy.deepcopy(inputs_func)
|
||||
model_func = model_class()
|
||||
model_no_func = copy.deepcopy(model_func)
|
||||
model_func = copy.deepcopy(model)
|
||||
model_no_func = copy.deepcopy(model)
|
||||
model_func = torch.compile(model_func, backend=backend_func)
|
||||
model_no_func = torch.compile(model_no_func, backend=backend_no_func)
|
||||
model_func(*inputs_func)
|
||||
model_no_func(*inputs_no_func)
|
||||
|
||||
# deepcopy inputs to prevent potential in place mutation
|
||||
outputs_func = model_func(*copy.deepcopy(inputs_func))
|
||||
outputs_no_func = model_no_func(*copy.deepcopy(inputs_no_func))
|
||||
torch.testing.assert_close(outputs_func, outputs_no_func)
|
||||
|
||||
# check if the functionalization pass is applied
|
||||
for op in model.ops_in_model(do_fusion):
|
||||
@@ -332,8 +335,3 @@ def test_fix_functionalization(
|
||||
found[op] = True
|
||||
assert all(found[op] for op in model.ops_in_model(do_fusion))
|
||||
assert all(not found.get(op) for op in model.ops_not_in_model())
|
||||
|
||||
# TODO (Rohan138): compare the outputs from model_func and model_no_func
|
||||
# currently runs into errors while comparing `TestFusedAddRMSNorm`
|
||||
# Linked issue: https://github.com/vllm-project/vllm/issues/34996
|
||||
# torch.testing.assert_close(outputs_func, outputs_no_func)
|
||||
|
||||
Reference in New Issue
Block a user