Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -13,13 +13,12 @@ QUANT_DTYPES = [current_platform.fp8_dtype()]
|
||||
NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing
|
||||
HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
|
||||
def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor,
|
||||
scale: torch.Tensor) -> torch.Tensor:
|
||||
def ref_impl(
|
||||
silu_and_mul: SiluAndMul, x: torch.Tensor, scale: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
silu_and_mul_out = silu_and_mul.forward_native(x)
|
||||
out, scales = ops.scaled_fp8_quant(silu_and_mul_out, scale)
|
||||
return out
|
||||
@@ -27,9 +26,7 @@ def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor,
|
||||
|
||||
def ops_impl(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
out_shape = (x.shape[0], x.shape[1] // 2)
|
||||
out = torch.empty(out_shape,
|
||||
dtype=current_platform.fp8_dtype(),
|
||||
device=x.device)
|
||||
out = torch.empty(out_shape, dtype=current_platform.fp8_dtype(), device=x.device)
|
||||
torch.ops._C.silu_and_mul_quant(out, x, scale)
|
||||
return out
|
||||
|
||||
@@ -57,7 +54,7 @@ def test_silu_and_mul(
|
||||
layer = SiluAndMul()
|
||||
|
||||
# Make inputs
|
||||
scale = (torch.randn((1), device=device, dtype=torch.float32))
|
||||
scale = torch.randn((1), device=device, dtype=torch.float32)
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
|
||||
ref_out = ref_impl(layer, x, scale)
|
||||
@@ -66,6 +63,7 @@ def test_silu_and_mul(
|
||||
assert ref_out.dtype == quant_dtype
|
||||
assert ops_out.dtype == quant_dtype
|
||||
assert ref_out.shape == ops_out.shape
|
||||
assert torch.allclose(ref_out.to(dtype=torch.float32),
|
||||
ops_out.to(dtype=torch.float32))
|
||||
assert torch.allclose(
|
||||
ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)
|
||||
)
|
||||
opcheck(torch.ops._C.silu_and_mul_quant, (ops_out, x, scale))
|
||||
|
||||
Reference in New Issue
Block a user