[torch.compile] Add support for non-contiguous fused RMSNorm + group quant (#36551)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -162,6 +162,7 @@ def ops_impl(
|
||||
)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("strided_input", [False, True])
|
||||
@torch.inference_mode()
|
||||
def test_rms_norm(
|
||||
default_vllm_config,
|
||||
@@ -175,6 +176,7 @@ def test_rms_norm(
|
||||
tma_alignment: int,
|
||||
seed: int,
|
||||
device: str,
|
||||
strided_input: bool,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
@@ -184,17 +186,17 @@ def test_rms_norm(
|
||||
|
||||
if group_size is not None and hidden_size % group_size[1] != 0:
|
||||
# skip
|
||||
return
|
||||
pytest.skip("Skip non-divisible group sizes")
|
||||
|
||||
if group_size is not None and has_scale_ub:
|
||||
# blockwise baseline doesn't support scale_ub
|
||||
return
|
||||
pytest.skip("scale_ub not supported for blockwise/group quantization")
|
||||
|
||||
if (
|
||||
group_size is None or quant_dtype != current_platform.fp8_dtype()
|
||||
) and tma_alignment != 0:
|
||||
# TMA alignment is only supported for groupwise fp8 kernels
|
||||
return
|
||||
pytest.skip("tma alignment not supported for per-token or int8 quantization")
|
||||
|
||||
if (
|
||||
group_size is not None
|
||||
@@ -202,21 +204,36 @@ def test_rms_norm(
|
||||
and hidden_size // group_size[1] % tma_alignment == 0
|
||||
):
|
||||
# Skip tests where TMA alignment doesn't create extra padding to save time
|
||||
return
|
||||
pytest.skip("Skip TMA alignment cases where no extra padding is added")
|
||||
|
||||
if has_scale_ub and quant_dtype != current_platform.fp8_dtype():
|
||||
# skip
|
||||
return
|
||||
pytest.skip("scale_ub only supported for fp8 quantization")
|
||||
|
||||
layer = RMSNorm(hidden_size, EPS).to(dtype=dtype)
|
||||
|
||||
# Make weights
|
||||
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||
|
||||
# Make inputs
|
||||
# Make inputs: use a wider tensor and slice to create a non-contiguous
|
||||
# (strided) input when strided_input=True. The last dimension stride
|
||||
# remains 1, which the kernel requires.
|
||||
scale = 1 / (hidden_size)
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale
|
||||
residual = torch.randn_like(x) * scale if add_residual else None
|
||||
last_dim = 2 * hidden_size if strided_input else hidden_size
|
||||
x = torch.randn(num_tokens, last_dim, dtype=dtype) * scale
|
||||
x = x[:, :hidden_size]
|
||||
|
||||
# dim 1 gets special-cased
|
||||
x_is_strided = strided_input and num_tokens != 1
|
||||
# check that the input is strided iff we expect it to be
|
||||
assert x.is_contiguous() != x_is_strided
|
||||
|
||||
# Residual must still be contiguous
|
||||
residual = (
|
||||
torch.randn(num_tokens, hidden_size, dtype=dtype) * scale
|
||||
if add_residual
|
||||
else None
|
||||
)
|
||||
if has_scale_ub:
|
||||
rms_x, _ = ref_rms_norm(layer, x, residual)
|
||||
scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device="cuda")
|
||||
@@ -260,12 +277,33 @@ def test_rms_norm(
|
||||
if add_residual:
|
||||
assert torch.allclose(ref_residual, ops_residual)
|
||||
|
||||
output = torch.empty_like(x, dtype=quant_dtype)
|
||||
output = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
|
||||
scales = torch.empty(
|
||||
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.rms_norm_dynamic_per_token_quant,
|
||||
(output, x, layer.weight, scales, 1e-5, scale_ub, residual),
|
||||
)
|
||||
if group_size is None:
|
||||
opcheck(
|
||||
torch.ops._C.rms_norm_dynamic_per_token_quant,
|
||||
(output, x, layer.weight, scales, 1e-5, scale_ub, residual),
|
||||
)
|
||||
else:
|
||||
# TODO(luka/eliza) opcheck is broken?
|
||||
# Somehow the cloned args are getting mutated in-place,
|
||||
# which causes the opcheck to fail.
|
||||
# https://github.com/vllm-project/vllm/issues/36688
|
||||
return
|
||||
opcheck(
|
||||
torch.ops._C.rms_norm_per_block_quant,
|
||||
(
|
||||
output,
|
||||
x,
|
||||
layer.weight,
|
||||
scales,
|
||||
1e-5,
|
||||
scale_ub,
|
||||
residual,
|
||||
group_size[1],
|
||||
True, # is_scale_transposed
|
||||
),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user