[ROCm] Enable fused_silu_mul_block_quant on ROCm (#38817)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
Gregory Shtrasberg
2026-04-08 11:23:32 -05:00
committed by GitHub
parent d74a306c4b
commit 56c976c1b5
7 changed files with 28 additions and 29 deletions

View File

@@ -16,7 +16,7 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
from vllm.platforms import current_platform
DTYPES = [torch.float16, torch.bfloat16]
QUANT_DTYPES = [torch.float8_e4m3fn, torch.int8]
QUANT_DTYPES = [current_platform.fp8_dtype(), torch.int8]
VEC_HIDDEN_SIZES = [1024, 1025, 1027, 1029]
NUM_TOKENS_HIDDEN_SIZES = [
*[(1, i) for i in [64, *VEC_HIDDEN_SIZES, 2048, 5120]],
@@ -28,9 +28,7 @@ SCALE_UBS = [False]
GROUP_SIZES = [64, 128]
IS_SCALE_TRANSPOSED = [False, True]
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)
]
CUDA_DEVICES = [i for i in range(1 if torch.accelerator.device_count() == 1 else 2)]
def ref_silu_and_mul_per_block_quant(
@@ -60,7 +58,7 @@ def ref_silu_and_mul_per_block_quant(
@pytest.mark.parametrize("group_size", GROUP_SIZES)
@pytest.mark.parametrize("is_scale_transposed", IS_SCALE_TRANSPOSED)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device_idx", CUDA_DEVICES)
@torch.inference_mode()
def test_silu_and_mul_per_block_quant(
default_vllm_config,
@@ -72,9 +70,11 @@ def test_silu_and_mul_per_block_quant(
group_size: int,
is_scale_transposed: bool,
seed: int,
device: str,
device_idx: str,
) -> None:
"""Test SiLU+Mul+Block Quantization kernel correctness."""
torch.accelerator.set_device_index(device_idx)
device = f"cuda:{device_idx}"
torch.random.manual_seed(seed)
torch.set_default_device(device)
@@ -147,7 +147,7 @@ def test_silu_block_quant_shapes(
out, scales = ops.silu_and_mul_per_block_quant(
x,
group_size=group_size,
quant_dtype=torch.float8_e4m3fn,
quant_dtype=current_platform.fp8_dtype(),
is_scale_transposed=False,
)
assert out.shape == (num_tokens, hidden_size)
@@ -157,7 +157,7 @@ def test_silu_block_quant_shapes(
out, scales = ops.silu_and_mul_per_block_quant(
x,
group_size=group_size,
quant_dtype=torch.float8_e4m3fn,
quant_dtype=current_platform.fp8_dtype(),
is_scale_transposed=True,
)
assert out.shape == (num_tokens, hidden_size)
@@ -177,12 +177,12 @@ def test_silu_block_quant_edge_cases(
out, scales = ops.silu_and_mul_per_block_quant(
x,
group_size=128,
quant_dtype=torch.float8_e4m3fn,
quant_dtype=current_platform.fp8_dtype(),
is_scale_transposed=False,
)
assert out.shape == (batch_size, hidden_size)
assert out.dtype == torch.float8_e4m3fn
assert out.dtype == current_platform.fp8_dtype()
assert scales.dtype == torch.float32
assert not torch.isnan(out.float()).any()
assert not torch.isnan(scales).any()