[Kernel] Added flashinfer fp8 per-tensor gemms (#22895)

Signed-off-by: Julien Lin <jullin@nvidia.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
nvjullin
2025-08-26 21:54:04 +08:00
committed by GitHub
parent b78bed1bc5
commit f66673a39d
9 changed files with 198 additions and 36 deletions

View File

@@ -104,8 +104,7 @@ class TestQuantModel(torch.nn.Module):
# Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02)
self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=True,
use_per_token_if_dynamic=False)
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False)
self.scale = torch.rand(1, dtype=torch.float32)
# Create a weight that is compatible with torch._scaled_mm,