Fix per file ruff ignores related to simplification (#26259)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 21:31:53 +01:00
committed by GitHub
parent 6b6e98775f
commit b893d661b1
32 changed files with 47 additions and 191 deletions

View File

@@ -88,10 +88,7 @@ def cutlass_fp8_gemm_helper(
# make scales K-major for blockwise quant, doesn't affect 1D scales
scale_b = scale_b.t().contiguous().t()
if use_bias:
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
else:
bias = None
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
@@ -122,10 +119,7 @@ def cutlass_int8_gemm_helper(
scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32)
scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32)
if use_bias:
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
else:
bias = None
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)