Fix per file ruff ignores related to simplification (#26259)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 21:31:53 +01:00
committed by GitHub
parent 6b6e98775f
commit b893d661b1
32 changed files with 47 additions and 191 deletions

View File

@@ -66,10 +66,7 @@ def test_cutlass_mla_decode(
b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen, torch_dtype
):
device = torch.device("cuda:0")
if torch_dtype == torch.float8_e4m3fn:
init_dtype = torch.bfloat16
else:
init_dtype = torch_dtype
init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype
torch.set_default_dtype(init_dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)

View File

@@ -52,10 +52,7 @@ def test_flash_mla(
b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen, torch_dtype
):
device = torch.device("cuda:0")
if torch_dtype == torch.float8_e4m3fn:
init_dtype = torch.bfloat16
else:
init_dtype = torch_dtype
init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype
torch.set_default_dtype(init_dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)

View File

@@ -33,10 +33,7 @@ def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
# More efficient implementation
# Convert decay factors to matrix form
if ed.dim() == 1:
decay = torch.exp(-ed).view(1, -1, 1, 1)
else:
decay = torch.exp(-ed)
decay = torch.exp(-ed).view(1, -1, 1, 1) if ed.dim() == 1 else torch.exp(-ed)
for b in range(B):
for step in range(S):

View File

@@ -705,10 +705,7 @@ def _pplx_moe(
with set_current_vllm_config(vllm_config), override_config(moe_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
if shared_experts is not None:
shared_output = shared_experts(a)
else:
shared_output = None
shared_output = shared_experts(a) if shared_experts is not None else None
torch_output = torch_experts(
a,

View File

@@ -88,10 +88,7 @@ def cutlass_fp8_gemm_helper(
# make scales K-major for blockwise quant, doesn't affect 1D scales
scale_b = scale_b.t().contiguous().t()
if use_bias:
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
else:
bias = None
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
@@ -122,10 +119,7 @@ def cutlass_int8_gemm_helper(
scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32)
scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32)
if use_bias:
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
else:
bias = None
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)

View File

@@ -84,10 +84,7 @@ def onednn_int8_gemm_test_helper(
azp = None
azp_adj = None
if use_bias:
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
else:
bias = None
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None
handler = ops.create_onednn_scaled_mm(
b,

View File

@@ -963,13 +963,9 @@ def make_test_metadata(
None if encoder_seq_lens is None else (sum(encoder_seq_lens))
)
if cross_test_params is None:
cross_kv_mmap = None
else:
# Encoder/decoder or encoder-only models only:
# * Extract *cross-attention* slot_mapping and block table
# (kv_mmap)
cross_kv_mmap = cross_test_params.kv_mmap
# For encoder/decoder or encoder-only models only, extract *cross-attention*
# slot_mapping and block table (kv_mmap)
cross_kv_mmap = None if cross_test_params is None else cross_test_params.kv_mmap
attn_backend_obj = make_backend(attn_backend.name)