[BugFix] Fix fusion test and add them to CI (#16287)

Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
Luka Govedič
2025-04-09 02:46:45 -04:00
committed by GitHub
parent b1eb4ca152
commit 9cdde47289
3 changed files with 75 additions and 50 deletions

View File

@@ -44,12 +44,17 @@ class TestModel(torch.nn.Module):
resid = torch.sqrt(x)
y = self.norm[0](x)
x2 = self.fp8_linear.apply(y, self.w[0], self.wscale[0], self.scale[0])
x2 = self.fp8_linear.apply(y,
self.w[0],
self.wscale[0],
input_scale=self.scale[0])
# make sure resid is used for replacement to work
y2, resid = self.norm[1](x2, resid)
x3 = self.fp8_linear.apply(y2, self.w[1], self.wscale[1],
self.scale[1])
x3 = self.fp8_linear.apply(y2,
self.w[1],
self.wscale[1],
input_scale=self.scale[1])
y3, resid = self.norm[2](x3, resid) # use resid here
return y3