[torch.compile] Add support for non-contiguous fused RMSNorm + group quant (#36551)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
This commit is contained in:
Luka Govedič
2026-03-11 13:56:55 -04:00
committed by GitHub
parent a1a3523a56
commit 9556af87d5
9 changed files with 219 additions and 87 deletions

View File

@@ -162,6 +162,7 @@ def ops_impl(
)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("strided_input", [False, True])
@torch.inference_mode()
def test_rms_norm(
default_vllm_config,
@@ -175,6 +176,7 @@ def test_rms_norm(
tma_alignment: int,
seed: int,
device: str,
strided_input: bool,
) -> None:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
@@ -184,17 +186,17 @@ def test_rms_norm(
if group_size is not None and hidden_size % group_size[1] != 0:
# skip
return
pytest.skip("Skip non-divisible group sizes")
if group_size is not None and has_scale_ub:
# blockwise baseline doesn't support scale_ub
return
pytest.skip("scale_ub not supported for blockwise/group quantization")
if (
group_size is None or quant_dtype != current_platform.fp8_dtype()
) and tma_alignment != 0:
# TMA alignment is only supported for groupwise fp8 kernels
return
pytest.skip("tma alignment not supported for per-token or int8 quantization")
if (
group_size is not None
@@ -202,21 +204,36 @@ def test_rms_norm(
and hidden_size // group_size[1] % tma_alignment == 0
):
# Skip tests where TMA alignment doesn't create extra padding to save time
return
pytest.skip("Skip TMA alignment cases where no extra padding is added")
if has_scale_ub and quant_dtype != current_platform.fp8_dtype():
# skip
return
pytest.skip("scale_ub only supported for fp8 quantization")
layer = RMSNorm(hidden_size, EPS).to(dtype=dtype)
# Make weights
layer.weight.data.normal_(mean=1.0, std=0.1)
# Make inputs
# Make inputs: use a wider tensor and slice to create a non-contiguous
# (strided) input when strided_input=True. The last dimension stride
# remains 1, which the kernel requires.
scale = 1 / (hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale
residual = torch.randn_like(x) * scale if add_residual else None
last_dim = 2 * hidden_size if strided_input else hidden_size
x = torch.randn(num_tokens, last_dim, dtype=dtype) * scale
x = x[:, :hidden_size]
# dim 1 gets special-cased
x_is_strided = strided_input and num_tokens != 1
# check that the input is strided iff we expect it to be
assert x.is_contiguous() != x_is_strided
# Residual must still be contiguous
residual = (
torch.randn(num_tokens, hidden_size, dtype=dtype) * scale
if add_residual
else None
)
if has_scale_ub:
rms_x, _ = ref_rms_norm(layer, x, residual)
scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device="cuda")
@@ -260,12 +277,33 @@ def test_rms_norm(
if add_residual:
assert torch.allclose(ref_residual, ops_residual)
output = torch.empty_like(x, dtype=quant_dtype)
output = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
scales = torch.empty(
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
)
opcheck(
torch.ops._C.rms_norm_dynamic_per_token_quant,
(output, x, layer.weight, scales, 1e-5, scale_ub, residual),
)
if group_size is None:
opcheck(
torch.ops._C.rms_norm_dynamic_per_token_quant,
(output, x, layer.weight, scales, 1e-5, scale_ub, residual),
)
else:
# TODO(luka/eliza) opcheck is broken?
# Somehow the cloned args are getting mutated in-place,
# which causes the opcheck to fail.
# https://github.com/vllm-project/vllm/issues/36688
return
opcheck(
torch.ops._C.rms_norm_per_block_quant,
(
output,
x,
layer.weight,
scales,
1e-5,
scale_ub,
residual,
group_size[1],
True, # is_scale_transposed
),
)