[perf] Add fused MLA QKV + strided layernorm (#21116)

Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Mickaël Seznec
2025-07-22 16:07:44 +02:00
committed by GitHub
parent 0df4d9b06b
commit 4fb56914c5
7 changed files with 214 additions and 66 deletions

View File

@@ -26,6 +26,7 @@ CUDA_DEVICES = [
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("strided_input", [False, True])
@torch.inference_mode()
def test_rms_norm(
num_tokens: int,
@@ -34,13 +35,17 @@ def test_rms_norm(
dtype: torch.dtype,
seed: int,
device: str,
strided_input: bool,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
layer = RMSNorm(hidden_size).to(dtype=dtype)
layer.weight.data.normal_(mean=1.0, std=0.1)
scale = 1 / (2 * hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
last_dim = 2 * hidden_size if strided_input else hidden_size
x = torch.randn(num_tokens, last_dim, dtype=dtype)
x = x[..., :hidden_size]
assert x.is_contiguous() != strided_input
x *= scale
residual = torch.randn_like(x) * scale if add_residual else None
@@ -72,6 +77,7 @@ def test_rms_norm(
@pytest.mark.parametrize("quant_scale", [1.0, 0.01, 10.0])
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("strided_input", [False, True])
def test_fused_rms_norm_quant(
num_tokens: int,
hidden_size: int,
@@ -80,13 +86,18 @@ def test_fused_rms_norm_quant(
quant_scale: float,
seed: int,
device: str,
strided_input: bool,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
weight = torch.empty(hidden_size, dtype=dtype).normal_(mean=1.0, std=0.1)
scale = 1 / (2 * hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
last_dim = 2 * hidden_size if strided_input else hidden_size
x_base = torch.randn(num_tokens, last_dim, dtype=dtype)
x = x_base[..., :hidden_size]
assert x.is_contiguous() != strided_input
x *= scale
if add_residual:
residual = torch.randn_like(x) * scale
@@ -106,9 +117,11 @@ def test_fused_rms_norm_quant(
# Unfused kernel is in-place so it goes second
# Also use a separate clone of x to avoid modifying the input
x_unfused = x.clone()
x_unfused_base = x_base.clone()
x_unfused = x_unfused_base[..., :hidden_size]
assert x_unfused.is_contiguous() != strided_input
torch.ops._C.fused_add_rms_norm(x_unfused, residual, weight, 1e-6)
torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused,
torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused.contiguous(),
quant_scale_t)
torch.cuda.synchronize()
@@ -116,7 +129,6 @@ def test_fused_rms_norm_quant(
residual,
atol=1e-2,
rtol=1e-2)
opcheck(
torch.ops._C.fused_add_rms_norm_static_fp8_quant,
(out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6))
@@ -131,7 +143,7 @@ def test_fused_rms_norm_quant(
opcheck(torch.ops._C.rms_norm_static_fp8_quant,
(out_quant_fused, x, weight, quant_scale_t, 1e-6))
torch.testing.assert_close(out_quant_fused.to(dtype=torch.float32),
out_quant.to(dtype=torch.float32),
torch.testing.assert_close(out_quant.to(dtype=torch.float32),
out_quant_fused.to(dtype=torch.float32),
atol=1e-3,
rtol=1e-3)