Fix per file ruff ignores related to simplification (#26259)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -66,10 +66,7 @@ def test_cutlass_mla_decode(
|
||||
b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen, torch_dtype
|
||||
):
|
||||
device = torch.device("cuda:0")
|
||||
if torch_dtype == torch.float8_e4m3fn:
|
||||
init_dtype = torch.bfloat16
|
||||
else:
|
||||
init_dtype = torch_dtype
|
||||
init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype
|
||||
torch.set_default_dtype(init_dtype)
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
@@ -52,10 +52,7 @@ def test_flash_mla(
|
||||
b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen, torch_dtype
|
||||
):
|
||||
device = torch.device("cuda:0")
|
||||
if torch_dtype == torch.float8_e4m3fn:
|
||||
init_dtype = torch.bfloat16
|
||||
else:
|
||||
init_dtype = torch_dtype
|
||||
init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype
|
||||
torch.set_default_dtype(init_dtype)
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
@@ -33,10 +33,7 @@ def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
|
||||
|
||||
# More efficient implementation
|
||||
# Convert decay factors to matrix form
|
||||
if ed.dim() == 1:
|
||||
decay = torch.exp(-ed).view(1, -1, 1, 1)
|
||||
else:
|
||||
decay = torch.exp(-ed)
|
||||
decay = torch.exp(-ed).view(1, -1, 1, 1) if ed.dim() == 1 else torch.exp(-ed)
|
||||
|
||||
for b in range(B):
|
||||
for step in range(S):
|
||||
|
||||
Reference in New Issue
Block a user