[Kernel] Added flashinfer fp8 per-tensor gemms (#22895)
Signed-off-by: Julien Lin <jullin@nvidia.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -104,8 +104,7 @@ class TestQuantModel(torch.nn.Module):
|
||||
# Initialize weights
|
||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=True,
|
||||
use_per_token_if_dynamic=False)
|
||||
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False)
|
||||
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
# Create a weight that is compatible with torch._scaled_mm,
|
||||
|
||||
Reference in New Issue
Block a user