[BugFix] Fix fusion test and add them to CI (#16287)
Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
@@ -44,12 +44,17 @@ class TestModel(torch.nn.Module):
|
||||
resid = torch.sqrt(x)
|
||||
y = self.norm[0](x)
|
||||
|
||||
x2 = self.fp8_linear.apply(y, self.w[0], self.wscale[0], self.scale[0])
|
||||
x2 = self.fp8_linear.apply(y,
|
||||
self.w[0],
|
||||
self.wscale[0],
|
||||
input_scale=self.scale[0])
|
||||
# make sure resid is used for replacement to work
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
x3 = self.fp8_linear.apply(y2, self.w[1], self.wscale[1],
|
||||
self.scale[1])
|
||||
x3 = self.fp8_linear.apply(y2,
|
||||
self.w[1],
|
||||
self.wscale[1],
|
||||
input_scale=self.scale[1])
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
return y3
|
||||
|
||||
|
||||
Reference in New Issue
Block a user