[torch.compile] Inductor code caching fix (#10273)

Signed-off-by: luka <luka@neuralmagic.com>
Signed-off-by: Luka Govedic <luka.govedic@gmail.com>
This commit is contained in:
Luka Govedič
2024-11-21 00:44:57 -05:00
committed by GitHub
parent 9d827170a3
commit 8b0fe06c89
14 changed files with 602 additions and 286 deletions

View File

@@ -38,12 +38,6 @@ class TestModel(torch.nn.Module):
return y3
# Init does pattern registration, which can only happen once
config = CompilationConfig(enable_fusion=True)
reshape_pass = RedundantReshapesPass(config)
fusion_pass = FusionPass.instance(config)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
@@ -58,6 +52,11 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps):
pytest.skip("Only test eps=1e-5 for now")
# Reshape pass is needed for the fusion pass to work
config = CompilationConfig.PassConfig(enable_fusion=True,
enable_reshape=True)
reshape_pass = RedundantReshapesPass(config)
fusion_pass = FusionPass.instance(config)
backend = TestBackend(reshape_pass, fusion_pass)
model = TestModel(hidden_size, eps)