[perf] Add fused MLA QKV + strided layernorm (#21116)
Signed-off-by: Mickael Seznec <mickael@mistral.ai> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -26,6 +26,7 @@ CUDA_DEVICES = [
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("strided_input", [False, True])
|
||||
@torch.inference_mode()
|
||||
def test_rms_norm(
|
||||
num_tokens: int,
|
||||
@@ -34,13 +35,17 @@ def test_rms_norm(
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
strided_input: bool,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
layer = RMSNorm(hidden_size).to(dtype=dtype)
|
||||
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||
scale = 1 / (2 * hidden_size)
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
last_dim = 2 * hidden_size if strided_input else hidden_size
|
||||
x = torch.randn(num_tokens, last_dim, dtype=dtype)
|
||||
x = x[..., :hidden_size]
|
||||
assert x.is_contiguous() != strided_input
|
||||
x *= scale
|
||||
residual = torch.randn_like(x) * scale if add_residual else None
|
||||
|
||||
@@ -72,6 +77,7 @@ def test_rms_norm(
|
||||
@pytest.mark.parametrize("quant_scale", [1.0, 0.01, 10.0])
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("strided_input", [False, True])
|
||||
def test_fused_rms_norm_quant(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
@@ -80,13 +86,18 @@ def test_fused_rms_norm_quant(
|
||||
quant_scale: float,
|
||||
seed: int,
|
||||
device: str,
|
||||
strided_input: bool,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
weight = torch.empty(hidden_size, dtype=dtype).normal_(mean=1.0, std=0.1)
|
||||
scale = 1 / (2 * hidden_size)
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
last_dim = 2 * hidden_size if strided_input else hidden_size
|
||||
x_base = torch.randn(num_tokens, last_dim, dtype=dtype)
|
||||
x = x_base[..., :hidden_size]
|
||||
assert x.is_contiguous() != strided_input
|
||||
|
||||
x *= scale
|
||||
if add_residual:
|
||||
residual = torch.randn_like(x) * scale
|
||||
@@ -106,9 +117,11 @@ def test_fused_rms_norm_quant(
|
||||
|
||||
# Unfused kernel is in-place so it goes second
|
||||
# Also use a separate clone of x to avoid modifying the input
|
||||
x_unfused = x.clone()
|
||||
x_unfused_base = x_base.clone()
|
||||
x_unfused = x_unfused_base[..., :hidden_size]
|
||||
assert x_unfused.is_contiguous() != strided_input
|
||||
torch.ops._C.fused_add_rms_norm(x_unfused, residual, weight, 1e-6)
|
||||
torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused,
|
||||
torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused.contiguous(),
|
||||
quant_scale_t)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
@@ -116,7 +129,6 @@ def test_fused_rms_norm_quant(
|
||||
residual,
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant,
|
||||
(out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6))
|
||||
@@ -131,7 +143,7 @@ def test_fused_rms_norm_quant(
|
||||
opcheck(torch.ops._C.rms_norm_static_fp8_quant,
|
||||
(out_quant_fused, x, weight, quant_scale_t, 1e-6))
|
||||
|
||||
torch.testing.assert_close(out_quant_fused.to(dtype=torch.float32),
|
||||
out_quant.to(dtype=torch.float32),
|
||||
torch.testing.assert_close(out_quant.to(dtype=torch.float32),
|
||||
out_quant_fused.to(dtype=torch.float32),
|
||||
atol=1e-3,
|
||||
rtol=1e-3)
|
||||
|
||||
Reference in New Issue
Block a user