Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -6,11 +6,7 @@ import torch
# Reference default values of atol and rtol are from
# https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67
default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5}
default_rtol = {
torch.float16: 1e-3,
torch.bfloat16: 1.6e-2,
torch.float: 1.3e-6
}
default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float: 1.3e-6}
def get_default_atol(output) -> float:

View File

@@ -3,8 +3,7 @@
import pytest
from vllm.utils import (create_kv_caches_with_random,
create_kv_caches_with_random_flash)
from vllm.utils import create_kv_caches_with_random, create_kv_caches_with_random_flash
@pytest.fixture()

View File

@@ -39,7 +39,7 @@ def ref_paged_attn(
for i in range(num_seqs):
query_len = query_lens[i]
kv_len = kv_lens[i]
q = query[start_idx:start_idx + query_len]
q = query[start_idx : start_idx + query_len]
q *= scale
num_kv_blocks = (kv_len + block_size - 1) // block_size
@@ -57,10 +57,13 @@ def ref_paged_attn(
empty_mask = torch.ones(query_len, kv_len)
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
if sliding_window is not None:
sliding_window_mask = torch.triu(empty_mask,
diagonal=kv_len -
(query_len + sliding_window) +
1).bool().logical_not()
sliding_window_mask = (
torch.triu(
empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1
)
.bool()
.logical_not()
)
mask |= sliding_window_mask
if soft_cap is not None:
attn = soft_cap * torch.tanh(attn / soft_cap)
@@ -74,11 +77,10 @@ def ref_paged_attn(
return torch.cat(outputs, dim=0)
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="Only ROCm is supported")
@pytest.mark.parametrize("seq_lens",
[[(10, 1328), (5, 18),
(129, 463)], [(8, 523), (24, 37), (3, 2011)]])
@pytest.mark.skipif(not current_platform.is_rocm(), reason="Only ROCm is supported")
@pytest.mark.parametrize(
"seq_lens", [[(10, 1328), (5, 18), (129, 463)], [(8, 523), (24, 37), (3, 2011)]]
)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@@ -109,34 +111,27 @@ def test_varlen_with_paged_kv(
assert num_query_heads % num_kv_heads == 0
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
(-1, -1))
window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1)
scale = head_size**-0.5
query = torch.randn(sum(query_lens),
num_query_heads,
head_size,
dtype=dtype)
key_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype)
key_cache = torch.randn(
num_blocks, block_size, num_kv_heads, head_size, dtype=dtype
)
value_cache = torch.randn_like(key_cache)
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
dim=0, dtype=torch.int32
)
cu_seq_lens = torch.tensor([0] + kv_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
cu_seq_lens = torch.tensor([0] + kv_lens, dtype=torch.int32).cumsum(
dim=0, dtype=torch.int32
)
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
block_tables = torch.randint(
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
)
output = torch.empty_like(query)
@@ -187,5 +182,7 @@ def test_varlen_with_paged_kv(
atol, rtol = 2e-2, 2e-2
if q_dtype is not None:
atol, rtol = 1.5e-1, 1.5e-1
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - ref_output))}"
(
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
f"{torch.max(torch.abs(output - ref_output))}",
)

View File

@@ -42,9 +42,7 @@ BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True]
KV_CACHE_DTYPE = ["auto", "fp8"]
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
def ref_masked_attention(
@@ -110,8 +108,7 @@ def ref_single_query_cached_kv_attention(
# Create the ALiBi bias used in the paged attention kernel.
position_ids = torch.arange(seq_len).int()
alibi_bias = (position_ids - seq_len + 1).float()
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
1, 1, -1)
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1)
out = ref_masked_attention(q, keys, values, scale, alibi_bias)
out = out.view(num_query_heads, head_size)
@@ -119,8 +116,8 @@ def ref_single_query_cached_kv_attention(
@pytest.mark.parametrize(
"version",
["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"])
"version", ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"]
)
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@@ -143,13 +140,18 @@ def test_paged_attention(
seed: int,
device: str,
) -> None:
if ((kv_cache_dtype == "fp8" and head_size % 16)
or (version == "rocm" and head_size not in (64, 128))):
if (kv_cache_dtype == "fp8" and head_size % 16) or (
version == "rocm" and head_size not in (64, 128)
):
pytest.skip()
if (version == "rocm" and current_platform.is_navi()
and (kv_cache_dtype == "fp8" or head_size != 128
or block_size != 16 or use_alibi)):
if (
version == "rocm"
and current_platform.is_navi()
and (
kv_cache_dtype == "fp8" or head_size != 128 or block_size != 16 or use_alibi
)
):
pytest.skip()
global PARTITION_SIZE
@@ -177,18 +179,24 @@ def test_paged_attention(
block_tables_lst: list[list[int]] = []
for _ in range(num_seqs):
block_table = [
random.randint(0, NUM_BLOCKS - 1)
for _ in range(max_num_blocks_per_seq)
random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq)
]
block_tables_lst.append(block_table)
block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
num_kv_heads, head_size,
kv_cache_dtype, dtype, seed,
device)
key_caches, value_caches = kv_cache_factory(
NUM_BLOCKS,
block_size,
1,
num_kv_heads,
head_size,
kv_cache_dtype,
dtype,
seed,
device,
)
key_cache, value_cache = key_caches[0], value_caches[0]
# Using default kv_scale
@@ -214,18 +222,37 @@ def test_paged_attention(
v_scale,
)
opcheck(torch.ops._C.paged_attention_v1,
(output, query, key_cache, value_cache, num_kv_heads, scale,
block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
opcheck(
torch.ops._C.paged_attention_v1,
(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
0,
0,
0,
64,
0,
),
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
)
elif version in ("v2", "rocm"):
if current_platform.is_rocm() and version == "rocm":
PARTITION_SIZE = PARTITION_SIZE_ROCM
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
assert PARTITION_SIZE % block_size == 0
num_seqs, num_heads, head_size = output.shape
tmp_output = torch.empty(
@@ -258,13 +285,34 @@ def test_paged_attention(
v_scale,
)
opcheck(torch.ops._C.paged_attention_v2,
(output, exp_sums, max_logits, tmp_output, query,
key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
opcheck(
torch.ops._C.paged_attention_v2,
(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
0,
0,
0,
64,
0,
),
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
)
else:
ops.paged_attention_rocm(
@@ -288,13 +336,30 @@ def test_paged_attention(
v_scale,
)
opcheck(torch.ops._rocm_C.paged_attention,
(output, exp_sums, max_logits, tmp_output, query,
key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, None, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
opcheck(
torch.ops._rocm_C.paged_attention,
(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
None,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
),
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
)
else:
raise AssertionError(f"Unknown version: {version}")
@@ -303,18 +368,17 @@ def test_paged_attention(
if kv_cache_dtype == "fp8":
# Convert cache data back to dtype.
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
block_size, x)
dequantized_key_cache = torch.empty(size=key_cache_shape,
dtype=dtype,
device=device)
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
dequantized_key_cache = torch.empty(
size=key_cache_shape, dtype=dtype, device=device
)
ops.convert_fp8(dequantized_key_cache, key_cache)
key_cache = dequantized_key_cache
value_cache_shape = value_cache.shape
dequantized_value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device=device)
dequantized_value_cache = torch.empty(
size=value_cache_shape, dtype=dtype, device=device
)
ops.convert_fp8(dequantized_value_cache, value_cache)
value_cache = dequantized_value_cache
@@ -367,8 +431,9 @@ def ref_multi_query_kv_attention(
if alibi_bias:
attn_mask = alibi_bias[i]
else:
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
diagonal=1)
attn_mask = torch.triu(
torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1
)
attn_mask = attn_mask * torch.finfo(dtype).min
attn_mask = attn_mask.to(dtype=dtype)
@@ -390,8 +455,9 @@ def ref_multi_query_kv_attention(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
@pytest.mark.skipif(
current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
)
@torch.inference_mode()
def test_multi_query_kv_attention(
num_seqs: int,
@@ -413,13 +479,11 @@ def test_multi_query_kv_attention(
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
qkv = torch.empty(num_tokens,
num_query_heads + 2 * num_kv_heads,
head_size,
dtype=dtype)
qkv = torch.empty(
num_tokens, num_query_heads + 2 * num_kv_heads, head_size, dtype=dtype
)
qkv.uniform_(-scale, scale)
query, key, value = qkv.split(
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)
query, key, value = qkv.split([num_query_heads, num_kv_heads, num_kv_heads], dim=1)
num_queries_per_kv = num_query_heads // num_kv_heads
if num_queries_per_kv > 1:
@@ -429,8 +493,7 @@ def test_multi_query_kv_attention(
alibi_bias = None
if use_alibi:
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype,
seq_lens)
attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
output = torch.empty_like(query)
start = 0
# Dynamic sequence length not supported with custom attn_bias.
@@ -442,7 +505,8 @@ def test_multi_query_kv_attention(
value[None, start:end],
attn_bias=attn_bias[i],
p=0.0,
scale=scale)
scale=scale,
)
output[start:end].copy_(out.view_as(query[start:end]))
start += seq_len
# xformers.AttentionBias to Tensor for use in reference impl.
@@ -485,8 +549,9 @@ def test_multi_query_kv_attention(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
@pytest.mark.skipif(
current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
)
@torch.inference_mode()
def test_multi_query_kv_attention_with_alibi(
num_seqs: int,

View File

@@ -15,16 +15,18 @@ from vllm.utils import STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_INVALID_VAL
@pytest.fixture(autouse=True)
def clear_cache():
"""Clear lru cache to ensure each test case runs without caching.
"""
"""Clear lru cache to ensure each test case runs without caching."""
_cached_get_attn_backend.cache_clear()
# Define MLA and non-MLA backends separately
DEVICE_MLA_BACKENDS = {
"cuda": [
"TRITON_MLA", "FLASHMLA", "FLASHINFER_MLA", "FLASH_ATTN_MLA",
"CUTLASS_MLA"
"TRITON_MLA",
"FLASHMLA",
"FLASHINFER_MLA",
"FLASH_ATTN_MLA",
"CUTLASS_MLA",
],
"hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
"cpu": [],
@@ -40,7 +42,7 @@ DEVICE_MLA_BLOCK_SIZES = {
"cuda": [16, 64], # CUDA supports both standard and extended block sizes
"hip": [16, 1], # HIP requires special handling for block_size=1
# "cpu": [16] # CPU uses fixed block size from test cases
"cpu": [] # FIXME(woosuk): Temporarily disable CPU tests
"cpu": [], # FIXME(woosuk): Temporarily disable CPU tests
}
@@ -48,12 +50,13 @@ def generate_params():
params = []
for use_mla in [True, False]:
for device in ["cuda", "hip", "cpu"]:
backends = DEVICE_MLA_BACKENDS[
device] if use_mla else DEVICE_REGULAR_ATTN_BACKENDS[device]
backends = (
DEVICE_MLA_BACKENDS[device]
if use_mla
else DEVICE_REGULAR_ATTN_BACKENDS[device]
)
for name in backends:
block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [
16
]
block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [16]
for block_size in block_sizes:
params.append(
pytest.param(
@@ -61,14 +64,13 @@ def generate_params():
name,
use_mla,
block_size,
id=
f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}"
))
id=f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}",
)
)
return params
@pytest.mark.parametrize("device, name, use_mla, block_size",
generate_params())
@pytest.mark.parametrize("device, name, use_mla, block_size", generate_params())
def test_env(
device: str,
name: str,
@@ -83,14 +85,12 @@ def test_env(
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
if device == "cpu":
with patch("vllm.attention.selector.current_platform",
CpuPlatform()):
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
backend = get_attn_backend(16, torch.float16, None, block_size)
assert backend.get_name() == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.current_platform",
RocmPlatform()):
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
if use_mla:
# ROCm MLA backend logic:
# - TRITON_MLA: supported when block_size != 1
@@ -101,44 +101,33 @@ def test_env(
if name == "TRITON_MLA" and block_size == 1:
# TRITON_MLA doesn't support block_size == 1
with pytest.raises(ValueError) as exc_info:
get_attn_backend(16,
torch.float16,
None,
block_size,
use_mla=use_mla)
assert f"The selected backend, {name}" in str(
exc_info.value)
get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
)
assert f"The selected backend, {name}" in str(exc_info.value)
elif name == "ROCM_AITER_MLA" and block_size != 1:
# ROCM_AITER_MLA only supports block_size == 1
with pytest.raises(ValueError) as exc_info:
get_attn_backend(16,
torch.float16,
None,
block_size,
use_mla=use_mla)
assert f"The selected backend, {name}" in str(
exc_info.value)
get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
)
assert f"The selected backend, {name}" in str(exc_info.value)
else:
# Valid backend-block_size combination
backend = get_attn_backend(16,
torch.float16,
None,
block_size,
use_mla=use_mla)
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
)
expected = name
assert backend.get_name() == expected
else:
backend = get_attn_backend(16,
torch.float16,
None,
block_size,
use_mla=use_mla)
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
)
expected = "TRITON_ATTN"
assert backend.get_name() == expected
elif device == "cuda":
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
if use_mla:
# CUDA MLA backend logic:
# - CUTLASS_MLA: only supported with block_size == 128
@@ -152,28 +141,23 @@ def test_env(
if name == "CUTLASS_MLA":
if block_size != 128:
# CUTLASS_MLA only supports block_size == 128
pytest.skip(
"CUTLASS_MLA only supports block_size 128")
pytest.skip("CUTLASS_MLA only supports block_size 128")
else:
backend = get_attn_backend(16,
torch.float16,
None,
block_size,
use_mla=use_mla)
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
)
expected = "CUTLASS_MLA"
assert backend.get_name() == expected
elif name == "FLASHINFER_MLA":
if block_size not in [32, 64]:
# FlashInfer MLA only supports block_size 32 or 64
pytest.skip(
"FlashInfer MLA only supports block_size 32 "
"or 64")
"FlashInfer MLA only supports block_size 32 or 64"
)
else:
backend = get_attn_backend(16,
torch.float16,
None,
block_size,
use_mla=use_mla)
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
)
expected = "FLASHINFER_MLA"
assert backend.get_name() == expected
elif name == "FLASHMLA":
@@ -182,58 +166,47 @@ def test_env(
pytest.skip("FlashMLA only supports block_size 64")
else:
from vllm.v1.attention.backends.mla.flashmla import ( # noqa: E501
is_flashmla_supported)
is_flashmla_supported,
)
is_supported, _ = is_flashmla_supported()
if not is_supported:
pytest.skip(
"FlashMLA not supported on this platform")
pytest.skip("FlashMLA not supported on this platform")
else:
backend = get_attn_backend(16,
torch.float16,
None,
block_size,
use_mla=use_mla)
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
)
expected = name
assert backend.get_name() == expected
elif name == "FLASH_ATTN_MLA":
backend = get_attn_backend(16,
torch.float16,
None,
block_size,
use_mla=use_mla)
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
)
expected = "FLASH_ATTN_MLA"
assert backend.get_name() == expected
else:
# TRITON_MLA or other fallback
backend = get_attn_backend(16,
torch.float16,
None,
block_size,
use_mla=use_mla)
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
)
expected = "TRITON_MLA"
assert backend.get_name() == expected
elif name == "FLASHINFER":
backend = get_attn_backend(16,
torch.float16,
None,
block_size,
use_mla=use_mla)
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
)
expected = "FLASHINFER"
assert backend.get_name() == expected
elif name == "XFORMERS":
backend = get_attn_backend(32,
torch.float16,
None,
block_size,
use_mla=use_mla)
backend = get_attn_backend(
32, torch.float16, None, block_size, use_mla=use_mla
)
expected = "XFORMERS"
assert backend.get_name() == expected
elif name == "FLASH_ATTN":
backend = get_attn_backend(32,
torch.float16,
None,
block_size,
use_mla=use_mla)
backend = get_attn_backend(
32, torch.float16, None, block_size, use_mla=use_mla
)
expected = "FLASH_ATTN"
assert backend.get_name() == expected
@@ -248,14 +221,12 @@ def test_fp32_fallback(
m.setenv("VLLM_USE_V1", "1")
if device == "cpu":
with patch("vllm.attention.selector.current_platform",
CpuPlatform()):
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
backend = get_attn_backend(16, torch.float32, None, 16)
assert backend.get_name() == "TORCH_SDPA"
elif device == "cuda":
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
backend = get_attn_backend(16, torch.float32, None, 16)
assert backend.get_name() == "FLEX_ATTENTION"
@@ -265,16 +236,16 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
# TODO: When testing for v1, pipe in `use_v1` as an argument to
# get_attn_backend
pytest.skip("Skipping as current backend selector does not " \
"handle fallbacks when a backend is set via env var.")
pytest.skip(
"Skipping as current backend selector does not "
"handle fallbacks when a backend is set via env var."
)
with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL)
# Unsupported CUDA arch
monkeypatch.setattr(torch.cuda,
"get_device_capability",
lambda _=None: (7, 5))
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
backend = get_attn_backend(16, torch.float16, None, 16)
assert backend.get_name() != STR_FLASH_ATTN_VAL
@@ -295,17 +266,17 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
# flash-attn is not installed
import sys
original_module = sys.modules.get('vllm_flash_attn')
monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None)
original_module = sys.modules.get("vllm_flash_attn")
monkeypatch.setitem(sys.modules, "vllm_flash_attn", None)
backend = get_attn_backend(16, torch.float16, None, 16)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Restore the original module if it existed
if original_module is not None:
monkeypatch.setitem(sys.modules, 'vllm_flash_attn',
original_module)
monkeypatch.setitem(sys.modules, "vllm_flash_attn", original_module)
else:
monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False)
monkeypatch.delitem(sys.modules, "vllm_flash_attn", raising=False)
# Unsupported head size
backend = get_attn_backend(17, torch.float16, None, 16)
@@ -314,8 +285,10 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
"""Test that invalid attention backend names raise ValueError."""
with monkeypatch.context() as m, patch(
"vllm.attention.selector.current_platform", CudaPlatform()):
with (
monkeypatch.context() as m,
patch("vllm.attention.selector.current_platform", CudaPlatform()),
):
m.setenv("VLLM_USE_V1", "1")
m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)

View File

@@ -10,7 +10,7 @@ from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
COPYING_DIRECTION = [("cuda", "cpu"), ("cuda", "cuda"), ("cpu", "cuda")]
DTYPES = [torch.bfloat16, torch.float]
NUM_TOKENS = [42] # Arbitrary values for testing
NUM_LAYERS = [1] # Arbitrary values for testing
@@ -32,9 +32,7 @@ NUM_BLOCKS = [1024, 10000]
NUM_MAPPINGS = [256] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
# We assume fp8 is always enabled for testing.
KV_CACHE_DTYPE = ["auto", "fp8"]
@@ -85,24 +83,33 @@ def test_copy_blocks(
block_mapping.append((src, dst2))
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
num_layers, num_heads,
head_size, kv_cache_dtype,
dtype, seed, device)
key_caches, value_caches = kv_cache_factory(
num_blocks,
block_size,
num_layers,
num_heads,
head_size,
kv_cache_dtype,
dtype,
seed,
device,
)
# Clone the KV caches.
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
# Call the copy blocks kernel.
block_mapping_tensor = torch.tensor(block_mapping,
dtype=torch.int64,
device=device).view(-1, 2)
block_mapping_tensor = torch.tensor(
block_mapping, dtype=torch.int64, device=device
).view(-1, 2)
opcheck(torch.ops._C_cache_ops.copy_blocks,
(key_caches, value_caches, block_mapping_tensor),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
cond=(head_size == HEAD_SIZES[0]))
opcheck(
torch.ops._C_cache_ops.copy_blocks,
(key_caches, value_caches, block_mapping_tensor),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
cond=(head_size == HEAD_SIZES[0]),
)
ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
# Run the reference implementation.
@@ -115,8 +122,7 @@ def test_copy_blocks(
# Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
torch.testing.assert_close(key_cache, cloned_key_cache)
for value_cache, cloned_value_cache in zip(value_caches,
cloned_value_caches):
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
torch.testing.assert_close(value_cache, cloned_value_cache)
@@ -155,10 +161,17 @@ def test_reshape_and_cache(
_, key, value = qkv.unbind(dim=1)
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
num_heads, head_size,
kv_cache_dtype, dtype, seed,
device)
key_caches, value_caches = kv_cache_factory(
num_blocks,
block_size,
1,
num_heads,
head_size,
kv_cache_dtype,
dtype,
seed,
device,
)
key_cache, value_cache = key_caches[0], value_caches[0]
# Using default kv_scale
@@ -176,12 +189,30 @@ def test_reshape_and_cache(
cloned_value_cache = value_cache.clone()
# Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache,
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]))
ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
kv_cache_dtype, k_scale, v_scale)
opcheck(
torch.ops._C_cache_ops.reshape_and_cache,
(
key,
value,
key_cache,
value_cache,
slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
),
cond=(head_size == HEAD_SIZES[0]),
)
ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
)
if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
@@ -202,14 +233,12 @@ def test_reshape_and_cache(
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
if kv_cache_dtype == "fp8":
torch.testing.assert_close(result_key_cache,
cloned_key_cache,
atol=0.001,
rtol=0.1)
torch.testing.assert_close(result_value_cache,
cloned_value_cache,
atol=0.001,
rtol=0.1)
torch.testing.assert_close(
result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
)
torch.testing.assert_close(
result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
)
else:
torch.testing.assert_close(key_cache, cloned_key_cache)
torch.testing.assert_close(value_cache, cloned_value_cache)
@@ -254,15 +283,8 @@ def test_reshape_and_cache_flash(
# Create a random slot mapping.
num_slots = block_size * num_blocks
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long,
device=device)
qkv = torch.randn(num_tokens,
3,
num_heads,
head_size,
dtype=dtype,
device=device)
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device)
_, key, value = qkv.unbind(dim=1)
# Create the KV caches.
@@ -293,48 +315,73 @@ def test_reshape_and_cache_flash(
# Clone the KV caches.
if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache_compact,
dtype=torch.float16)
ops.convert_fp8(cloned_key_cache, key_cache_compact, k_scale.item(),
kv_cache_dtype)
cloned_value_cache = torch.empty_like(value_cache_compact,
dtype=torch.float16)
ops.convert_fp8(cloned_value_cache, value_cache_compact,
v_scale.item(), kv_cache_dtype)
cloned_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16)
ops.convert_fp8(
cloned_key_cache, key_cache_compact, k_scale.item(), kv_cache_dtype
)
cloned_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16)
ops.convert_fp8(
cloned_value_cache, value_cache_compact, v_scale.item(), kv_cache_dtype
)
else:
cloned_key_cache = key_cache_compact.clone()
cloned_value_cache = value_cache_compact.clone()
# Call the reshape_and_cache kernel.
if implementation == "cuda":
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
(key, value, key_cache, value_cache, slot_mapping,
kv_cache_dtype, k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]))
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, k_scale,
v_scale)
opcheck(
torch.ops._C_cache_ops.reshape_and_cache_flash,
(
key,
value,
key_cache,
value_cache,
slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
),
cond=(head_size == HEAD_SIZES[0]),
)
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
)
elif implementation == "triton":
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash)
triton_reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, k_scale,
v_scale)
triton_reshape_and_cache_flash,
)
triton_reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
)
key_cache_compact = permute_and_compact(key_cache)
value_cache_compact = permute_and_compact(value_cache)
if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache_compact,
dtype=torch.float16)
ops.convert_fp8(result_key_cache,
key_cache_compact,
k_scale.item(),
kv_dtype=kv_cache_dtype)
result_value_cache = torch.empty_like(value_cache_compact,
dtype=torch.float16)
ops.convert_fp8(result_value_cache,
value_cache_compact,
v_scale.item(),
kv_dtype=kv_cache_dtype)
result_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16)
ops.convert_fp8(
result_key_cache, key_cache_compact, k_scale.item(), kv_dtype=kv_cache_dtype
)
result_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16)
ops.convert_fp8(
result_value_cache,
value_cache_compact,
v_scale.item(),
kv_dtype=kv_cache_dtype,
)
# Run the reference implementation.
block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
@@ -352,14 +399,12 @@ def test_reshape_and_cache_flash(
cloned_value_cache[block_idx, :, block_offset, :] = value[i]
if kv_cache_dtype == "fp8":
torch.testing.assert_close(result_key_cache,
cloned_key_cache,
atol=0.001,
rtol=0.1)
torch.testing.assert_close(result_value_cache,
cloned_value_cache,
atol=0.001,
rtol=0.1)
torch.testing.assert_close(
result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
)
torch.testing.assert_close(
result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
)
else:
torch.testing.assert_close(key_cache_compact, cloned_key_cache)
torch.testing.assert_close(value_cache_compact, cloned_value_cache)
@@ -396,8 +441,8 @@ def test_swap_blocks(
current_platform.seed_everything(seed)
src_device = device if direction[0] == "cuda" else 'cpu'
dst_device = device if direction[1] == "cuda" else 'cpu'
src_device = device if direction[0] == "cuda" else "cpu"
dst_device = device if direction[1] == "cuda" else "cpu"
src_blocks = random.sample(range(num_blocks), num_mappings)
# For the same device, mapping must not overlap
@@ -408,42 +453,62 @@ def test_swap_blocks(
dst_blocks = random.sample(range(num_blocks), num_mappings)
block_mapping = list(zip(src_blocks, dst_blocks))
block_mapping_tensor = torch.tensor(block_mapping,
dtype=torch.int64,
device="cpu").view(-1, 2)
block_mapping_tensor = torch.tensor(
block_mapping, dtype=torch.int64, device="cpu"
).view(-1, 2)
# Create the KV caches on the first device.
src_key_caches, src_value_caches = kv_cache_factory(
num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
seed, src_device)
num_blocks,
block_size,
1,
num_heads,
head_size,
kv_cache_dtype,
dtype,
seed,
src_device,
)
# Create the KV caches on the second device.
dist_key_caches, dist_value_caches = kv_cache_factory(
num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
seed, dst_device)
num_blocks,
block_size,
1,
num_heads,
head_size,
kv_cache_dtype,
dtype,
seed,
dst_device,
)
src_key_caches_clone = src_key_caches[0].clone()
src_value_caches_clone = src_value_caches[0].clone()
# Call the swap_blocks kernel.
do_opcheck = (head_size == HEAD_SIZES[0])
opcheck(torch.ops._C_cache_ops.swap_blocks,
(src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
cond=do_opcheck)
opcheck(torch.ops._C_cache_ops.swap_blocks,
(src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
cond=do_opcheck)
do_opcheck = head_size == HEAD_SIZES[0]
opcheck(
torch.ops._C_cache_ops.swap_blocks,
(src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
cond=do_opcheck,
)
opcheck(
torch.ops._C_cache_ops.swap_blocks,
(src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
cond=do_opcheck,
)
ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
block_mapping_tensor)
ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
block_mapping_tensor)
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping_tensor)
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping_tensor)
for src, dst in block_mapping:
torch.testing.assert_close(src_key_caches_clone[src].cpu(),
dist_key_caches[0][dst].cpu())
torch.testing.assert_close(src_value_caches_clone[src].cpu(),
dist_value_caches[0][dst].cpu())
torch.testing.assert_close(
src_key_caches_clone[src].cpu(), dist_key_caches[0][dst].cpu()
)
torch.testing.assert_close(
src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu()
)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@@ -489,11 +554,9 @@ def _create_mla_cache(
device: str,
) -> torch.Tensor:
cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype
return torch.zeros(num_blocks,
block_size,
entry_size,
dtype=cache_dtype,
device=device)
return torch.zeros(
num_blocks, block_size, entry_size, dtype=cache_dtype, device=device
)
def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
@@ -533,20 +596,16 @@ def test_concat_and_cache_mla(
total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long,
device=device)
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
k_pe = torch.randn(num_tokens,
qk_rope_head_dim,
dtype=dtype,
device=device)
k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
entry_size = kv_lora_rank + qk_rope_head_dim
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device)
kv_cache = _create_mla_cache(
num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
)
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
for i in range(num_tokens):
@@ -558,10 +617,7 @@ def test_concat_and_cache_mla(
if kv_cache_dtype == "fp8":
ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
ops.convert_fp8(ref_kv_cache,
ref_temp,
scale.item(),
kv_dtype=kv_cache_dtype)
ops.convert_fp8(ref_kv_cache, ref_temp, scale.item(), kv_dtype=kv_cache_dtype)
else:
ref_kv_cache = ref_temp
@@ -571,24 +627,18 @@ def test_concat_and_cache_mla(
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
kv_cache_dtype, scale)
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
if kv_cache_dtype == "fp8":
result_temp = torch.empty_like(kv_cache, dtype=torch.float16)
ops.convert_fp8(result_temp,
kv_cache.contiguous(),
scale.item(),
kv_dtype=kv_cache_dtype)
ops.convert_fp8(
result_temp, kv_cache.contiguous(), scale.item(), kv_dtype=kv_cache_dtype
)
expected_temp = torch.empty_like(ref_kv_cache, dtype=torch.float16)
ops.convert_fp8(expected_temp,
ref_kv_cache,
scale.item(),
kv_dtype=kv_cache_dtype)
torch.testing.assert_close(result_temp,
expected_temp,
atol=0.001,
rtol=0.1)
ops.convert_fp8(
expected_temp, ref_kv_cache, scale.item(), kv_dtype=kv_cache_dtype
)
torch.testing.assert_close(result_temp, expected_temp, atol=0.001, rtol=0.1)
else:
torch.testing.assert_close(kv_cache, ref_kv_cache)
@@ -620,24 +670,21 @@ def test_concat_and_cache_ds_mla(
total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long,
device=device)
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
k_pe = torch.randn(num_tokens,
qk_rope_head_dim,
dtype=dtype,
device=device)
k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim)
scale = torch.tensor(1.0, dtype=torch.float32, device=device)
kv_cache = _create_mla_cache(num_blocks,
block_size,
entry_size,
dtype=torch.uint8,
kv_cache_dtype=kv_cache_dtype,
device=device)
kv_cache = _create_mla_cache(
num_blocks,
block_size,
entry_size,
dtype=torch.uint8,
kv_cache_dtype=kv_cache_dtype,
device=device,
)
ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype)
tile_data = torch.zeros(128, dtype=dtype, device=device)
@@ -664,14 +711,16 @@ def test_concat_and_cache_ds_mla(
manual_max = abs(tile_data_float[0])
for j in range(1, 128):
manual_max = max(manual_max, abs(tile_data_float[j]))
tile_scale = manual_max / 448.
tile_scale = manual_max / 448.0
ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale
ops.convert_fp8(ref_cache_slice[tile_start:tile_end],
tile_data,
tile_scale.item(),
kv_dtype="fp8")
ops.convert_fp8(
ref_cache_slice[tile_start:tile_end],
tile_data,
tile_scale.item(),
kv_dtype="fp8",
)
for j in range(qk_rope_head_dim):
ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j]
@@ -682,8 +731,7 @@ def test_concat_and_cache_ds_mla(
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
kv_cache_dtype, scale)
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
for i in range(num_tokens):
slot = slot_mapping[i].item()
@@ -694,12 +742,14 @@ def test_concat_and_cache_ds_mla(
kv_nope = kv_cache_slice[:kv_lora_rank]
ref_nope = ref_cache_slice[:kv_lora_rank]
kv_scales = kv_cache_slice.view(torch.float32)[kv_lora_rank //
4:kv_lora_rank // 4 + 4]
ref_scales = ref_cache_slice.view(
torch.float32)[kv_lora_rank // 4:kv_lora_rank // 4 + 4]
kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
kv_scales = kv_cache_slice.view(torch.float32)[
kv_lora_rank // 4 : kv_lora_rank // 4 + 4
]
ref_scales = ref_cache_slice.view(torch.float32)[
kv_lora_rank // 4 : kv_lora_rank // 4 + 4
]
kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8 :]
ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8 :]
torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1)
torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1)
@@ -734,8 +784,9 @@ def test_copy_blocks_mla(
kv_caches = []
for _ in range(num_layers):
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device)
kv_cache = _create_mla_cache(
num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
)
_fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype)
kv_caches.append(kv_cache)
@@ -752,9 +803,9 @@ def test_copy_blocks_mla(
dst2 = dst_blocks[2 * i + 1]
block_mapping.append((src, dst1))
block_mapping.append((src, dst2))
block_mapping_tensor = torch.tensor(block_mapping,
dtype=torch.int64,
device=device).view(-1, 2)
block_mapping_tensor = torch.tensor(
block_mapping, dtype=torch.int64, device=device
).view(-1, 2)
for src, dst in block_mapping:
for ref_cache in ref_caches:
@@ -795,10 +846,12 @@ def test_swap_blocks_mla(
entry_size = kv_lora_rank + qk_rope_head_dim
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device)
dst_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device)
src_cache = _create_mla_cache(
num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
)
dst_cache = _create_mla_cache(
num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
)
_fill_mla_cache(src_cache, kv_cache_dtype)
_fill_mla_cache(dst_cache, kv_cache_dtype)
@@ -810,9 +863,9 @@ def test_swap_blocks_mla(
remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remaining_blocks, num_mappings)
block_mapping = list(zip(src_blocks, dst_blocks))
block_mapping_tensor = torch.tensor(block_mapping,
dtype=torch.int64,
device="cpu").view(-1, 2)
block_mapping_tensor = torch.tensor(
block_mapping, dtype=torch.int64, device="cpu"
).view(-1, 2)
opcheck(
torch.ops._C_cache_ops.swap_blocks,
@@ -827,7 +880,8 @@ def test_swap_blocks_mla(
src_cache_clone[src].cpu(),
dst_cache[dst].cpu(),
msg=f"Block {src} from src should have been swapped to block "
f"{dst} in dst_cache.")
f"{dst} in dst_cache.",
)
@pytest.mark.parametrize("kv_lora_rank", [512])
@@ -840,32 +894,36 @@ def test_swap_blocks_mla(
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
block_size, num_blocks,
max_seq_len, batch_size, dtype,
kv_cache_dtype, device):
def test_gather_and_maybe_dequant_cache_mla(
kv_lora_rank,
qk_rope_head_dim,
block_size,
num_blocks,
max_seq_len,
batch_size,
dtype,
kv_cache_dtype,
device,
):
entry_size = kv_lora_rank + qk_rope_head_dim
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device)
src_cache = _create_mla_cache(
num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
)
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
seq_len_tensor = torch.randint(0,
max_seq_len + 1, (batch_size, ),
device=device)
seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size,), device=device)
total_tokens = seq_len_tensor.sum()
cu_seq_lens = torch.empty((batch_size + 1),
dtype=torch.int32,
device=device)
cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device)
cu_seq_lens[0] = 0
cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
print("seq_len_tensor", seq_len_tensor)
tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
block_table = torch.empty((batch_size, num_blocks),
dtype=torch.int32,
device=device)
block_table = torch.empty(
(batch_size, num_blocks), dtype=torch.int32, device=device
)
for b in range(batch_size):
perm = torch.randperm(num_blocks, device=device)
@@ -893,10 +951,8 @@ def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
remaining = s - (tot - 1) * block_size
last_block_data = src_cache[blocks[-1], :remaining, :]
if kv_cache_dtype == "fp8":
dequantized_last_block = torch.empty_like(last_block_data,
dtype=dtype)
ops.convert_fp8(dequantized_last_block, last_block_data,
scale.item())
dequantized_last_block = torch.empty_like(last_block_data, dtype=dtype)
ops.convert_fp8(dequantized_last_block, last_block_data, scale.item())
gathered_rows.append(dequantized_last_block)
else:
gathered_rows.append(last_block_data)
@@ -907,14 +963,29 @@ def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
opcheck(
torch.ops._C_cache_ops.gather_and_maybe_dequant_cache,
(src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype,
scale, None),
(
src_cache,
dst,
block_table,
cu_seq_lens,
batch_size,
kv_cache_dtype,
scale,
None,
),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.gather_and_maybe_dequant_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, kv_cache_dtype,
scale, None)
ops.gather_and_maybe_dequant_cache(
src_cache,
dst,
block_table,
cu_seq_lens,
batch_size,
kv_cache_dtype,
scale,
None,
)
torch.testing.assert_close(dst, expected)
@@ -925,42 +996,46 @@ def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
@pytest.mark.parametrize("max_seq_len", [512])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype",
["auto"]) # You can also test "fp8" if needed.
@pytest.mark.parametrize(
"kv_cache_dtype", ["auto"]
) # You can also test "fp8" if needed.
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_cp_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
num_blocks, max_seq_len, batch_size, dtype,
kv_cache_dtype, device):
def test_cp_gather_cache_mla(
kv_lora_rank,
qk_rope_head_dim,
block_size,
num_blocks,
max_seq_len,
batch_size,
dtype,
kv_cache_dtype,
device,
):
entry_size = kv_lora_rank + qk_rope_head_dim
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device)
src_cache = _create_mla_cache(
num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
)
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
seq_len_tensor = torch.randint(0,
max_seq_len + 1, (batch_size, ),
device=device)
seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size,), device=device)
total_tokens = seq_len_tensor.sum()
cu_seq_lens = torch.empty((batch_size + 1),
dtype=torch.int32,
device=device)
cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device)
cu_seq_lens[0] = 0
cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
print("seq_len_tensor", seq_len_tensor)
tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
block_table = torch.empty((batch_size, num_blocks),
dtype=torch.int32,
device=device)
block_table = torch.empty(
(batch_size, num_blocks), dtype=torch.int32, device=device
)
for b in range(batch_size):
perm = torch.randperm(num_blocks, device=device)
block_table[b, :] = perm
dst = torch.zeros((total_tokens, entry_size),
dtype=src_cache.dtype,
device=device)
dst = torch.zeros((total_tokens, entry_size), dtype=src_cache.dtype, device=device)
expected_batches = []
for b in range(batch_size):
@@ -1016,20 +1091,16 @@ def test_concat_and_cache_mla_cpu(
total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long,
device=device)
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
k_pe = torch.randn(num_tokens,
qk_rope_head_dim,
dtype=dtype,
device=device)
k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
entry_size = kv_lora_rank + qk_rope_head_dim
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device)
kv_cache = _create_mla_cache(
num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
)
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
for i in range(num_tokens):
@@ -1041,10 +1112,7 @@ def test_concat_and_cache_mla_cpu(
if kv_cache_dtype == "fp8":
ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
ops.convert_fp8(ref_kv_cache,
ref_temp,
scale.item(),
kv_dtype=kv_cache_dtype)
ops.convert_fp8(ref_kv_cache, ref_temp, scale.item(), kv_dtype=kv_cache_dtype)
else:
ref_kv_cache = ref_temp
@@ -1054,6 +1122,5 @@ def test_concat_and_cache_mla_cpu(
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
kv_cache_dtype, scale)
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
torch.testing.assert_close(kv_cache, ref_kv_cache)

View File

@@ -7,11 +7,12 @@ import pytest
import torch
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import (cascade_attention,
merge_attn_states)
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_varlen_func,
is_fa_version_supported)
from vllm.v1.attention.backends.flash_attn import cascade_attention, merge_attn_states
from vllm.vllm_flash_attn import (
fa_version_unsupported_reason,
flash_attn_varlen_func,
is_fa_version_supported,
)
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 192, 256]
@@ -37,21 +38,14 @@ def test_merge_kernel(
assert num_query_heads % num_kv_heads == 0
# Prepare inputs.
prefix_output = torch.randn(num_tokens,
num_query_heads,
head_size,
dtype=dtype)
suffix_output = torch.randn(num_tokens,
num_query_heads,
head_size,
dtype=dtype)
prefix_output = torch.randn(num_tokens, num_query_heads, head_size, dtype=dtype)
suffix_output = torch.randn(num_tokens, num_query_heads, head_size, dtype=dtype)
prefix_lse = torch.randn(num_query_heads, num_tokens, dtype=torch.float32)
suffix_lse = torch.randn(num_query_heads, num_tokens, dtype=torch.float32)
# Run the kernel.
output = torch.empty(num_tokens, num_query_heads, head_size, dtype=dtype)
merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
suffix_lse)
merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse)
# Reference implementation.
max_lse = torch.maximum(prefix_lse, suffix_lse)
@@ -97,8 +91,10 @@ def test_cascade(
) -> None:
torch.set_default_device("cuda")
if not is_fa_version_supported(fa_version):
pytest.skip(f"Flash attention version {fa_version} not supported due "
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
pytest.skip(
f"Flash attention version {fa_version} not supported due "
f'to: "{fa_version_unsupported_reason(fa_version)}"'
)
current_platform.seed_everything(0)
@@ -107,11 +103,9 @@ def test_cascade(
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
key_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
key_cache = torch.randn(
num_blocks, block_size, num_kv_heads, head_size, dtype=dtype
)
value_cache = torch.randn_like(key_cache)
seq_lens, common_prefix_len = seq_lens_and_common_prefix
@@ -122,26 +116,21 @@ def test_cascade(
max_kv_len = max(kv_lens)
total_num_query_tokens = sum(query_lens)
query = torch.randn(total_num_query_tokens,
num_query_heads,
head_size,
dtype=dtype)
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
query = torch.randn(total_num_query_tokens, num_query_heads, head_size, dtype=dtype)
cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
dim=0, dtype=torch.int32
)
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
block_tables = torch.randint(
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
)
assert common_prefix_len > 0
assert common_prefix_len % block_size == 0
num_common_kv_blocks = common_prefix_len // block_size
# Make sure the first `num_common_kv_blocks` blocks are the same.
block_tables[:, :num_common_kv_blocks] = \
block_tables[0, :num_common_kv_blocks]
block_tables[:, :num_common_kv_blocks] = block_tables[0, :num_common_kv_blocks]
# Run the regular attention.
ref_output = flash_attn_varlen_func(
@@ -161,8 +150,7 @@ def test_cascade(
# Run cascade attention.
assert all(common_prefix_len < kv_len for kv_len in kv_lens)
cu_prefix_query_lens = torch.tensor([0, total_num_query_tokens],
dtype=torch.int32)
cu_prefix_query_lens = torch.tensor([0, total_num_query_tokens], dtype=torch.int32)
prefix_kv_lens = torch.tensor([common_prefix_len], dtype=torch.int32)
suffix_kv_lens = kv_lens_tensor - common_prefix_len
output = torch.empty_like(query)

View File

@@ -12,33 +12,37 @@ from vllm.platforms import current_platform
from vllm.triton_utils import triton
def cal_diff(x: torch.Tensor,
y: torch.Tensor,
name: str,
use_fp8: bool = False,
diff_threshold: Optional[float] = None) -> None:
def cal_diff(
x: torch.Tensor,
y: torch.Tensor,
name: str,
use_fp8: bool = False,
diff_threshold: Optional[float] = None,
) -> None:
x, y = x.double(), y.double()
cos_diff = 1 - 2 * (x * y).sum().item() / max(
(x * x + y * y).sum().item(), 1e-12)
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
if diff_threshold is not None:
# directly compare the cos_diff with the threshold
assert cos_diff < diff_threshold
else:
# use the default threshold
if (use_fp8):
if use_fp8:
assert cos_diff < 1e-4
else:
assert cos_diff < 1e-5
CUTLASS_MLA_UNSUPPORTED_REASON = \
"Cutlass MLA Requires compute capability of 10 or above." \
if not current_platform.is_device_capability(100) \
CUTLASS_MLA_UNSUPPORTED_REASON = (
"Cutlass MLA Requires compute capability of 10 or above."
if not current_platform.is_device_capability(100)
else "Cutlass MLA is supported"
)
@pytest.mark.skipif(not current_platform.has_device_capability(100),
reason=CUTLASS_MLA_UNSUPPORTED_REASON)
@pytest.mark.skipif(
not current_platform.has_device_capability(100),
reason=CUTLASS_MLA_UNSUPPORTED_REASON,
)
@pytest.mark.parametrize("b", [128])
@pytest.mark.parametrize("s_q", [1])
@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384])
@@ -54,11 +58,13 @@ CUTLASS_MLA_UNSUPPORTED_REASON = \
[
torch.bfloat16,
# fp8 can have occasional precision-related failures.
pytest.param(torch.float8_e4m3fn, marks=pytest.mark.flaky(reruns=2))
])
pytest.param(torch.float8_e4m3fn, marks=pytest.mark.flaky(reruns=2)),
],
)
@torch.inference_mode()
def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size,
causal, varlen, torch_dtype):
def test_cutlass_mla_decode(
b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen, torch_dtype
):
device = torch.device("cuda:0")
if torch_dtype == torch.float8_e4m3fn:
init_dtype = torch.bfloat16
@@ -70,24 +76,25 @@ def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size,
torch.manual_seed(42)
random.seed(42)
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}")
print(
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}"
)
use_fp8 = torch_dtype == torch.float8_e4m3fn
scale = math.sqrt(d)**(-1)
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
scale = math.sqrt(d) ** (-1)
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
if varlen:
for i in range(b):
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2),
s_q)
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)
total_seqlens = cache_seqlens.sum().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
q = torch.randn(b, s_q, h_q, d)
block_table = torch.arange(b * max_seqlen_pad // block_size,
dtype=torch.int32).view(
b, max_seqlen_pad // block_size)
block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32
).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
blocked_v = blocked_k[..., :dv]
@@ -121,22 +128,29 @@ def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size,
q_pe = q_pe_padded
kv_cache_flat = blocked_k.squeeze(2)
device_properties = torch.cuda.get_device_properties(
torch.device("cuda:0"))
device_properties = torch.cuda.get_device_properties(torch.device("cuda:0"))
sm_count = device_properties.multi_processor_count
workspace_size = ops.sm100_cutlass_mla_get_workspace_size(
max_seqlen * block_size, b, sm_count, num_kv_splits=1)
workspace = torch.empty(workspace_size,
device="cuda",
dtype=torch.uint8)
max_seqlen * block_size, b, sm_count, num_kv_splits=1
)
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)
out_ans = torch.empty(b, MAX_HEADS, dv, dtype=init_dtype)
output_lse = torch.empty((b, MAX_HEADS),
dtype=torch.float32,
device=q_nope.device)
ops.sm100_cutlass_mla_decode(out_ans, output_lse, q_nope, q_pe,
kv_cache_flat, cache_seqlens, block_table,
workspace, scale, 1)
output_lse = torch.empty(
(b, MAX_HEADS), dtype=torch.float32, device=q_nope.device
)
ops.sm100_cutlass_mla_decode(
out_ans,
output_lse,
q_nope,
q_pe,
kv_cache_flat,
cache_seqlens,
block_table,
workspace,
scale,
1,
)
return out_ans[:, :h_q].contiguous(), output_lse[:, :h_q].contiguous()
def scaled_dot_product_attention(query, key, value, is_causal=False):
@@ -150,8 +164,7 @@ def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size,
s_q = query.shape[-2]
s_k = key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
temp_mask = torch.ones(s_q, s_k,
dtype=torch.bool).tril(diagonal=s_k - s_q)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
attn_weight += attn_bias
@@ -161,10 +174,16 @@ def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size,
def ref_mla():
q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q
blocked_k_ = (blocked_k.to(torch.float) *
descale_k).to(init_dtype) if use_fp8 else blocked_k
blocked_v_ = (blocked_v.to(torch.float) *
descale_k).to(init_dtype) if use_fp8 else blocked_v
blocked_k_ = (
(blocked_k.to(torch.float) * descale_k).to(init_dtype)
if use_fp8
else blocked_k
)
blocked_v_ = (
(blocked_v.to(torch.float) * descale_k).to(init_dtype)
if use_fp8
else blocked_v
)
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
@@ -191,8 +210,9 @@ def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size,
t = triton.testing.do_bench(cutlass_mla)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d +
b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (
b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,",
f"{bytes / 10 ** 6 / t:.0f} GB/s")
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d) * (
torch.finfo(torch_dtype).bits // 8
) + (b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
print(
f"{t:.3f} ms, {FLOPS / 10**9 / t:.0f} TFLOPS,", f"{bytes / 10**6 / t:.0f} GB/s"
)

View File

@@ -7,9 +7,14 @@ import torch
from vllm.platforms import current_platform
from vllm.utils import cdiv, has_deep_gemm
from vllm.utils.deep_gemm import (_ceil_to_ue8m0, calc_diff, fp8_mqa_logits,
fp8_paged_mqa_logits, get_num_sms,
get_paged_mqa_logits_metadata)
from vllm.utils.deep_gemm import (
_ceil_to_ue8m0,
calc_diff,
fp8_mqa_logits,
fp8_paged_mqa_logits,
get_num_sms,
get_paged_mqa_logits_metadata,
)
def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor:
@@ -24,17 +29,18 @@ def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor:
device=x.device,
dtype=torch.uint8,
)
x_fp8[:, :block_size * head_dim] = x_scaled.view(
num_blocks, block_size * head_dim).view(dtype=torch.uint8)
x_fp8[:,
block_size * head_dim:] = sf.view(num_blocks,
block_size).view(dtype=torch.uint8)
x_fp8[:, : block_size * head_dim] = x_scaled.view(
num_blocks, block_size * head_dim
).view(dtype=torch.uint8)
x_fp8[:, block_size * head_dim :] = sf.view(num_blocks, block_size).view(
dtype=torch.uint8
)
return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4)
def per_custom_dims_cast_to_fp8(
x: torch.Tensor, dims: tuple,
use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]:
x: torch.Tensor, dims: tuple, use_ue8m0: bool
) -> tuple[torch.Tensor, torch.Tensor]:
excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
sf = x_amax / 448.0
@@ -69,10 +75,12 @@ def _ref_fp8_mqa_logits(
q = q.float()
k = k.float()
mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
>= cu_seqlen_ks[:, None])
mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
< cu_seqlen_ke[:, None])
mask_lo = (
torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None]
)
mask_hi = (
torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
)
mask = mask_lo & mask_hi
score = torch.einsum("mhd,and->hmn", q, k)
@@ -84,14 +92,15 @@ def _ref_fp8_mqa_logits(
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="SM90 and SM100 only")
@pytest.mark.skipif(
not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
)
def test_deepgemm_fp8_mqa_logits():
torch.manual_seed(0)
random.seed(0)
num_heads, head_dim = 32, 128
for seq_len in (512, ):
for seq_len_kv in (1024, ):
for seq_len in (512,):
for seq_len_kv in (1024,):
for disable_cp in (False, True):
q = torch.randn(
seq_len,
@@ -100,24 +109,23 @@ def test_deepgemm_fp8_mqa_logits():
device="cuda",
dtype=torch.bfloat16,
)
kv = torch.randn(seq_len_kv,
head_dim,
device="cuda",
dtype=torch.bfloat16)
weights = torch.randn(seq_len,
num_heads,
device="cuda",
dtype=torch.float32)
kv = torch.randn(
seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16
)
weights = torch.randn(
seq_len, num_heads, device="cuda", dtype=torch.float32
)
if disable_cp:
ks = torch.zeros(seq_len, dtype=torch.int, device="cuda")
ke = torch.arange(seq_len, dtype=torch.int,
device="cuda") + (seq_len_kv - seq_len)
ke = torch.arange(seq_len, dtype=torch.int, device="cuda") + (
seq_len_kv - seq_len
)
else:
ks, ke = _generate_cp_test_data(seq_len, seq_len_kv)
q_fp8 = q.to(torch.float8_e4m3fn)
kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0, ), False)
kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0,), False)
logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke)
ref_logits = _ref_fp8_mqa_logits(
@@ -157,11 +165,10 @@ def _ref_fp8_paged_mqa_logits(
context_lens_list = context_lens.tolist()
for i in range(batch_size):
context_len = context_lens_list[i]
q_offsets = torch.arange(context_len - next_n,
context_len,
device="cuda")
weight_slice = (weights[i * next_n:(i + 1) * next_n, :].transpose(
0, 1).contiguous())
q_offsets = torch.arange(context_len - next_n, context_len, device="cuda")
weight_slice = (
weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous()
)
for block_rk in range(cdiv(context_len, block_size)):
block_idx = block_tables[i][block_rk]
qx, kx = q[i], kv_cache[block_idx]
@@ -170,28 +177,30 @@ def _ref_fp8_paged_mqa_logits(
(block_rk + 1) * block_size,
device="cuda",
)
mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :]
<= q_offsets[:, None])
mask = (k_offsets[None, :] < context_len) & (
k_offsets[None, :] <= q_offsets[:, None]
)
s = torch.where(
mask[None, :, :],
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
logits.dtype),
logits.dtype
),
float("-inf"),
)
s = torch.relu(s) * weight_slice[..., None]
s = s.sum(dim=0)
logits[
i * next_n:(i + 1) * next_n,
block_rk * block_size:(block_rk + 1) * block_size,
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s,
float("-inf"))
i * next_n : (i + 1) * next_n,
block_rk * block_size : (block_rk + 1) * block_size,
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf"))
return logits
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="SM90 and SM100 only")
@pytest.mark.skipif(
not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
)
def test_deepgemm_fp8_paged_mqa_logits():
torch.manual_seed(0)
random.seed(0)
@@ -199,7 +208,7 @@ def test_deepgemm_fp8_paged_mqa_logits():
max_model_len = 4096
for batch_size, next_n in [(4, 1), (2, 2)]:
for heads, index_dim in [(32, 128)]:
for avg_kv in (2048, ):
for avg_kv in (2048,):
num_blocks, blocksize = max_model_len * 2, 64
q = torch.randn(
@@ -218,12 +227,14 @@ def test_deepgemm_fp8_paged_mqa_logits():
dtype=torch.float32,
)
context_lens = (torch.randint(int(0.8 * avg_kv),
int(1.2 * avg_kv),
(batch_size, )).cuda().to(
torch.int32))
max_block_len = ((context_lens.max().item() + blocksize - 1) //
blocksize * blocksize)
context_lens = (
torch.randint(int(0.8 * avg_kv), int(1.2 * avg_kv), (batch_size,))
.cuda()
.to(torch.int32)
)
max_block_len = (
(context_lens.max().item() + blocksize - 1) // blocksize * blocksize
)
block_tables = torch.zeros(
(batch_size, max_block_len),
device="cuda",
@@ -243,7 +254,8 @@ def test_deepgemm_fp8_paged_mqa_logits():
kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)
schedule_metadata = get_paged_mqa_logits_metadata(
context_lens, blocksize, get_num_sms())
context_lens, blocksize, get_num_sms()
)
logits = fp8_paged_mqa_logits(
q_fp8,
kv_cache_fp8,
@@ -263,15 +275,18 @@ def test_deepgemm_fp8_paged_mqa_logits():
max_model_len,
)
positions = (torch.arange(max_model_len,
device="cuda").unsqueeze(0).expand(
batch_size * next_n, -1))
row_indices = (
torch.arange(batch_size * next_n, device="cuda") // next_n)
positions = (
torch.arange(max_model_len, device="cuda")
.unsqueeze(0)
.expand(batch_size * next_n, -1)
)
row_indices = torch.arange(batch_size * next_n, device="cuda") // next_n
next_n_offset = (
torch.arange(batch_size * next_n, device="cuda") % next_n)
mask = positions <= (context_lens[row_indices] - next_n +
next_n_offset).unsqueeze(1)
torch.arange(batch_size * next_n, device="cuda") % next_n
)
mask = positions <= (
context_lens[row_indices] - next_n + next_n_offset
).unsqueeze(1)
logits = logits.masked_fill(~mask, 0)
ref_logits = ref_logits.masked_fill(~mask, 0)

View File

@@ -7,10 +7,12 @@ import pytest
import torch
from vllm.platforms import current_platform
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_varlen_func,
flash_attn_with_kvcache,
is_fa_version_supported)
from vllm.vllm_flash_attn import (
fa_version_unsupported_reason,
flash_attn_varlen_func,
flash_attn_with_kvcache,
is_fa_version_supported,
)
NUM_HEADS = [(4, 4), (8, 2)]
HEAD_SIZES = [128, 256]
@@ -44,7 +46,7 @@ def ref_paged_attn(
for i in range(num_seqs):
query_len = query_lens[i]
kv_len = kv_lens[i]
q = query[start_idx:start_idx + query_len]
q = query[start_idx : start_idx + query_len]
q *= scale
num_kv_blocks = (kv_len + block_size - 1) // block_size
@@ -62,10 +64,13 @@ def ref_paged_attn(
empty_mask = torch.ones(query_len, kv_len)
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
if sliding_window is not None:
sliding_window_mask = torch.triu(empty_mask,
diagonal=kv_len -
(query_len + sliding_window) +
1).bool().logical_not()
sliding_window_mask = (
torch.triu(
empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1
)
.bool()
.logical_not()
)
mask |= sliding_window_mask
if soft_cap is not None:
attn = soft_cap * torch.tanh(attn / soft_cap)
@@ -106,11 +111,15 @@ def test_flash_attn_with_paged_kv(
) -> None:
torch.set_default_device("cuda")
if not is_fa_version_supported(fa_version):
pytest.skip(f"Flash attention version {fa_version} not supported due "
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
pytest.skip(
f"Flash attention version {fa_version} not supported due "
f'to: "{fa_version_unsupported_reason(fa_version)}"'
)
if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
pytest.skip("Flash attention with quantized inputs is only "
"supported on version 3 with bfloat16 base type")
pytest.skip(
"Flash attention with quantized inputs is only "
"supported on version 3 with bfloat16 base type"
)
current_platform.seed_everything(0)
num_seqs = len(kv_lens)
@@ -119,23 +128,19 @@ def test_flash_attn_with_paged_kv(
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
(-1, -1))
window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1)
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
key_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
key_cache = torch.randn(
num_blocks, block_size, num_kv_heads, head_size, dtype=dtype
)
value_cache = torch.randn_like(key_cache)
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
block_tables = torch.randint(
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
)
q = query.unsqueeze(1)
out = torch.empty_like(q) if use_out else None
@@ -180,23 +185,27 @@ def test_flash_attn_with_paged_kv(
if q_dtype is not None:
atol, rtol = 1.5e-1, 1.5e-1
ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=[1] * num_seqs,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap,
sliding_window=sliding_window)
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - ref_output))}"
ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=[1] * num_seqs,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap,
sliding_window=sliding_window,
)
(
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
f"{torch.max(torch.abs(output - ref_output))}",
)
@pytest.mark.parametrize("use_out", [True, False])
@pytest.mark.parametrize("seq_lens",
[[(1, 1328), (5, 18),
(129, 463)], [(1, 523), (1, 37), (1, 2011)]])
@pytest.mark.parametrize(
"seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]]
)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@@ -222,11 +231,15 @@ def test_varlen_with_paged_kv(
) -> None:
torch.set_default_device("cuda")
if not is_fa_version_supported(fa_version):
pytest.skip(f"Flash attention version {fa_version} not supported due "
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
pytest.skip(
f"Flash attention version {fa_version} not supported due "
f'to: "{fa_version_unsupported_reason(fa_version)}"'
)
if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
pytest.skip("Flash attention with quantized inputs is only "
"supported on version 3 with bfloat16 base type")
pytest.skip(
"Flash attention with quantized inputs is only "
"supported on version 3 with bfloat16 base type"
)
current_platform.seed_everything(0)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
@@ -236,30 +249,23 @@ def test_varlen_with_paged_kv(
assert num_query_heads % num_kv_heads == 0
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
(-1, -1))
window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1)
scale = head_size**-0.5
query = torch.randn(sum(query_lens),
num_query_heads,
head_size,
dtype=dtype)
key_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype)
key_cache = torch.randn(
num_blocks, block_size, num_kv_heads, head_size, dtype=dtype
)
value_cache = torch.randn_like(key_cache)
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
dim=0, dtype=torch.int32
)
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
block_tables = torch.randint(
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
)
out = torch.empty_like(query) if use_out else None
@@ -315,5 +321,7 @@ def test_varlen_with_paged_kv(
atol, rtol = 1.5e-2, 1e-2
if q_dtype is not None:
atol, rtol = 1.5e-1, 1.5e-1
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - ref_output))}"
(
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
f"{torch.max(torch.abs(output - ref_output))}",
)

View File

@@ -38,7 +38,7 @@ def ref_paged_attn(
for i in range(num_seqs):
query_len = query_lens[i]
kv_len = kv_lens[i]
q = query[start_idx:start_idx + query_len]
q = query[start_idx : start_idx + query_len]
q *= scale
num_kv_blocks = (kv_len + block_size - 1) // block_size
@@ -56,10 +56,13 @@ def ref_paged_attn(
empty_mask = torch.ones(query_len, kv_len)
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
if sliding_window is not None:
sliding_window_mask = torch.triu(empty_mask,
diagonal=kv_len -
(query_len + sliding_window) +
1).bool().logical_not()
sliding_window_mask = (
torch.triu(
empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1
)
.bool()
.logical_not()
)
mask |= sliding_window_mask
if soft_cap is not None:
attn = soft_cap * torch.tanh(attn / soft_cap)
@@ -101,20 +104,16 @@ def test_flashinfer_decode_with_paged_kv(
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
key_value_cache = torch.randn(NUM_BLOCKS,
2,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
key_value_cache = torch.randn(
NUM_BLOCKS, 2, block_size, num_kv_heads, head_size, dtype=dtype
)
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
block_tables = torch.randint(
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
)
kv_indptr = [0]
kv_indices = []
@@ -135,9 +134,9 @@ def test_flashinfer_decode_with_paged_kv(
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.\
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
use_tensor_cores=True)
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, "NHD", use_tensor_cores=True
)
wrapper.plan(
kv_indptr,
kv_indices,
@@ -155,17 +154,21 @@ def test_flashinfer_decode_with_paged_kv(
output = wrapper.run(query, key_value_cache)
ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=[1] * num_seqs,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap,
sliding_window=sliding_window)
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=[1] * num_seqs,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap,
sliding_window=sliding_window,
)
(
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2),
f"{torch.max(torch.abs(output - ref_output))}",
)
@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]])
@@ -196,16 +199,10 @@ def test_flashinfer_prefill_with_paged_kv(
max_kv_len = max(kv_lens)
scale = head_size**-0.5
query = torch.randn(sum(query_lens),
num_query_heads,
head_size,
dtype=dtype)
key_value_cache = torch.randn(NUM_BLOCKS,
2,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype)
key_value_cache = torch.randn(
NUM_BLOCKS, 2, block_size, num_kv_heads, head_size, dtype=dtype
)
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
@@ -215,10 +212,9 @@ def test_flashinfer_prefill_with_paged_kv(
value_cache /= head_size**0.5
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
block_tables = torch.randint(
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
)
qo_indptr = [0]
kv_indptr = [0]
@@ -242,8 +238,7 @@ def test_flashinfer_prefill_with_paged_kv(
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD")
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD")
wrapper.plan(
qo_indptr,
kv_indptr,
@@ -264,17 +259,21 @@ def test_flashinfer_prefill_with_paged_kv(
key_value_cache,
)
ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=query_lens,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap,
sliding_window=sliding_window)
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=query_lens,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap,
sliding_window=sliding_window,
)
(
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2),
f"{torch.max(torch.abs(output - ref_output))}",
)
@pytest.mark.parametrize("seq_lens", [[(1, 132), (5, 18)]])
@@ -284,9 +283,13 @@ def test_flashinfer_prefill_with_paged_kv(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
def test_flashinfer_prefill_with_paged_fp8_kv(
seq_lens: list[tuple[int, int]], num_heads: tuple[int, int],
head_size: int, dtype: torch.dtype, block_size: int,
soft_cap: Optional[float]) -> None:
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
) -> None:
pytest.skip("TODO: fix the accuracy issue")
torch.set_default_device("cuda")
current_platform.seed_everything(0)
@@ -301,17 +304,11 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
kv_cache_dtype = torch.float8_e4m3fn
query = torch.randn(sum(query_lens),
num_query_heads,
head_size,
dtype=dtype)
query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype)
NUM_BLOCKS_FP8 = 2048
key_value_cache = torch.randn(NUM_BLOCKS_FP8,
2,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
key_value_cache = torch.randn(
NUM_BLOCKS_FP8, 2, block_size, num_kv_heads, head_size, dtype=dtype
)
key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1)
key_cache /= head_size**0.5
value_cache /= head_size**0.5
@@ -319,15 +316,15 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
k_scale = key_cache.amax().item() / 448.0
v_scale = value_cache.amax().item() / 448.0
kv_cache_fp8 = torch.cat([key_cache / k_scale, value_cache / v_scale],
dim=1).to(kv_cache_dtype)
kv_cache_fp8 = torch.cat([key_cache / k_scale, value_cache / v_scale], dim=1).to(
kv_cache_dtype
)
assert (kv_cache_fp8.shape == key_value_cache.shape)
assert kv_cache_fp8.shape == key_value_cache.shape
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS_FP8,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
block_tables = torch.randint(
0, NUM_BLOCKS_FP8, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
)
qo_indptr = [0]
kv_indptr = [0]
@@ -351,8 +348,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD")
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD")
wrapper.plan(
qo_indptr,
kv_indptr,
@@ -369,19 +365,23 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale)
ref_output = ref_paged_attn(query=query,
key_cache=key_cache.squeeze(1),
value_cache=value_cache.squeeze(1),
query_lens=query_lens,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap)
ref_output = ref_paged_attn(
query=query,
key_cache=key_cache.squeeze(1),
value_cache=value_cache.squeeze(1),
query_lens=query_lens,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap,
)
del query
del block_tables
# verify prefill fp8
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
(
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2),
f"{torch.max(torch.abs(output - ref_output))}",
)
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@@ -414,12 +414,9 @@ def test_flashinfer_decode_with_paged_fp8_kv(
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
NUM_BLOCKS_FP8 = 2048
key_value_cache = torch.randn(NUM_BLOCKS_FP8,
2,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
key_value_cache = torch.randn(
NUM_BLOCKS_FP8, 2, block_size, num_kv_heads, head_size, dtype=dtype
)
key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1)
key_cache /= head_size**0.5
value_cache /= head_size**0.5
@@ -429,14 +426,13 @@ def test_flashinfer_decode_with_paged_fp8_kv(
key_cache_fp8 = (key_cache / k_scale).to(kv_cache_dtype)
value_cache_fp8 = (value_cache / v_scale).to(kv_cache_dtype)
assert (key_cache_fp8.shape[1] == 1 and value_cache_fp8.shape[1] == 1)
assert key_cache_fp8.shape[1] == 1 and value_cache_fp8.shape[1] == 1
kv_cache_fp8 = torch.cat([key_cache_fp8, value_cache_fp8], dim=1)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS_FP8,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
block_tables = torch.randint(
0, NUM_BLOCKS_FP8, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
)
kv_indptr = [0]
kv_indices = []
@@ -457,32 +453,38 @@ def test_flashinfer_decode_with_paged_fp8_kv(
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.\
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
use_tensor_cores=use_tensor_cores)
wrapper.plan(kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
q_data_type=dtype,
kv_data_type=kv_cache_dtype,
logits_soft_cap=soft_cap)
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores
)
wrapper.plan(
kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
q_data_type=dtype,
kv_data_type=kv_cache_dtype,
logits_soft_cap=soft_cap,
)
output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale)
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=[1] * num_seqs,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap)
ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=[1] * num_seqs,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap,
)
# Temporary fix: Increasing the tolerance. Seems like a flashinfer issue
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
(
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2),
f"{torch.max(torch.abs(output - ref_output))}",
)

View File

@@ -13,34 +13,29 @@ FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
if not current_platform.has_device_capability(100):
pytest.skip(
reason="FlashInfer MLA Requires compute capability of 10 or above.",
allow_module_level=True)
allow_module_level=True,
)
def ref_mla(
out: Tensor, # (bs, num_heads, v_head_dim)
query: Tensor, # (bs, num_heads, head_dim)
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
scale: float,
block_tables: Tensor, # (bs, max_num_blocks)
seq_lens: Tensor, # (bs,)
out: Tensor, # (bs, num_heads, v_head_dim)
query: Tensor, # (bs, num_heads, head_dim)
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
scale: float,
block_tables: Tensor, # (bs, max_num_blocks)
seq_lens: Tensor, # (bs,)
):
bs, num_heads, v_head_dim = out.shape
head_dim = query.shape[2]
for i in range(bs):
# gather and flatten KV-cache
kv = kv_cache[
block_tables[i]] # (max_num_blocks, block_size, head_dim)
kv = kv.view(1, -1,
head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim)
kv = kv_cache[block_tables[i]] # (max_num_blocks, block_size, head_dim)
kv = kv.view(1, -1, head_dim)[:, : seq_lens[i]] # (1, seq_len, head_dim)
v = kv[:, :, :v_head_dim]
q = query[i].view(num_heads, 1, head_dim)
o = F.scaled_dot_product_attention(q,
kv,
v,
scale=scale,
enable_gqa=True)
o = F.scaled_dot_product_attention(q, kv, v, scale=scale, enable_gqa=True)
out[i] = o.view(num_heads, v_head_dim)
return out
@@ -50,7 +45,7 @@ def ref_mla(
@pytest.mark.parametrize("bs", [1, 2, 4, 16])
@pytest.mark.parametrize("block_size", [32, 64])
def test_flashinfer_mla_decode(dtype: torch.dtype, bs: int, block_size: int):
torch.set_default_device('cuda')
torch.set_default_device("cuda")
torch.manual_seed(42)
# Deepseek R1 config
@@ -59,11 +54,11 @@ def test_flashinfer_mla_decode(dtype: torch.dtype, bs: int, block_size: int):
qk_nope_head_dim = 128
qk_rope_head_dim = 64
qk_head_dim = kv_lora_rank + qk_rope_head_dim
scale = (qk_nope_head_dim + qk_rope_head_dim)**-0.5
scale = (qk_nope_head_dim + qk_rope_head_dim) ** -0.5
MAX_SEQ_LEN = 1024
seq_lens = [torch.randint(2, MAX_SEQ_LEN, (1, )).item() for _ in range(bs)]
seq_lens = [torch.randint(2, MAX_SEQ_LEN, (1,)).item() for _ in range(bs)]
seq_lens[-1] = MAX_SEQ_LEN
max_seq_len = max(seq_lens)
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32)
@@ -86,12 +81,12 @@ def test_flashinfer_mla_decode(dtype: torch.dtype, bs: int, block_size: int):
block_id = 0
for i in range(bs):
num_blocks_needed = blocks_per_seq[i]
block_tables[i, :num_blocks_needed] = all_block_ids[block_id:block_id +
num_blocks_needed]
block_tables[i, :num_blocks_needed] = all_block_ids[
block_id : block_id + num_blocks_needed
]
block_id += num_blocks_needed
kv_cache = torch.randn(block_tables.numel(), block_size,
qk_head_dim).to(dtype)
kv_cache = torch.randn(block_tables.numel(), block_size, qk_head_dim).to(dtype)
q = torch.randn(bs, num_heads, qk_head_dim).to(dtype)
out_ref = q.new_zeros(bs, num_heads, kv_lora_rank)

View File

@@ -6,15 +6,18 @@ import flashinfer
import pytest
import torch
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from tests.kernels.quantization.nvfp4_utils import (
FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype,
)
from vllm.platforms import current_platform
from vllm.utils import round_up
if not current_platform.is_device_capability(100):
pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.",
allow_module_level=True)
pytest.skip(
"This TRTLLM kernel requires NVIDIA Blackwell.", allow_module_level=True
)
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
FP8_DTYPE = current_platform.fp8_dtype()
@@ -64,8 +67,9 @@ NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
@torch.inference_mode
def test_flashinfer_trtllm_decode_with_baseline(
dtype: torch.dtype,
quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype],
Optional[torch.dtype]],
quant_dtypes: tuple[
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
],
batch_size: int,
max_seq_lens: tuple[int, int],
num_heads: tuple[int, int],
@@ -106,7 +110,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
q_scale = 1.0
ref_query = query
kv_lens = torch.randint(1, max_kv_len, (batch_size, ), dtype=torch.int32)
kv_lens = torch.randint(1, max_kv_len, (batch_size,), dtype=torch.int32)
kv_lens[-1] = max_kv_len
seq_lens = kv_lens
@@ -122,10 +126,9 @@ def test_flashinfer_trtllm_decode_with_baseline(
k_scale = v_scale = kv_scale
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS,
(batch_size, max_num_blocks_per_seq),
dtype=torch.int32)
block_tables = torch.randint(
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
)
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
@@ -147,20 +150,23 @@ def test_flashinfer_trtllm_decode_with_baseline(
# Baseline Decode
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, kv_layout, use_tensor_cores=True)
wrapper.plan(kv_indptr,
kv_indices,
kv_last_page_lens,
num_qo_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=dtype,
window_left=window_left,
logits_soft_cap=soft_cap)
workspace_buffer, kv_layout, use_tensor_cores=True
)
wrapper.plan(
kv_indptr,
kv_indices,
kv_last_page_lens,
num_qo_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=dtype,
window_left=window_left,
logits_soft_cap=soft_cap,
)
output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output)
@@ -169,17 +175,21 @@ def test_flashinfer_trtllm_decode_with_baseline(
if o_quant_dtype == FP8_DTYPE:
_, o_scale = to_float8(output)
elif o_quant_dtype == FP4_DTYPE:
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
o_sf_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(output.flatten(), dim=-1)
).to(torch.float32)
# TRTLLM Decode
if o_quant_dtype == FP4_DTYPE:
output_trtllm = flashinfer.utils.FP4Tensor(
torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ),
dtype=torch.uint8),
torch.empty((round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // 16, 4)),
dtype=torch.float8_e4m3fn),
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
torch.empty(
(
round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // 16, 4),
),
dtype=torch.float8_e4m3fn,
),
)
else:
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
@@ -201,13 +211,12 @@ def test_flashinfer_trtllm_decode_with_baseline(
output_trtllm = output_trtllm.to(dtype) * o_scale
elif o_quant_dtype == FP4_DTYPE:
output_trtllm.data = output_trtllm.data.reshape(
-1, query.shape[1] * query.shape[2] // 2)
output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data,
output_trtllm.scale,
o_sf_scale, dtype,
query.device)
output_trtllm = output_trtllm.reshape(-1, query.shape[1],
query.shape[2])
-1, query.shape[1] * query.shape[2] // 2
)
output_trtllm = dequantize_nvfp4_to_dtype(
output_trtllm.data, output_trtllm.scale, o_sf_scale, dtype, query.device
)
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
rtol, atol = 3e-1, 1e0
@@ -216,8 +225,10 @@ def test_flashinfer_trtllm_decode_with_baseline(
else:
rtol, atol = 1e-2, 2e-2
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - output_trtllm))}"
(
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol),
f"{torch.max(torch.abs(output - output_trtllm))}",
)
@pytest.mark.parametrize("dtype", DTYPE)
@@ -233,8 +244,9 @@ def test_flashinfer_trtllm_decode_with_baseline(
@torch.inference_mode
def test_flashinfer_trtllm_prefill_with_baseline(
dtype: torch.dtype,
quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype],
Optional[torch.dtype]],
quant_dtypes: tuple[
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
],
batch_size: int,
max_seq_lens: tuple[int, int],
num_heads: tuple[int, int],
@@ -270,17 +282,16 @@ def test_flashinfer_trtllm_prefill_with_baseline(
else:
raise ValueError(f"Invalid kv_layout: {kv_layout}")
q_lens = torch.randint(1, max_q_len, (batch_size, ), dtype=torch.int32)
q_lens = torch.randint(1, max_q_len, (batch_size,), dtype=torch.int32)
q_lens[-1] = max_q_len
q_indptr = torch.cat([
torch.tensor([0], dtype=torch.int32),
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
])
q_indptr = torch.cat(
[
torch.tensor([0], dtype=torch.int32),
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
]
)
query = torch.randn(torch.sum(q_lens).item(),
num_qo_heads,
head_size,
dtype=dtype)
query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype)
if q_quant_dtype == FP8_DTYPE:
query, q_scale = to_float8(query)
ref_query = query.to(dtype) * q_scale
@@ -288,7 +299,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
q_scale = 1.0
ref_query = query
kv_lens = torch.randint(0, max_kv_len, (batch_size, ), dtype=torch.int32)
kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
kv_lens[-1] = max_kv_len
seq_lens = kv_lens + q_lens
@@ -304,10 +315,9 @@ def test_flashinfer_trtllm_prefill_with_baseline(
k_scale = v_scale = kv_scale
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS,
(batch_size, max_num_blocks_per_seq),
dtype=torch.int32)
block_tables = torch.randint(
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
)
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
@@ -329,21 +339,24 @@ def test_flashinfer_trtllm_prefill_with_baseline(
# Baseline Prefill
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout)
wrapper.plan(q_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
num_qo_heads,
num_kv_heads,
head_size,
block_size,
causal=True,
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=dtype,
window_left=window_left,
logits_soft_cap=soft_cap)
workspace_buffer, kv_layout
)
wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
num_qo_heads,
num_kv_heads,
head_size,
block_size,
causal=True,
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=dtype,
window_left=window_left,
logits_soft_cap=soft_cap,
)
output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output)
@@ -352,17 +365,21 @@ def test_flashinfer_trtllm_prefill_with_baseline(
if o_quant_dtype == FP8_DTYPE:
_, o_scale = to_float8(output)
elif o_quant_dtype == FP4_DTYPE:
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
o_sf_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(output.flatten(), dim=-1)
).to(torch.float32)
# TRTLLM Prefill
if o_quant_dtype == FP4_DTYPE:
output_trtllm = flashinfer.utils.FP4Tensor(
torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ),
dtype=torch.uint8),
torch.empty((round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // 16, 4)),
dtype=torch.float8_e4m3fn),
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
torch.empty(
(
round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // 16, 4),
),
dtype=torch.float8_e4m3fn,
),
)
else:
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
@@ -388,13 +405,12 @@ def test_flashinfer_trtllm_prefill_with_baseline(
output_trtllm = output_trtllm.to(dtype) * o_scale
elif o_quant_dtype == FP4_DTYPE:
output_trtllm.data = output_trtllm.data.reshape(
-1, query.shape[1] * query.shape[2] // 2)
output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data,
output_trtllm.scale,
o_sf_scale, dtype,
query.device)
output_trtllm = output_trtllm.reshape(-1, query.shape[1],
query.shape[2])
-1, query.shape[1] * query.shape[2] // 2
)
output_trtllm = dequantize_nvfp4_to_dtype(
output_trtllm.data, output_trtllm.scale, o_sf_scale, dtype, query.device
)
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
rtol, atol = 4e-1, 1e0
@@ -405,5 +421,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
else:
rtol, atol = 1e-2, 1e-2
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - output_trtllm))}"
(
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol),
f"{torch.max(torch.abs(output - output_trtllm))}",
)

View File

@@ -7,30 +7,33 @@ import random
import pytest
import torch
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported)
from vllm.attention.ops.flashmla import (
flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported,
)
from vllm.triton_utils import triton
def cal_diff(x: torch.Tensor,
y: torch.Tensor,
name: str,
use_fp8: bool = False) -> None:
def cal_diff(
x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool = False
) -> None:
x, y = x.double(), y.double()
cos_diff = 1 - 2 * (x * y).sum().item() / max(
(x * x + y * y).sum().item(), 1e-12)
if (use_fp8):
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
if use_fp8:
assert cos_diff < 1e-4
else:
assert cos_diff < 1e-5
FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
if not is_flashmla_supported()[0] else "FlashMLA is supported"
FLASH_MLA_UNSUPPORTED_REASON = (
is_flashmla_supported()[1]
if not is_flashmla_supported()[0]
else "FlashMLA is supported"
)
@pytest.mark.skipif(not is_flashmla_supported()[0],
reason=FLASH_MLA_UNSUPPORTED_REASON)
@pytest.mark.skipif(not is_flashmla_supported()[0], reason=FLASH_MLA_UNSUPPORTED_REASON)
@pytest.mark.parametrize("b", [128])
@pytest.mark.parametrize("s_q", [1, 2])
@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384])
@@ -41,11 +44,13 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
@pytest.mark.parametrize("block_size", [64])
@pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("varlen", [False, True])
@pytest.mark.parametrize("torch_dtype",
[torch.bfloat16, torch.float16, torch.float8_e4m3fn])
@pytest.mark.parametrize(
"torch_dtype", [torch.bfloat16, torch.float16, torch.float8_e4m3fn]
)
@torch.inference_mode()
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
varlen, torch_dtype):
def test_flash_mla(
b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen, torch_dtype
):
device = torch.device("cuda:0")
if torch_dtype == torch.float8_e4m3fn:
init_dtype = torch.bfloat16
@@ -57,31 +62,34 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
torch.manual_seed(0)
random.seed(0)
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}")
print(
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}"
)
use_fp8 = torch_dtype == torch.float8_e4m3fn
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
if varlen:
for i in range(b):
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2),
s_q)
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)
total_seqlens = cache_seqlens.sum().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
q = torch.randn(b, s_q, h_q, d)
block_table = torch.arange(b * max_seqlen_pad // block_size,
dtype=torch.int32).view(
b, max_seqlen_pad // block_size)
block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32
).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
for i in range(b):
blocked_k.view(b, max_seqlen_pad, h_kv,
d)[i, cache_seqlens[i].item():] = float("nan")
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item() :] = (
float("nan")
)
blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens, s_q * h_q // h_kv, h_kv)
cache_seqlens, s_q * h_q // h_kv, h_kv
)
init_dtype = q.dtype
if use_fp8:
@@ -97,16 +105,18 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
descale_k = None
def flash_mla():
return flash_mla_with_kvcache(q,
blocked_k,
block_table,
cache_seqlens,
dv,
tile_scheduler_metadata,
num_splits,
causal=causal,
descale_q=descale_q,
descale_k=descale_k)
return flash_mla_with_kvcache(
q,
blocked_k,
block_table,
cache_seqlens,
dv,
tile_scheduler_metadata,
num_splits,
causal=causal,
descale_q=descale_q,
descale_k=descale_k,
)
def scaled_dot_product_attention(query, key, value, is_causal=False):
query = query.float()
@@ -119,8 +129,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
s_q = query.shape[-2]
s_k = key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
temp_mask = torch.ones(s_q, s_k,
dtype=torch.bool).tril(diagonal=s_k - s_q)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
attn_weight += attn_bias
@@ -130,10 +139,16 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
def ref_mla():
q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q
blocked_k_ = (blocked_k.to(torch.float) *
descale_k).to(init_dtype) if use_fp8 else blocked_k
blocked_v_ = (blocked_v.to(torch.float) *
descale_k).to(init_dtype) if use_fp8 else blocked_v
blocked_k_ = (
(blocked_k.to(torch.float) * descale_k).to(init_dtype)
if use_fp8
else blocked_k
)
blocked_v_ = (
(blocked_v.to(torch.float) * descale_k).to(init_dtype)
if use_fp8
else blocked_v
)
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
@@ -156,8 +171,9 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
t = triton.testing.do_bench(flash_mla)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d +
b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (
b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,",
f"{bytes / 10 ** 6 / t:.0f} GB/s")
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d) * (
torch.finfo(torch_dtype).bits // 8
) + (b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
print(
f"{t:.3f} ms, {FLOPS / 10**9 / t:.0f} TFLOPS,", f"{bytes / 10**6 / t:.0f} GB/s"
)

View File

@@ -13,6 +13,7 @@ def _cuda_sm90_available() -> bool:
def test_sparse_flashmla_metadata_smoke():
import vllm.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_supported()
if not ok or not _cuda_sm90_available():
pytest.skip(reason or "SM90 not available")
@@ -27,18 +28,21 @@ def test_sparse_flashmla_metadata_smoke():
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
tile_md, num_splits = fm.get_mla_metadata(cache_seqlens,
q_seq_per_hk,
num_heads_k,
num_heads_q=num_heads_q,
topk=topk,
is_fp8_kvcache=True)
tile_md, num_splits = fm.get_mla_metadata(
cache_seqlens,
q_seq_per_hk,
num_heads_k,
num_heads_q=num_heads_q,
topk=topk,
is_fp8_kvcache=True,
)
assert tile_md.dtype == torch.int32
assert num_splits.dtype == torch.int32
def test_sparse_flashmla_decode_smoke():
import vllm.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_supported()
if not ok or not _cuda_sm90_available():
pytest.skip(reason or "SM90 not available")
@@ -58,36 +62,42 @@ def test_sparse_flashmla_decode_smoke():
q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k
# q_heads_per_hk = num_heads_q // num_heads_k
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
tile_md, num_splits = fm.get_mla_metadata(cache_seqlens,
q_seq_per_hk,
num_heads_k,
num_heads_q=num_heads_q,
topk=topk,
is_fp8_kvcache=True)
tile_md, num_splits = fm.get_mla_metadata(
cache_seqlens,
q_seq_per_hk,
num_heads_k,
num_heads_q=num_heads_q,
topk=topk,
is_fp8_kvcache=True,
)
# Inputs
q = torch.zeros((batch_size, seqlen_q, num_heads_q, head_dim_k),
dtype=torch.bfloat16,
device=device)
k_cache = torch.zeros((1, page_block_size, num_heads_k, bytes_per_token),
dtype=torch.uint8,
device=device)
indices = torch.zeros((batch_size, seqlen_q, topk),
dtype=torch.int32,
device=device)
q = torch.zeros(
(batch_size, seqlen_q, num_heads_q, head_dim_k),
dtype=torch.bfloat16,
device=device,
)
k_cache = torch.zeros(
(1, page_block_size, num_heads_k, bytes_per_token),
dtype=torch.uint8,
device=device,
)
indices = torch.zeros(
(batch_size, seqlen_q, topk), dtype=torch.int32, device=device
)
block_table = torch.zeros((batch_size, 128),
dtype=torch.int32,
device=device)
out, lse = fm.flash_mla_with_kvcache(q,
k_cache,
block_table,
cache_seqlens,
head_dim_v,
tile_md,
num_splits,
indices=indices,
is_fp8_kvcache=True)
block_table = torch.zeros((batch_size, 128), dtype=torch.int32, device=device)
out, lse = fm.flash_mla_with_kvcache(
q,
k_cache,
block_table,
cache_seqlens,
head_dim_v,
tile_md,
num_splits,
indices=indices,
is_fp8_kvcache=True,
)
assert out.shape[0] == batch_size
assert out.shape[-1] == head_dim_v
assert lse.shape[0] == batch_size
@@ -95,6 +105,7 @@ def test_sparse_flashmla_decode_smoke():
def test_sparse_flashmla_prefill_smoke():
import vllm.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_supported()
if not ok or not _cuda_sm90_available():
pytest.skip(reason or "SM90 not available")
@@ -112,8 +123,7 @@ def test_sparse_flashmla_prefill_smoke():
kv = torch.zeros((s_kv, h_kv, d_qk), dtype=torch.bfloat16, device=device)
indices = torch.zeros((s_q, h_kv, topk), dtype=torch.int32, device=device)
out, max_logits, lse = fm.flash_mla_sparse_prefill(q, kv, indices, 1.0,
d_v)
out, max_logits, lse = fm.flash_mla_sparse_prefill(q, kv, indices, 1.0, d_v)
assert out.shape == (s_q, h_q, d_v)
assert max_logits.shape == (s_q, h_q)
assert lse.shape == (s_q, h_q)

View File

@@ -4,8 +4,7 @@
import pytest
import torch
from vllm.model_executor.layers.lightning_attn import (
linear_decode_forward_triton)
from vllm.model_executor.layers.lightning_attn import linear_decode_forward_triton
from vllm.platforms import current_platform
NUM_HEADS = [4, 8]
@@ -17,8 +16,8 @@ DTYPES = [torch.float32]
def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
"""Reference implementation of lightning attention core algorithm
The difference from the main implementation is that this processes
The difference from the main implementation is that this processes
each step sequentially, instead of using parallelized triton kernels
"""
B, H, S, D = q.shape
@@ -62,8 +61,7 @@ def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
# The actual implementation returns a tensor of shape [B, H, 2, D, E]
# where dimension 2 contains both KV and KV history
kv_reshaped = kv_cache.unsqueeze(2) # [B, H, 1, D, E]
final_kv_cache = torch.cat([kv_reshaped, kv_reshaped],
dim=2) # [B, H, 2, D, E]
final_kv_cache = torch.cat([kv_reshaped, kv_reshaped], dim=2) # [B, H, 2, D, E]
return output, final_kv_cache
@@ -109,7 +107,7 @@ def reference_linear_decode(q, k, v, kv_caches, slope_rate, slot_idx):
out_h = torch.matmul(q_bh, kv_new)
# Update output and cache
output[b, h * D:(h + 1) * D] = out_h
output[b, h * D : (h + 1) * D] = out_h
kv_caches[b, h] = kv_new
return output
@@ -135,12 +133,9 @@ def test_linear_decode_forward_triton(
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
kv_caches = base * torch.randn(batch_size,
num_heads,
head_size,
head_size,
dtype=dtype,
device="cuda")
kv_caches = base * torch.randn(
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
)
kv_caches_copy = kv_caches.clone()
@@ -150,15 +145,14 @@ def test_linear_decode_forward_triton(
slot_idx = torch.arange(batch_size, device="cuda")
triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
slope_rate, slot_idx)
triton_output = linear_decode_forward_triton(
q, k, v, kv_caches, slope_rate, slot_idx
)
reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
slope_rate, slot_idx)
torch.testing.assert_close(triton_output,
reference_output,
rtol=1e-1,
atol=1e-1)
reference_output = reference_linear_decode(
q, k, v, kv_caches_copy, slope_rate, slot_idx
)
torch.testing.assert_close(triton_output, reference_output, rtol=1e-1, atol=1e-1)
torch.testing.assert_close(kv_caches, kv_caches_copy, rtol=1e-1, atol=1e-1)
assert triton_output.shape == (batch_size, num_heads * head_size)
@@ -184,12 +178,9 @@ def test_linear_decode_forward_triton_with_padding(
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
kv_caches = base * torch.randn(batch_size,
num_heads,
head_size,
head_size,
dtype=dtype,
device="cuda")
kv_caches = base * torch.randn(
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
)
kv_caches_copy = kv_caches.clone()
@@ -199,14 +190,15 @@ def test_linear_decode_forward_triton_with_padding(
slot_idx = torch.tensor([0, 1, -1, 2], device="cuda")
triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
slope_rate, slot_idx)
triton_output = linear_decode_forward_triton(
q, k, v, kv_caches, slope_rate, slot_idx
)
reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
slope_rate, slot_idx)
reference_output = reference_linear_decode(
q, k, v, kv_caches_copy, slope_rate, slot_idx
)
padding_mask = (slot_idx
!= -1).unsqueeze(1).expand(-1, num_heads * head_size)
padding_mask = (slot_idx != -1).unsqueeze(1).expand(-1, num_heads * head_size)
triton_masked = triton_output[padding_mask]
reference_masked = reference_output[padding_mask]
@@ -217,15 +209,11 @@ def test_linear_decode_forward_triton_with_padding(
for i in range(batch_size):
if valid_indices[i] > 0:
torch.testing.assert_close(kv_caches[i],
kv_caches_copy[i],
rtol=rtol,
atol=atol)
torch.testing.assert_close(
kv_caches[i], kv_caches_copy[i], rtol=rtol, atol=atol
)
torch.testing.assert_close(triton_masked,
reference_masked,
rtol=rtol,
atol=atol)
torch.testing.assert_close(triton_masked, reference_masked, rtol=rtol, atol=atol)
assert triton_output.shape == (batch_size, num_heads * head_size)
@@ -249,39 +237,33 @@ def test_lightning_attention_reference(
current_platform.seed_everything(42)
base = 0.01
q = base * torch.randn(
batch_size, num_heads, seq_len, head_size, dtype=dtype)
k = base * torch.randn(
batch_size, num_heads, seq_len, head_size, dtype=dtype)
v = base * torch.randn(
batch_size, num_heads, seq_len, head_size, dtype=dtype)
q = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
k = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
v = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
ed = torch.zeros(num_heads, device="cuda")
for h in range(num_heads):
ed[h] = 0.1 * (h + 1)
kv_history = base * torch.randn(batch_size,
num_heads,
head_size,
head_size,
dtype=dtype,
device="cuda")
kv_history = base * torch.randn(
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
)
kv_history_clone = kv_history.clone()
ref_output, ref_kv_cache = reference_lightning_attention(
q, k, v, ed, 256, kv_history)
q, k, v, ed, 256, kv_history
)
from vllm.model_executor.layers.lightning_attn import lightning_attention
actual_output, actual_kv_cache = lightning_attention(
q, k, v, ed, 256, kv_history_clone)
q, k, v, ed, 256, kv_history_clone
)
atol, rtol = 1.5e-1, 1.5e-1
torch.testing.assert_close(ref_output, actual_output, rtol=rtol, atol=atol)
torch.testing.assert_close(ref_kv_cache,
actual_kv_cache,
rtol=rtol,
atol=atol)
torch.testing.assert_close(ref_kv_cache, actual_kv_cache, rtol=rtol, atol=atol)
assert ref_output.shape == (batch_size, num_heads, seq_len, head_size)
assert ref_kv_cache.shape == actual_kv_cache.shape

View File

@@ -7,19 +7,20 @@ import torch
from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda
from vllm.attention.ops.triton_merge_attn_states import (
merge_attn_states as merge_attn_states_triton)
merge_attn_states as merge_attn_states_triton,
)
from vllm.platforms import current_platform
# Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
# can be used to combine partial attention results (in the split-KV case)
def merge_attn_states_torch(
output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
output_lse: Optional[torch.Tensor] = None, # [NUM_HEADS, NUM_TOKENS]
output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
output_lse: Optional[torch.Tensor] = None, # [NUM_HEADS, NUM_TOKENS]
):
p_lse = prefix_lse
s_lse = suffix_lse
@@ -32,15 +33,13 @@ def merge_attn_states_torch(
s_lse = s_lse - max_lse
p_lse_exp = torch.exp(p_lse)
s_lse_exp = torch.exp(s_lse)
out_se = (p_lse_exp + s_lse_exp)
out_se = p_lse_exp + s_lse_exp
if output_lse is not None:
output_lse = torch.log(out_se) + max_lse
p_scale = p_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
s_scale = s_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
p_scale = torch.transpose(p_scale, 0,
1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
s_scale = torch.transpose(s_scale, 0,
1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
p_scale = torch.transpose(p_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
s_scale = torch.transpose(s_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
output = prefix_output * p_scale + suffix_output * s_scale
return output, output_lse
@@ -55,8 +54,10 @@ all_case_info: list[tuple] = []
def generate_markdown_table():
global all_case_info
table_header = ("| tokens | heads | headsize | dtype "
"| device | torch | triton | cuda | speedup |")
table_header = (
"| tokens | heads | headsize | dtype "
"| device | torch | triton | cuda | speedup |"
)
table_separator = "| --- | --- | --- | --- | --- | --- | --- | --- | --- |"
def shortly_dtype(dtype: torch.dtype) -> str:
@@ -68,16 +69,26 @@ def generate_markdown_table():
print(table_header)
print(table_separator)
for info in all_case_info:
(num_tokens, num_heads, head_size, dtype, device,
avg_time_torch_kernel, avg_time_triton_kernel, avg_time_cuda_kernel,
performance_improved) = info
(
num_tokens,
num_heads,
head_size,
dtype,
device,
avg_time_torch_kernel,
avg_time_triton_kernel,
avg_time_cuda_kernel,
performance_improved,
) = info
dtype = shortly_dtype(dtype)
device = shortly_device(device)
print(f"| {num_tokens} | {num_heads} | {head_size} "
f"| {dtype} | {device} | {avg_time_torch_kernel:.5f}ms "
f"| {avg_time_triton_kernel:.5f}ms "
f"| {avg_time_cuda_kernel:.5f}ms "
f"| {performance_improved:.4f}x |")
print(
f"| {num_tokens} | {num_heads} | {head_size} "
f"| {dtype} | {device} | {avg_time_torch_kernel:.5f}ms "
f"| {avg_time_triton_kernel:.5f}ms "
f"| {avg_time_cuda_kernel:.5f}ms "
f"| {performance_improved:.4f}x |"
)
@pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS)
@@ -85,29 +96,28 @@ def generate_markdown_table():
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("output_dtype", DTYPES)
@torch.inference_mode()
def test_merge_attn_states(num_tokens: int, num_query_heads: int,
head_size: int, output_dtype: torch.dtype):
def test_merge_attn_states(
num_tokens: int, num_query_heads: int, head_size: int, output_dtype: torch.dtype
):
if not current_platform.is_cuda():
pytest.skip('Currently only support compare triton merge_attn_states '
'with custom cuda merge_attn_states kernel')
pytest.skip(
"Currently only support compare triton merge_attn_states "
"with custom cuda merge_attn_states kernel"
)
NUM_TOKENS = num_tokens
NUM_HEADS = num_query_heads
HEAD_SIZE = head_size
print(f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, "
f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, "
f"Device: {current_platform.get_device_name()}")
print(
f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, "
f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, "
f"Device: {current_platform.get_device_name()}"
)
# prefix_lse and suffix_lse contain inf and normal values
prefix_lse = torch.randn(NUM_HEADS,
NUM_TOKENS,
dtype=torch.float32,
device="cuda")
suffix_lse = torch.randn(NUM_HEADS,
NUM_TOKENS,
dtype=torch.float32,
device="cuda")
prefix_lse = torch.randn(NUM_HEADS, NUM_TOKENS, dtype=torch.float32, device="cuda")
suffix_lse = torch.randn(NUM_HEADS, NUM_TOKENS, dtype=torch.float32, device="cuda")
# Generate boolean masks
mask_prefix = torch.rand(NUM_HEADS, NUM_TOKENS) < 0.1
@@ -117,23 +127,23 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int,
mask_prefix = torch.logical_and(mask_prefix, ~combined_mask)
mask_suffix = torch.logical_and(mask_suffix, ~combined_mask)
prefix_lse[mask_prefix] = float('inf')
suffix_lse[mask_suffix] = float('inf')
prefix_lse[mask_prefix] = float("inf")
suffix_lse[mask_suffix] = float("inf")
# Other input tensors (need to be initialized but
# no actual calculation needed)
output = torch.zeros((NUM_TOKENS, NUM_HEADS, HEAD_SIZE),
dtype=output_dtype,
device="cuda")
output_lse = torch.zeros((NUM_HEADS, NUM_TOKENS),
dtype=torch.float32,
device="cuda")
prefix_output = torch.randn((NUM_TOKENS, NUM_HEADS, HEAD_SIZE),
dtype=output_dtype,
device="cuda")
suffix_output = torch.randn((NUM_TOKENS, NUM_HEADS, HEAD_SIZE),
dtype=output_dtype,
device="cuda")
output = torch.zeros(
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
)
output_lse = torch.zeros(
(NUM_HEADS, NUM_TOKENS), dtype=torch.float32, device="cuda"
)
prefix_output = torch.randn(
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
)
suffix_output = torch.randn(
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
)
warmup_times = 2
repeat_times = 20
@@ -149,15 +159,25 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int,
suffix_lse_torch = suffix_lse.clone()
for _ in range(warmup_times):
output_torch, output_lse_torch = merge_attn_states_torch(
output_torch, prefix_output, prefix_lse_torch, suffix_output,
suffix_lse_torch, output_lse_torch)
output_torch,
prefix_output,
prefix_lse_torch,
suffix_output,
suffix_lse_torch,
output_lse_torch,
)
torch.cuda.synchronize()
for _ in range(repeat_times):
start.record()
output_torch, output_lse_torch = merge_attn_states_torch(
output_torch, prefix_output, prefix_lse_torch, suffix_output,
suffix_lse_torch, output_lse_torch)
output_torch,
prefix_output,
prefix_lse_torch,
suffix_output,
suffix_lse_torch,
output_lse_torch,
)
end.record()
torch.cuda.synchronize()
total_time_torch_kernel += start.elapsed_time(end)
@@ -173,16 +193,26 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int,
end = torch.cuda.Event(enable_timing=True)
for _ in range(warmup_times):
merge_attn_states_triton(output_ref_triton, prefix_output, prefix_lse,
suffix_output, suffix_lse,
output_lse_ref_triton)
merge_attn_states_triton(
output_ref_triton,
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
output_lse_ref_triton,
)
torch.cuda.synchronize()
for _ in range(repeat_times):
start.record()
merge_attn_states_triton(output_ref_triton, prefix_output, prefix_lse,
suffix_output, suffix_lse,
output_lse_ref_triton)
merge_attn_states_triton(
output_ref_triton,
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
output_lse_ref_triton,
)
end.record()
torch.cuda.synchronize()
total_time_triton_kernel += start.elapsed_time(end)
@@ -195,14 +225,26 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int,
output_lse_cuda = output_lse.clone()
for _ in range(warmup_times):
merge_attn_states_cuda(output_cuda, prefix_output, prefix_lse,
suffix_output, suffix_lse, output_lse_cuda)
merge_attn_states_cuda(
output_cuda,
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
output_lse_cuda,
)
torch.cuda.synchronize()
for _ in range(repeat_times):
start.record()
merge_attn_states_cuda(output_cuda, prefix_output, prefix_lse,
suffix_output, suffix_lse, output_lse_cuda)
merge_attn_states_cuda(
output_cuda,
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
output_lse_cuda,
)
end.record()
torch.cuda.synchronize()
total_time_cuda_kernel += start.elapsed_time(end)
@@ -213,8 +255,10 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int,
performance_improved = avg_time_triton_kernel / avg_time_cuda_kernel
print(f" Torch time: {avg_time_torch_kernel:.6f}ms")
print(f"Triton time: {avg_time_triton_kernel:.6f}ms")
print(f" CUDA time: {avg_time_cuda_kernel:.6f}ms, "
f"Performance: {performance_improved:.5f}x")
print(
f" CUDA time: {avg_time_cuda_kernel:.6f}ms, "
f"Performance: {performance_improved:.5f}x"
)
print("-" * 100)
# 4. Correctness compare
@@ -232,35 +276,45 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int,
# states operation.
output_ref = output_ref_triton
output_lse_ref = output_lse_ref_triton
torch.testing.assert_close(output_cuda.float(),
output_ref.float(),
atol=1e-3,
rtol=rtol)
torch.testing.assert_close(
output_cuda.float(), output_ref.float(), atol=1e-3, rtol=rtol
)
print("Output all match, max abs diff:")
print(f"(Triton vs Torch) : {diff(output_torch, output_ref)}")
print(f" (CUDA vs Torch) : {diff(output_torch, output_cuda)}")
print(f" (CUDA vs Triton): {diff(output_ref, output_cuda)}")
print("-" * 100)
torch.testing.assert_close(output_lse_cuda.float(),
output_lse_ref.float(),
atol=1e-3,
rtol=rtol)
torch.testing.assert_close(
output_lse_cuda.float(), output_lse_ref.float(), atol=1e-3, rtol=rtol
)
print("Output LSE all match, max abs diff:")
print(f"(Triton vs Torch) : {diff(output_lse_torch, output_lse_ref)}")
print(f" (CUDA vs Torch) : {diff(output_lse_torch, output_lse_cuda)}")
print(f" (CUDA vs Triton): {diff(output_lse_ref, output_lse_cuda)}")
print("-" * 100)
print("All output values test passed! All inf values "
"are correctly replaced with -inf.")
print(
"All output values test passed! All inf values "
"are correctly replaced with -inf."
)
print("-" * 100)
device = current_platform.get_device_name()
all_case_info.append(
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE, output_dtype, device,
avg_time_torch_kernel, avg_time_triton_kernel, avg_time_cuda_kernel,
performance_improved))
if len(all_case_info) == (len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) *
len(NUM_QUERY_HEADS) * len(DTYPES)):
(
NUM_TOKENS,
NUM_HEADS,
HEAD_SIZE,
output_dtype,
device,
avg_time_torch_kernel,
avg_time_triton_kernel,
avg_time_cuda_kernel,
performance_improved,
)
)
if len(all_case_info) == (
len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) * len(NUM_QUERY_HEADS) * len(DTYPES)
):
generate_markdown_table()

View File

@@ -5,6 +5,7 @@ Test:
* Tests for MultiHeadAttention layer
"""
from unittest.mock import patch
import pytest
@@ -21,11 +22,11 @@ from vllm.platforms.rocm import RocmPlatform
@pytest.fixture(autouse=True)
def clear_cache():
"""Clear lru cache to ensure each test case runs without caching.
"""
"""Clear lru cache to ensure each test case runs without caching."""
_cached_get_attn_backend.cache_clear()
# Clear xformers availability cache
import vllm.attention.layer as layer_module
layer_module.USE_XFORMERS_OPS = None
@@ -37,49 +38,63 @@ def test_mha_attn_platform(device: str):
torch.set_default_dtype(torch.float16)
if device == "cpu":
with patch("vllm.attention.layer.current_platform", CpuPlatform()), \
patch("vllm.model_executor.models.vision.current_platform",
CpuPlatform()):
with (
patch("vllm.attention.layer.current_platform", CpuPlatform()),
patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()),
):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA
elif device == "hip":
with patch("vllm.attention.layer.current_platform", RocmPlatform()), \
patch("vllm.model_executor.models.vision.current_platform",
RocmPlatform()):
with (
patch("vllm.attention.layer.current_platform", RocmPlatform()),
patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA
else:
# Test CUDA with head_size=64 (divisible by 32)
# - should use vLLM's FlashAttention
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
patch("vllm.model_executor.models.vision.current_platform",
CudaPlatform()):
with (
patch("vllm.attention.layer.current_platform", CudaPlatform()),
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.FLASH_ATTN
# Test CUDA with head_size=72 (not divisible by 32)
# - with upstream FA not available
# - should use xformers
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
patch("vllm.model_executor.models.vision.current_platform",
CudaPlatform()), \
patch("vllm.attention.layer.check_upstream_fa_availability",
return_value=False):
with (
patch("vllm.attention.layer.current_platform", CudaPlatform()),
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
patch(
"vllm.attention.layer.check_upstream_fa_availability",
return_value=False,
),
):
attn = MultiHeadAttention(16, 72, scale=1)
assert attn.attn_backend == _Backend.XFORMERS
# Test CUDA with head_size=72 (not divisible by 32)
# - with upstream FA available
# - should use upstream FA
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
patch("vllm.model_executor.models.vision.current_platform",
CudaPlatform()), \
patch("vllm.attention.layer.check_upstream_fa_availability",
return_value=True), \
patch.dict('sys.modules', {'flash_attn': type('MockFlashAttn', (),
{
'flash_attn_varlen_func': lambda *args, **kwargs: None
})()}):
with (
patch("vllm.attention.layer.current_platform", CudaPlatform()),
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
patch(
"vllm.attention.layer.check_upstream_fa_availability", return_value=True
),
patch.dict(
"sys.modules",
{
"flash_attn": type(
"MockFlashAttn",
(),
{"flash_attn_varlen_func": lambda *args, **kwargs: None},
)()
},
),
):
attn = MultiHeadAttention(16, 72, scale=1)
assert attn.attn_backend == _Backend.FLASH_ATTN
@@ -108,9 +123,11 @@ NUM_HEADS = [1, 16]
NUM_KV_HEADS = [1]
HEAD_SIZES = [64, 80]
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [
torch.half, torch.bfloat16, torch.float
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
DTYPES = (
[torch.half, torch.bfloat16, torch.float]
if not current_platform.is_rocm()
else [torch.half, torch.bfloat16]
)
CUDA_DEVICES = ["cuda"]
@@ -138,10 +155,9 @@ def test_mha_attn_forward(
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
scale = 1.0 / head_size**0.5
attn = MultiHeadAttention(num_heads,
head_size,
scale=scale,
num_kv_heads=num_kv_heads)
attn = MultiHeadAttention(
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
)
output = attn(q, k, v)
assert num_heads % num_kv_heads == 0

View File

@@ -11,30 +11,24 @@ from vllm.utils import cdiv
def ref_mla(
out: Tensor, # (bs, num_heads, v_head_dim)
query: Tensor, # (bs, num_heads, head_dim)
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
scale: float,
block_tables: Tensor, # (bs, max_num_blocks)
seq_lens: Tensor, # (bs,)
out: Tensor, # (bs, num_heads, v_head_dim)
query: Tensor, # (bs, num_heads, head_dim)
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
scale: float,
block_tables: Tensor, # (bs, max_num_blocks)
seq_lens: Tensor, # (bs,)
):
bs, num_heads, v_head_dim = out.shape
head_dim = query.shape[2]
for i in range(bs):
# gather and flatten KV-cache
kv = kv_cache[
block_tables[i]] # (max_num_blocks, block_size, head_dim)
kv = kv.view(1, -1,
head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim)
kv = kv_cache[block_tables[i]] # (max_num_blocks, block_size, head_dim)
kv = kv.view(1, -1, head_dim)[:, : seq_lens[i]] # (1, seq_len, head_dim)
v = kv[:, :, :v_head_dim]
q = query[i].view(num_heads, 1, head_dim)
o = F.scaled_dot_product_attention(q,
kv,
v,
scale=scale,
enable_gqa=True)
o = F.scaled_dot_product_attention(q, kv, v, scale=scale, enable_gqa=True)
out[i] = o.view(num_heads, v_head_dim)
return out
@@ -63,18 +57,17 @@ def test_mla_decode_cpu(
torch.set_default_dtype(dtype)
torch.manual_seed(0)
scale = d**(-0.5)
scale = d ** (-0.5)
if varlen:
seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2)
seq_lens = seq_lens.clip(2).to(torch.int32)
else:
seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32)
seq_lens = torch.full((bs,), mean_seq_len, dtype=torch.int32)
max_seq_len = seq_lens.max().item()
seqlen_pad = cdiv(max_seq_len, 256) * 256 # is this necessary?
q = torch.randn(bs, h_q, d)
block_table = torch.arange(bs * seqlen_pad // block_size,
dtype=torch.int32)
block_table = torch.arange(bs * seqlen_pad // block_size, dtype=torch.int32)
block_table = block_table.view(bs, seqlen_pad // block_size)
kv_cache = torch.randn(block_table.numel(), block_size, d)
@@ -82,8 +75,7 @@ def test_mla_decode_cpu(
kv_cache.view(bs, seqlen_pad, d)[i, seq_len:] = float("nan")
out_mla = q.new_zeros(bs, h_q, dv)
ops.mla_decode_kvcache_cpu(out_mla, q, kv_cache, scale, block_table,
seq_lens)
ops.mla_decode_kvcache_cpu(out_mla, q, kv_cache, scale, block_table, seq_lens)
out_ref = q.new_zeros(bs, h_q, dv)
ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)

View File

@@ -39,7 +39,7 @@ def test_pack_seq_basic_fp8():
start_idx = sum(lengths_list[:b])
seq_len = lengths_list[b]
expected_data = x[start_idx:start_idx + seq_len].to(torch.float32)
expected_data = x[start_idx : start_idx + seq_len].to(torch.float32)
actual_data = packed[b, :seq_len].to(torch.float32)
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
@@ -62,7 +62,7 @@ def test_pack_seq_custom_padding_fp8():
# Check valid data
for b in range(B):
start_idx = b * 10
expected_data = x[start_idx:start_idx + 10].to(torch.float32)
expected_data = x[start_idx : start_idx + 10].to(torch.float32)
actual_data = result[b, :10].to(torch.float32)
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
@@ -73,9 +73,7 @@ def test_pack_seq_custom_padding_fp8():
elif pad_value > 0:
assert torch.all(padded_data > 50) # Large positive values
else:
assert torch.allclose(padded_data,
torch.zeros_like(padded_data),
atol=1e-2)
assert torch.allclose(padded_data, torch.zeros_like(padded_data), atol=1e-2)
def test_pack_seq_default_negative_inf_padding_fp8():
@@ -93,7 +91,8 @@ def test_pack_seq_default_negative_inf_padding_fp8():
# Check that padding is large negative values (fp8 representation of -inf)
padded_data = result[:, 10:].to(torch.float32)
assert torch.all(
padded_data < -100) # fp8 -inf is represented as large negative number
padded_data < -100
) # fp8 -inf is represented as large negative number
def test_pack_seq_edge_cases_fp8():
@@ -142,7 +141,7 @@ def test_pack_seq_different_block_sizes_fp8():
# Check that valid data is preserved (within fp8 precision)
for b in range(B):
start_idx = b * 25
expected_data = x[start_idx:start_idx + 25].to(torch.float32)
expected_data = x[start_idx : start_idx + 25].to(torch.float32)
actual_data = result[b, :25].to(torch.float32)
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
@@ -198,10 +197,7 @@ def test_pack_unpack_roundtrip_fp8():
# Unpack without explicit start locations (computed in kernel)
unpacked_with_loc = unpack_seq_triton(packed, lengths)
assert_close(x_f32,
unpacked_with_loc.to(torch.float32),
rtol=1e-3,
atol=1e-2)
assert_close(x_f32, unpacked_with_loc.to(torch.float32), rtol=1e-3, atol=1e-2)
def test_unpack_seq_triton_edge_cases_fp8():
@@ -216,10 +212,7 @@ def test_unpack_seq_triton_edge_cases_fp8():
packed = pack_seq_triton(x, lengths)
unpacked = unpack_seq_triton(packed, lengths)
assert unpacked.shape == x.shape
assert_close(x.to(torch.float32),
unpacked.to(torch.float32),
rtol=1e-1,
atol=1e-2)
assert_close(x.to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2)
# Test with very short sequences
x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1
@@ -228,10 +221,9 @@ def test_unpack_seq_triton_edge_cases_fp8():
packed = pack_seq_triton(x, lengths)
unpacked = unpack_seq_triton(packed, lengths)
# Only compare the first 3 elements that were actually packed
assert_close(x[:3].to(torch.float32),
unpacked.to(torch.float32),
rtol=1e-1,
atol=1e-2)
assert_close(
x[:3].to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2
)
x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
@@ -239,7 +231,4 @@ def test_unpack_seq_triton_edge_cases_fp8():
packed = pack_seq_triton(x, lengths)
unpacked = unpack_seq_triton(packed, lengths)
assert unpacked.shape == x.shape
assert_close(x.to(torch.float32),
unpacked.to(torch.float32),
rtol=1e-1,
atol=1e-2)
assert_close(x.to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2)

View File

@@ -12,8 +12,7 @@ from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
from tests.kernels.utils import make_alibi_bias
from vllm.attention.ops.chunked_prefill_paged_decode import (
chunked_prefill_paged_decode)
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
@@ -22,9 +21,7 @@ NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 64]
HEAD_SIZES = [24, 128]
DTYPES = [torch.float16]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
SLIDING_WINDOW = [0, 16, 2048]
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]
@@ -50,12 +47,10 @@ def test_contexted_kv_attention(
device: str,
op: Callable,
) -> None:
if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
89):
if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89):
pytest.skip(
'Triton limitation: fp8e4nv data type is not supported on CUDA'
' arch < 89')
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
)
current_platform.seed_everything(0)
torch.set_default_device(device)
@@ -93,38 +88,29 @@ def test_contexted_kv_attention(
cache_dtype = dtype
else:
cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
k_cache = torch.zeros(cache_size,
block_size,
num_kv_heads,
head_size,
dtype=cache_dtype)
v_cache = torch.zeros(cache_size,
block_size,
num_kv_heads,
head_size,
dtype=cache_dtype)
k_cache = torch.zeros(
cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
)
v_cache = torch.zeros(
cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
)
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
values = torch.arange(0, cache_size, dtype=torch.long)
values = values[torch.randperm(cache_size)]
block_table = values[:BS * max_block_per_request].view(
BS, max_block_per_request)
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
dtype=torch.long),
dim=0)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0)
max_input_len = MAX_SEQ_LEN
# copy kv to cache
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
dtype=torch.long),
dim=0)
b_seq_start_loc = torch.cumsum(
torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0
)
for i in range(BS):
for j in range(query_lens[i]):
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
j])
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
b_ctx_len[i] + j])
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j])
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j])
cur_ctx = 0
block_id = 0
while cur_ctx < b_ctx_len[i]:
@@ -135,61 +121,71 @@ def test_contexted_kv_attention(
end_loc = start_loc + block_size
start_slot = block_table[i, block_id] * block_size
end_slot = start_slot + end_loc - start_loc
k_cache.view(-1, num_kv_heads,
head_size)[start_slot:end_slot].copy_(
key[start_loc:end_loc])
v_cache.view(-1, num_kv_heads,
head_size)[start_slot:end_slot].copy_(
value[start_loc:end_loc])
k_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
key[start_loc:end_loc]
)
v_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
value[start_loc:end_loc]
)
cur_ctx += block_size
block_id += 1
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
8).permute(0, 2, 3, 1, 4).contiguous()
k_cache = (
k_cache.view(-1, block_size, num_kv_heads, head_size // 8, 8)
.permute(0, 2, 3, 1, 4)
.contiguous()
)
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache = v_cache.view(-1, block_size, num_kv_heads,
head_size).permute(0, 2, 3, 1).contiguous()
v_cache = (
v_cache.view(-1, block_size, num_kv_heads, head_size)
.permute(0, 2, 3, 1)
.contiguous()
)
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Warm up the Triton kernel by calling it once before actually measuring
# generation time
op(query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
MAX_CTX_LEN,
max_input_len,
k_scale,
v_scale,
sliding_window=sliding_window)
op(
query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
MAX_CTX_LEN,
max_input_len,
k_scale,
v_scale,
sliding_window=sliding_window,
)
torch.cuda.synchronize()
start_time = time.time()
op(query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
MAX_CTX_LEN,
max_input_len,
k_scale,
v_scale,
sliding_window=sliding_window)
op(
query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
MAX_CTX_LEN,
max_input_len,
k_scale,
v_scale,
sliding_window=sliding_window,
)
torch.cuda.synchronize()
end_time = time.time()
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms")
scale = float(1.0 / (head_size**0.5))
@@ -201,22 +197,24 @@ def test_contexted_kv_attention(
# heads.
#
# see also: vllm/model_executor/layers/attention.py
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
query.shape[-1])
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], num_kv_heads,
num_queries_per_kv, value.shape[-1])
query = query.view(
query.shape[0], num_kv_heads, num_queries_per_kv, query.shape[-1]
)
key = key[:, :, None, :].expand(
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
)
value = value[:, :, None, :].expand(
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
)
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
query_lens, seq_lens)
query_lens, seq_lens
)
if sliding_window > 0:
attn_bias = attn_bias.make_local_attention_from_bottomright(
sliding_window)
attn_bias = attn_bias.make_local_attention_from_bottomright(sliding_window)
output_ref = xops.memory_efficient_attention_forward(
query,
key,
@@ -239,7 +237,7 @@ def test_contexted_kv_attention(
)
torch.cuda.synchronize()
end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms")
output_ref = output_ref.reshape(output.shape)
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
@@ -262,12 +260,10 @@ def test_contexted_kv_attention_alibi(
device: str,
op: Callable,
) -> None:
if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
89):
if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89):
pytest.skip(
'Triton limitation: fp8e4nv data type is not supported on CUDA'
' arch < 89')
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
)
current_platform.seed_everything(0)
torch.set_default_device(device)
@@ -280,9 +276,9 @@ def test_contexted_kv_attention_alibi(
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
# Fork from: vllm/vllm/model_executor/models/bloom.py#L44
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads))
base = torch.tensor(
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32,
)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
@@ -290,17 +286,16 @@ def test_contexted_kv_attention_alibi(
if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32,
)
num_remaining_heads = min(closest_power_of_2,
total_num_heads - closest_power_of_2)
extra_powers = torch.arange(start=1,
end=1 + 2 * num_remaining_heads,
step=2,
dtype=torch.int32)
slopes = torch.cat(
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
num_remaining_heads = min(
closest_power_of_2, total_num_heads - closest_power_of_2
)
extra_powers = torch.arange(
start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32
)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
alibi_slopes = _get_alibi_slopes(num_heads).to(device)
@@ -328,38 +323,29 @@ def test_contexted_kv_attention_alibi(
cache_dtype = dtype
else:
cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
k_cache = torch.zeros(cache_size,
block_size,
num_kv_heads,
head_size,
dtype=cache_dtype)
v_cache = torch.zeros(cache_size,
block_size,
num_kv_heads,
head_size,
dtype=cache_dtype)
k_cache = torch.zeros(
cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
)
v_cache = torch.zeros(
cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
)
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
values = torch.arange(0, cache_size, dtype=torch.long)
values = values[torch.randperm(cache_size)]
block_table = values[:BS * max_block_per_request].view(
BS, max_block_per_request)
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
dtype=torch.long),
dim=0)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0)
max_input_len = MAX_SEQ_LEN
# copy kv to cache
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
dtype=torch.long),
dim=0)
b_seq_start_loc = torch.cumsum(
torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0
)
for i in range(BS):
for j in range(query_lens[i]):
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
j])
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
b_ctx_len[i] + j])
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j])
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j])
cur_ctx = 0
block_id = 0
while cur_ctx < b_ctx_len[i]:
@@ -370,82 +356,90 @@ def test_contexted_kv_attention_alibi(
end_loc = start_loc + block_size
start_slot = block_table[i, block_id] * block_size
end_slot = start_slot + end_loc - start_loc
k_cache.view(-1, num_kv_heads,
head_size)[start_slot:end_slot].copy_(
key[start_loc:end_loc])
v_cache.view(-1, num_kv_heads,
head_size)[start_slot:end_slot].copy_(
value[start_loc:end_loc])
k_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
key[start_loc:end_loc]
)
v_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
value[start_loc:end_loc]
)
cur_ctx += block_size
block_id += 1
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
8).permute(0, 2, 3, 1, 4).contiguous()
k_cache = (
k_cache.view(-1, block_size, num_kv_heads, head_size // 8, 8)
.permute(0, 2, 3, 1, 4)
.contiguous()
)
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache = v_cache.view(-1, block_size, num_kv_heads,
head_size).permute(0, 2, 3, 1).contiguous()
v_cache = (
v_cache.view(-1, block_size, num_kv_heads, head_size)
.permute(0, 2, 3, 1)
.contiguous()
)
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Warm up the Triton kernel by calling it once before actually measuring
# generation time
op(query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
MAX_CTX_LEN,
max_input_len,
k_scale,
v_scale,
alibi_slopes=alibi_slopes)
op(
query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
MAX_CTX_LEN,
max_input_len,
k_scale,
v_scale,
alibi_slopes=alibi_slopes,
)
torch.cuda.synchronize()
start_time = time.time()
op(query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
MAX_CTX_LEN,
max_input_len,
k_scale,
v_scale,
alibi_slopes=alibi_slopes)
op(
query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
MAX_CTX_LEN,
max_input_len,
k_scale,
v_scale,
alibi_slopes=alibi_slopes,
)
torch.cuda.synchronize()
end_time = time.time()
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms")
scale = float(1.0 / (head_size**0.5))
# NOTE(DefTruth): In order to reuse _make_alibi_bias function,
# we have to pad query tensor before MQA/GQA expanding.
if query.shape[0] != key.shape[0]:
query_pad = torch.empty(sum(seq_lens),
num_heads,
head_size,
dtype=dtype)
query_pad = torch.empty(sum(seq_lens), num_heads, head_size, dtype=dtype)
query_pad.uniform_(-1e-3, 1e-3)
seq_start = 0
query_start = 0
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
seq_end = seq_start + seq_len
query_end = query_start + query_len
query_pad[seq_start:seq_end, ...] = torch.cat([
torch.zeros(
seq_len - query_len, num_heads, head_size, dtype=dtype),
query[query_start:query_end, ...]
],
dim=0)
query_pad[seq_start:seq_end, ...] = torch.cat(
[
torch.zeros(seq_len - query_len, num_heads, head_size, dtype=dtype),
query[query_start:query_end, ...],
],
dim=0,
)
seq_start += seq_len
query_start += query_len
query = query_pad
@@ -456,11 +450,12 @@ def test_contexted_kv_attention_alibi(
# heads.
#
# see also: vllm/model_executor/layers/attention.py
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], num_kv_heads,
num_queries_per_kv, value.shape[-1])
key = key[:, :, None, :].expand(
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
)
value = value[:, :, None, :].expand(
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
)
# [seq, num_kv_heads, num_queries_per_kv, dk]=>
# [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the
# codebase. We save some time reshaping alibi matrix at runtime.
@@ -483,24 +478,23 @@ def test_contexted_kv_attention_alibi(
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
seq_end = seq_start + seq_len
query_end = query_start + query_len
out = xops.memory_efficient_attention_forward(query[:,
seq_start:seq_end],
key[:,
seq_start:seq_end],
value[:,
seq_start:seq_end],
attn_bias=attn_bias[i],
p=0.0,
scale=scale)
out = xops.memory_efficient_attention_forward(
query[:, seq_start:seq_end],
key[:, seq_start:seq_end],
value[:, seq_start:seq_end],
attn_bias=attn_bias[i],
p=0.0,
scale=scale,
)
out = out.view_as(query[:, seq_start:seq_end]).view(
seq_len, num_heads, head_size)
output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:,
...])
seq_len, num_heads, head_size
)
output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len :, ...])
seq_start += seq_len
query_start += query_len
torch.cuda.synchronize()
end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms")
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
@@ -532,9 +526,16 @@ def test_contexted_kv_attention_f32(
device: str,
op: Callable,
) -> None:
test_contexted_kv_attention(num_heads, num_queries_per_kv, head_size,
sliding_window, dtype, kv_cache_dtype, device,
op)
test_contexted_kv_attention(
num_heads,
num_queries_per_kv,
head_size,
sliding_window,
dtype,
kv_cache_dtype,
device,
op,
)
@pytest.mark.optional
@@ -555,5 +556,6 @@ def test_contexted_kv_attention_alibi_f32(
device: str,
op: Callable,
) -> None:
test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size,
dtype, kv_cache_dtype, device, op)
test_contexted_kv_attention_alibi(
num_heads, num_queries_per_kv, head_size, dtype, kv_cache_dtype, device, op
)

View File

@@ -11,8 +11,7 @@ from vllm.utils import STR_BACKEND_ENV_VAR
@pytest.fixture(autouse=True)
def clear_cache():
"""Clear lru cache to ensure each test case runs without caching.
"""
"""Clear lru cache to ensure each test case runs without caching."""
_cached_get_attn_backend.cache_clear()
@@ -22,46 +21,29 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_FLASH")
# Set the current platform to ROCm using monkeypatch
monkeypatch.setattr("vllm.attention.selector.current_platform",
RocmPlatform())
monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform())
# Test standard ROCm attention
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
assert (backend.get_name() == "ROCM_FLASH"
or backend.get_name() == "TRITON_ATTN")
assert backend.get_name() == "ROCM_FLASH" or backend.get_name() == "TRITON_ATTN"
# MLA test for deepseek related
# change the attention backend to triton MLA
m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA")
backend = get_attn_backend(576,
torch.bfloat16,
"auto",
16,
False,
use_mla=True)
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
assert backend.get_name() == "TRITON_MLA"
# If attention backend is None
# If use_mla is true
# The selected backend is triton MLA
m.setenv(STR_BACKEND_ENV_VAR, None)
backend = get_attn_backend(576,
torch.bfloat16,
"auto",
16,
False,
use_mla=True)
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
assert backend.get_name() == "TRITON_MLA"
# change the attention backend to AITER MLA
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")
backend = get_attn_backend(576,
torch.bfloat16,
"auto",
1,
False,
use_mla=True)
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
assert backend.get_name() == "ROCM_AITER_MLA"
# If attention backend is None
@@ -70,10 +52,5 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
# The selected backend is ROCM_AITER_MLA
m.setenv(STR_BACKEND_ENV_VAR, None)
m.setenv("VLLM_ROCM_USE_AITER", "1")
backend = get_attn_backend(576,
torch.bfloat16,
"auto",
1,
False,
use_mla=True)
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
assert backend.get_name() == "ROCM_AITER_MLA"

View File

@@ -24,14 +24,12 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
num_kv_splits = 8
num_pages_per_batch = cdiv(seq_len, PAGE_SIZE)
req_to_page = torch.randint(0,
CACHE_SIZE // PAGE_SIZE,
(B, num_pages_per_batch, 1),
device="cuda")
req_to_page = torch.randint(
0, CACHE_SIZE // PAGE_SIZE, (B, num_pages_per_batch, 1), device="cuda"
)
req_to_token = req_to_page * PAGE_SIZE
req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE)
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(
1, 1, -1)
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(1, 1, -1)
req_to_token = req_to_token.view(B, -1)
req_to_token = req_to_token[:, :seq_len].contiguous()
@@ -48,7 +46,7 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
lse = torch.zeros(B, H_Q, dtype=dtype, device="cuda")
b_seq_len = torch.full((B, ), seq_len, device="cuda")
b_seq_len = torch.full((B,), seq_len, device="cuda")
attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1),

View File

@@ -14,9 +14,11 @@ HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16]
DTYPES = [torch.bfloat16]
QDTYPES = [None, torch.float8_e4m3fn] if not current_platform.is_rocm() else [
None, torch.float8_e4m3fnuz
]
QDTYPES = (
[None, torch.float8_e4m3fn]
if not current_platform.is_rocm()
else [None, torch.float8_e4m3fnuz]
)
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]
@@ -42,7 +44,7 @@ def ref_paged_attn(
for i in range(num_seqs):
query_len = query_lens[i]
kv_len = kv_lens[i]
q = query[start_idx:start_idx + query_len]
q = query[start_idx : start_idx + query_len]
q *= scale
num_kv_blocks = (kv_len + block_size - 1) // block_size
@@ -60,10 +62,13 @@ def ref_paged_attn(
empty_mask = torch.ones(query_len, kv_len)
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
if sliding_window is not None:
sliding_window_mask = torch.triu(empty_mask,
diagonal=kv_len -
(query_len + sliding_window) +
1).bool().logical_not()
sliding_window_mask = (
torch.triu(
empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1
)
.bool()
.logical_not()
)
mask |= sliding_window_mask
if soft_cap is not None and soft_cap > 0:
attn = soft_cap * torch.tanh(attn / soft_cap)
@@ -77,9 +82,9 @@ def ref_paged_attn(
return torch.cat(outputs, dim=0)
@pytest.mark.parametrize("seq_lens",
[[(1, 1328), (5, 18),
(129, 463)], [(1, 523), (1, 37), (1, 2011)]])
@pytest.mark.parametrize(
"seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]]
)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@@ -111,30 +116,23 @@ def test_triton_unified_attn(
assert num_query_heads % num_kv_heads == 0
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
(-1, -1))
window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1)
scale = head_size**-0.5
query = torch.randn(sum(query_lens),
num_query_heads,
head_size,
dtype=dtype)
key_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype)
key_cache = torch.randn(
num_blocks, block_size, num_kv_heads, head_size, dtype=dtype
)
value_cache = torch.randn_like(key_cache)
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
dim=0, dtype=torch.int32
)
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
block_tables = torch.randint(
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
)
output = torch.empty_like(query)
@@ -188,5 +186,7 @@ def test_triton_unified_attn(
atol, rtol = 1.5e-2, 1e-2
if q_dtype is not None:
atol, rtol = 1.5e-1, 1.5e-1
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - ref_output))}"
(
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
f"{torch.max(torch.abs(output - ref_output))}",
)

View File

@@ -8,19 +8,23 @@ import torch
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
GeluAndMul, MulAndSilu,
NewGELU, QuickGELU,
SiluAndMul, SwigluOAIAndMul)
from vllm.model_executor.layers.activation import (
FastGELU,
FatreluAndMul,
GeluAndMul,
MulAndSilu,
NewGELU,
QuickGELU,
SiluAndMul,
SwigluOAIAndMul,
)
from vllm.platforms import current_platform
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
D = [512, 13824] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
@pytest.mark.parametrize(
@@ -73,24 +77,19 @@ def test_act_and_mul(
out = layer(x)
ref_out = layer.forward_native(x)
if activation == "swigluoai_and_mul":
rtol = {
#For fp16, change the relative tolerance from 1e-3 to 2e-3
torch.float16:
2e-3,
torch.bfloat16:
2e-2,
torch.float:
1.3e-6
# For fp16, change the relative tolerance from 1e-3 to 2e-3
torch.float16: 2e-3,
torch.bfloat16: 2e-2,
torch.float: 1.3e-6,
}
def _get_rtol(output) -> float:
return rtol[output.dtype]
torch.testing.assert_close(out,
ref_out,
atol=get_default_atol(out),
rtol=_get_rtol(out))
torch.testing.assert_close(
out, ref_out, atol=get_default_atol(out), rtol=_get_rtol(out)
)
else:
# The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are
# equivalent to the native PyTorch implementations, so we can do exact
@@ -98,7 +97,7 @@ def test_act_and_mul(
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if activation == "fatrelu":
opcheck(fn, (out, x, threshold))
@@ -108,9 +107,14 @@ def test_act_and_mul(
opcheck(fn, (out, x))
@pytest.mark.parametrize("activation", [(FastGELU, torch.ops._C.gelu_fast),
(NewGELU, torch.ops._C.gelu_new),
(QuickGELU, torch.ops._C.gelu_quick)])
@pytest.mark.parametrize(
"activation",
[
(FastGELU, torch.ops._C.gelu_fast),
(NewGELU, torch.ops._C.gelu_new),
(QuickGELU, torch.ops._C.gelu_quick),
],
)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@@ -132,10 +136,9 @@ def test_activation(
fn = activation[1]
out = layer(x)
ref_out = layer.forward_native(x)
torch.testing.assert_close(out,
ref_out,
atol=get_default_atol(out),
rtol=get_default_rtol(out))
torch.testing.assert_close(
out, ref_out, atol=get_default_atol(out), rtol=get_default_rtol(out)
)
out = torch.empty_like(x)
opcheck(fn, (out, x))

View File

@@ -24,9 +24,7 @@ NUM_TOKENS_HIDDEN_SIZES = [
ADD_RESIDUAL = [False, True]
SCALE_UBS = [True, False]
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
EPS = 1e-6
@@ -34,13 +32,12 @@ EPS = 1e-6
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
return torch.as_tensor(x, dtype=torch.float32, device='cuda')
return torch.as_tensor(x, dtype=torch.float32, device="cuda")
def ref_rms_norm(rms_norm_layer: RMSNorm,
x: torch.Tensor,
residual: Optional[torch.Tensor]) \
-> tuple[torch.Tensor, Optional[torch.Tensor]]:
def ref_rms_norm(
rms_norm_layer: RMSNorm, x: torch.Tensor, residual: Optional[torch.Tensor]
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if residual is not None:
residual = residual.clone()
out, residual = rms_norm_layer.forward_native(x, residual)
@@ -50,12 +47,13 @@ def ref_rms_norm(rms_norm_layer: RMSNorm,
return out, residual
def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor]) \
-> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
def ref_dynamic_per_token_quant(
rms_norm_layer: RMSNorm,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
if scale_ub is not None:
assert quant_dtype == torch.float8_e4m3fn
@@ -64,9 +62,9 @@ def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm,
# Quant
if quant_dtype == torch.float8_e4m3fn:
torch_out, scales = ops.scaled_fp8_quant(torch_out,
scale_ub=scale_ub,
use_per_token_if_dynamic=True)
torch_out, scales = ops.scaled_fp8_quant(
torch_out, scale_ub=scale_ub, use_per_token_if_dynamic=True
)
else:
assert quant_dtype == torch.int8
torch_out, scales = ops.scaled_int8_quant(torch_out)
@@ -74,38 +72,41 @@ def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm,
return torch_out, scales, residual
def ref_impl(rms_norm_layer: RMSNorm,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor]) \
-> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return ref_dynamic_per_token_quant(rms_norm_layer, x, quant_dtype,
residual, scale_ub)
def ref_impl(
rms_norm_layer: RMSNorm,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return ref_dynamic_per_token_quant(
rms_norm_layer, x, quant_dtype, residual, scale_ub
)
def ops_dynamic_per_token_quant(weight: torch.Tensor,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor]) \
-> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
def ops_dynamic_per_token_quant(
weight: torch.Tensor,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
if residual is not None:
residual = residual.clone()
out, scales = ops.rms_norm_dynamic_per_token_quant(x, weight, EPS,
quant_dtype, scale_ub,
residual)
out, scales = ops.rms_norm_dynamic_per_token_quant(
x, weight, EPS, quant_dtype, scale_ub, residual
)
return out, scales, residual
def ops_impl(weight: torch.Tensor,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor]) \
-> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual,
scale_ub)
def ops_impl(
weight: torch.Tensor,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual, scale_ub)
@pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES)
@@ -146,12 +147,14 @@ def test_rms_norm(
residual = torch.randn_like(x) * scale if add_residual else None
if scale_ub is not None:
rms_x, _ = ref_rms_norm(layer, x, residual)
scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device='cuda')
scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device="cuda")
ref_out, ref_scales, ref_residual = \
ref_impl(layer, x, quant_dtype, residual, scale_ub)
ops_out, ops_scales, ops_residual = \
ops_impl(layer.weight, x, quant_dtype, residual, scale_ub)
ref_out, ref_scales, ref_residual = ref_impl(
layer, x, quant_dtype, residual, scale_ub
)
ops_out, ops_scales, ops_residual = ops_impl(
layer.weight, x, quant_dtype, residual, scale_ub
)
assert ref_out.dtype == quant_dtype
assert ops_out.dtype == quant_dtype
@@ -160,15 +163,18 @@ def test_rms_norm(
# big atol to account for round-off errors.
assert torch.allclose(ref_out, ops_out, atol=1)
else:
assert torch.allclose(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))
assert torch.allclose(
ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)
)
if add_residual:
assert torch.allclose(ref_residual, ops_residual)
output = torch.empty_like(x, dtype=quant_dtype)
scales = torch.empty((x.numel() // x.shape[-1], 1),
device=x.device,
dtype=torch.float32)
scales = torch.empty(
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
)
opcheck(torch.ops._C.rms_norm_dynamic_per_token_quant,
(output, x, layer.weight, scales, 1e-5, scale_ub, residual))
opcheck(
torch.ops._C.rms_norm_dynamic_per_token_quant,
(output, x, layer.weight, scales, 1e-5, scale_ub, residual),
)

View File

@@ -11,13 +11,22 @@ from vllm.platforms import current_platform
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
HIDDEN_SIZES = [8, 768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192,
8199] # Arbitrary values for testing
HIDDEN_SIZES = [
8,
768,
769,
770,
771,
5120,
5124,
5125,
5126,
8192,
8199,
] # Arbitrary values for testing
ADD_RESIDUAL = [False, True]
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@@ -63,11 +72,14 @@ def test_rms_norm(
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
if residual is not None:
opcheck(torch.ops._C.fused_add_rms_norm,
(x, residual, layer.weight.data, layer.variance_epsilon))
opcheck(
torch.ops._C.fused_add_rms_norm,
(x, residual, layer.weight.data, layer.variance_epsilon),
)
else:
opcheck(torch.ops._C.rms_norm,
(out, x, layer.weight.data, layer.variance_epsilon))
opcheck(
torch.ops._C.rms_norm, (out, x, layer.weight.data, layer.variance_epsilon)
)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@@ -98,7 +110,8 @@ def test_poly_norm(
opcheck(
torch.ops._C.poly_norm,
(out, x, layer.weight.data, layer.bias.data, layer.variance_epsilon))
(out, x, layer.weight.data, layer.bias.data, layer.variance_epsilon),
)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@@ -144,7 +157,8 @@ def test_fused_rms_norm_quant(
if add_residual:
torch.ops._C.fused_add_rms_norm_static_fp8_quant(
out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6)
out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6
)
# Unfused kernel is in-place so it goes second
# Also use a separate clone of x to avoid modifying the input
@@ -152,29 +166,32 @@ def test_fused_rms_norm_quant(
x_unfused = x_unfused_base[..., :hidden_size]
assert x_unfused.is_contiguous() != strided_input
torch.ops._C.fused_add_rms_norm(x_unfused, residual, weight, 1e-6)
torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused.contiguous(),
quant_scale_t)
torch.ops._C.static_scaled_fp8_quant(
out_quant, x_unfused.contiguous(), quant_scale_t
)
torch.cuda.synchronize()
torch.testing.assert_close(residual_fused,
residual,
atol=1e-2,
rtol=1e-2)
torch.testing.assert_close(residual_fused, residual, atol=1e-2, rtol=1e-2)
opcheck(
torch.ops._C.fused_add_rms_norm_static_fp8_quant,
(out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6))
(out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6),
)
else:
torch.ops._C.rms_norm_static_fp8_quant(out_quant_fused, x, weight,
quant_scale_t, 1e-6)
torch.ops._C.rms_norm_static_fp8_quant(
out_quant_fused, x, weight, quant_scale_t, 1e-6
)
torch.ops._C.rms_norm(out_norm, x, weight, 1e-6)
torch.ops._C.static_scaled_fp8_quant(out_quant, out_norm,
quant_scale_t)
torch.ops._C.static_scaled_fp8_quant(out_quant, out_norm, quant_scale_t)
opcheck(torch.ops._C.rms_norm_static_fp8_quant,
(out_quant_fused, x, weight, quant_scale_t, 1e-6))
opcheck(
torch.ops._C.rms_norm_static_fp8_quant,
(out_quant_fused, x, weight, quant_scale_t, 1e-6),
)
torch.testing.assert_close(out_quant.to(dtype=torch.float32),
out_quant_fused.to(dtype=torch.float32),
atol=1e-3,
rtol=1e-3)
torch.testing.assert_close(
out_quant.to(dtype=torch.float32),
out_quant_fused.to(dtype=torch.float32),
atol=1e-3,
rtol=1e-3,
)

View File

@@ -14,25 +14,25 @@ from vllm.platforms import current_platform
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int,
head_size: int, max_position_embeddings: int,
dtype: torch.dtype, device: torch.device):
def generate_test_data(
num_tokens: int,
num_q_heads: int,
num_kv_heads: int,
head_size: int,
max_position_embeddings: int,
dtype: torch.dtype,
device: torch.device,
):
"""Generate test data for given configuration."""
current_platform.seed_everything(42)
# Create 2D positions (3, num_tokens) for multimodal case
positions = torch.randint(0,
max_position_embeddings // 4, (3, num_tokens),
device=device)
positions = torch.randint(
0, max_position_embeddings // 4, (3, num_tokens), device=device
)
# Create query and key tensors
query = torch.randn(num_tokens,
num_q_heads * head_size,
dtype=dtype,
device=device)
key = torch.randn(num_tokens,
num_kv_heads * head_size,
dtype=dtype,
device=device)
query = torch.randn(num_tokens, num_q_heads * head_size, dtype=dtype, device=device)
key = torch.randn(num_tokens, num_kv_heads * head_size, dtype=dtype, device=device)
return positions, query, key
@@ -59,7 +59,8 @@ MODELS_TO_TEST = [
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
reason="Qwen3-VL only available after Transformers v4.57",
)
]),
],
),
MRoPETestInfo(
model_name="Qwen/Qwen3-VL-30B-A3B-Instruct",
marks=[
@@ -67,24 +68,33 @@ MODELS_TO_TEST = [
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
reason="Qwen3-VL only available after Transformers v4.57",
)
]),
],
),
]
num_tokens_list = [11, 8192]
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Skipping CUDA/ROCm only tests.")
@pytest.mark.parametrize("model_info, model_name", [
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
for test_config in MODELS_TO_TEST
])
@pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests."
)
@pytest.mark.parametrize(
"model_info, model_name",
[
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
for test_config in MODELS_TO_TEST
],
)
@pytest.mark.parametrize("tp_size", [1, 2])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("num_tokens", num_tokens_list)
def test_mrope(model_name: str, model_info: MRoPETestInfo, tp_size: int,
dtype: torch.dtype, num_tokens: int):
def test_mrope(
model_name: str,
model_info: MRoPETestInfo,
tp_size: int,
dtype: torch.dtype,
num_tokens: int,
):
atol = model_info.atol
rtol = model_info.rtol
@@ -96,8 +106,11 @@ def test_mrope(model_name: str, model_info: MRoPETestInfo, tp_size: int,
total_num_heads = config.num_attention_heads
num_heads = total_num_heads // tp_size
num_kv_heads = max(1, total_num_kv_heads // tp_size)
head_dim = (config.head_dim if hasattr(config, "head_dim") else
config.hidden_size // total_num_heads)
head_dim = (
config.head_dim
if hasattr(config, "head_dim")
else config.hidden_size // total_num_heads
)
is_neox_style = True
rope_theta = config.rope_theta
@@ -117,9 +130,9 @@ def test_mrope(model_name: str, model_info: MRoPETestInfo, tp_size: int,
# create q k v input tensors
# create rotary pos emb input tensors
positions, query, key = generate_test_data(num_tokens, num_heads,
num_kv_heads, head_dim,
max_position, dtype, device)
positions, query, key = generate_test_data(
num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device
)
query_native, key_native = mrope_helper_class.forward_native(
positions,
@@ -137,19 +150,26 @@ def test_mrope(model_name: str, model_info: MRoPETestInfo, tp_size: int,
torch.testing.assert_close(key_native, key_cuda, atol=atol, rtol=rtol)
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Skipping CUDA/ROCm only tests.")
@pytest.mark.parametrize("model_info, model_name", [
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
for test_config in MODELS_TO_TEST
])
@pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests."
)
@pytest.mark.parametrize(
"model_info, model_name",
[
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
for test_config in MODELS_TO_TEST
],
)
@pytest.mark.parametrize("tp_size", [1, 2])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("num_tokens", num_tokens_list)
def test_mrope_torch_compile_tracing(model_name: str,
model_info: MRoPETestInfo, tp_size: int,
dtype: torch.dtype, num_tokens: int):
def test_mrope_torch_compile_tracing(
model_name: str,
model_info: MRoPETestInfo,
tp_size: int,
dtype: torch.dtype,
num_tokens: int,
):
atol = model_info.atol
rtol = model_info.rtol
@@ -161,8 +181,11 @@ def test_mrope_torch_compile_tracing(model_name: str,
total_num_heads = config.num_attention_heads
num_heads = total_num_heads // tp_size
num_kv_heads = max(1, total_num_kv_heads // tp_size)
head_dim = (config.head_dim if hasattr(config, "head_dim") else
config.hidden_size // total_num_heads)
head_dim = (
config.head_dim
if hasattr(config, "head_dim")
else config.hidden_size // total_num_heads
)
is_neox_style = True
rope_theta = config.rope_theta
max_position = config.max_position_embeddings
@@ -180,16 +203,16 @@ def test_mrope_torch_compile_tracing(model_name: str,
).to(device=device)
# Generate test data
positions, query, key = generate_test_data(num_tokens, num_heads,
num_kv_heads, head_dim,
max_position, dtype, device)
positions, query, key = generate_test_data(
num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device
)
# Create a wrapper that makes the in-place function appear functional
def functional_forward_cuda(pos, q, k):
"""Wrapper that converts in-place operation to functional style
CUDA Graph does not support in-place operations.
This wrapper creates working copies of the
This wrapper creates working copies of the
input tensors and modifies them.
"""
q_work = q.clone() # Create working copies
@@ -206,11 +229,13 @@ def test_mrope_torch_compile_tracing(model_name: str,
)
try:
compiled_forward_cuda = torch.compile(functional_forward_cuda,
fullgraph=True,
backend="inductor",
mode="reduce-overhead",
dynamic=False)
compiled_forward_cuda = torch.compile(
functional_forward_cuda,
fullgraph=True,
backend="inductor",
mode="reduce-overhead",
dynamic=False,
)
# Run compiled version
query_compiled_cuda, key_compiled_cuda = compiled_forward_cuda(
@@ -225,25 +250,16 @@ def test_mrope_torch_compile_tracing(model_name: str,
mrope_helper_class.forward_cuda(positions, query_cuda, key_cuda)
# Verify results
torch.testing.assert_close(query_compiled_cuda,
query_cuda,
atol=atol,
rtol=rtol)
torch.testing.assert_close(key_compiled_cuda,
key_cuda,
atol=atol,
rtol=rtol)
torch.testing.assert_close(query_compiled_cuda,
query_native,
atol=atol,
rtol=rtol)
torch.testing.assert_close(key_compiled_cuda,
key_native,
atol=atol,
rtol=rtol)
torch.testing.assert_close(
query_compiled_cuda, query_cuda, atol=atol, rtol=rtol
)
torch.testing.assert_close(key_compiled_cuda, key_cuda, atol=atol, rtol=rtol)
torch.testing.assert_close(
query_compiled_cuda, query_native, atol=atol, rtol=rtol
)
torch.testing.assert_close(key_compiled_cuda, key_native, atol=atol, rtol=rtol)
print("✓ forward_cuda successfully traced with torch.compile inductor")
except Exception as e:
pytest.fail(
f"forward_cuda failed to trace with torch.compile inductor: {e}")
pytest.fail(f"forward_cuda failed to trace with torch.compile inductor: {e}")

View File

@@ -8,11 +8,11 @@ from tests.kernels.utils import opcheck
from vllm._custom_ops import permute_cols
@pytest.mark.parametrize('shape', [(1, 512), (544, 4096), (67, 8192)])
@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("shape", [(1, 512), (544, 4096), (67, 8192)])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_permute_cols(shape, dtype):
x = torch.randn(shape, dtype=dtype).cuda()
perm = torch.randperm(x.shape[1]).to(torch.int).cuda()
opcheck(torch.ops._C.permute_cols, (x, perm))
y = permute_cols(x, perm)
torch.testing.assert_close(y, x[:, perm])
torch.testing.assert_close(y, x[:, perm])

View File

@@ -19,30 +19,33 @@ NUM_HEADS = [17] # Arbitrary values for testing
BATCH_SIZES = [5] # Arbitrary values for testing
SEQ_LENS = [11, 8192] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
USE_KEY = [True, False]
def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
head_size: int) -> tuple[int, ...]:
def _get_flat_tensor_shape(
batch_size: int, seq_len: int, num_heads: int, head_size: int
) -> tuple[int, ...]:
return (batch_size, seq_len, num_heads * head_size)
# For testing sliced tensors
def _get_padded_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
head_size: int) -> tuple[int, ...]:
def _get_padded_tensor_shape(
batch_size: int, seq_len: int, num_heads: int, head_size: int
) -> tuple[int, ...]:
return (batch_size, seq_len, num_heads, head_size + 64)
def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
head_size: int) -> tuple[int, ...]:
def _get_batch_tensor_shape(
batch_size: int, seq_len: int, num_heads: int, head_size: int
) -> tuple[int, ...]:
return (batch_size, seq_len, num_heads, head_size)
TENSORS_SHAPES_FN = [
_get_batch_tensor_shape, _get_flat_tensor_shape, _get_padded_tensor_shape
_get_batch_tensor_shape,
_get_flat_tensor_shape,
_get_padded_tensor_shape,
]
@@ -97,41 +100,63 @@ def test_rotary_embedding(
ref_query, ref_key = rope.forward_native(positions, query, key)
out_query, out_key = rope.forward(positions, query, key)
# Compare the results.
torch.testing.assert_close(out_query,
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query))
torch.testing.assert_close(
out_query,
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query),
)
if use_key:
torch.testing.assert_close(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
torch.testing.assert_close(
out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key),
)
else:
assert ref_key is None and out_key is None, \
"expected returned key to be None"
assert ref_key is None and out_key is None, "expected returned key to be None"
@torch.inference_mode()
def test_rope_module_cache():
MAX_POSITIONS = [123, 1234]
BASES = [10000, 1000000]
ROPE_SCALINGS = (None, {
"rope_type": "linear",
"factor": (1, )
}, {
"rope_type": "dynamic",
"factor": 1
})
settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
ROPE_SCALINGS, DTYPES)
ROPE_SCALINGS = (
None,
{"rope_type": "linear", "factor": (1,)},
{"rope_type": "dynamic", "factor": 1},
)
settings = (
HEAD_SIZES,
ROTARY_DIMS,
MAX_POSITIONS,
BASES,
IS_NEOX_STYLE,
ROPE_SCALINGS,
DTYPES,
)
rope_setting_id_map: dict[str, int] = {}
for setting in product(*settings):
head_size, rotary_dim, max_position, base, \
is_neox_stype, rope_scaling, dtype = setting
(
head_size,
rotary_dim,
max_position,
base,
is_neox_stype,
rope_scaling,
dtype,
) = setting
if rotary_dim is None:
rotary_dim = head_size
rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_stype, rope_scaling, dtype)
rope = get_rope(
head_size,
rotary_dim,
max_position,
base,
is_neox_stype,
rope_scaling,
dtype,
)
# different settings cannot share the same rope module
assert id(rope) not in rope_setting_id_map.values()
assert all(x.dtype == dtype for x in rope.buffers())
@@ -139,11 +164,25 @@ def test_rope_module_cache():
rope_setting_id_map[str(setting)] = id(rope)
for setting in product(*settings):
head_size, rotary_dim, max_position, base, \
is_neox_stype, rope_scaling, dtype = setting
(
head_size,
rotary_dim,
max_position,
base,
is_neox_stype,
rope_scaling,
dtype,
) = setting
if rotary_dim is None:
rotary_dim = head_size
rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_stype, rope_scaling, dtype)
rope = get_rope(
head_size,
rotary_dim,
max_position,
base,
is_neox_stype,
rope_scaling,
dtype,
)
# check if cache take effect
assert id(rope) == rope_setting_id_map[str(setting)]

View File

@@ -13,17 +13,20 @@ from tests.kernels.utils import opcheck
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
def rotary_embedding_opcheck(rot,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None):
def rotary_embedding_opcheck(
rot,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
):
cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype)
# ops.rotary_embedding() is a in-place operation
# that updates the query and key tensors.
opcheck(torch.ops._C.rotary_embedding,
(positions, query, key, rot.head_size, cos_sin_cache,
rot.is_neox_style))
opcheck(
torch.ops._C.rotary_embedding,
(positions, query, key, rot.head_size, cos_sin_cache, rot.is_neox_style),
)
@pytest.mark.parametrize("device", ["cuda"])
@@ -34,26 +37,30 @@ def rotary_embedding_opcheck(rot,
@pytest.mark.parametrize("seq_len", [11, 1024])
@pytest.mark.parametrize("use_key", [True, False])
@pytest.mark.parametrize("head_stride_is_contiguous", [True, False])
def test_rotary_embedding_opcheck(dist_init, device, max_position,
is_neox_style, rotary_dim, head_size,
seq_len, use_key, head_stride_is_contiguous):
def test_rotary_embedding_opcheck(
dist_init,
device,
max_position,
is_neox_style,
rotary_dim,
head_size,
seq_len,
use_key,
head_stride_is_contiguous,
):
batch_size = 1
base = 10000
num_heads = 7
rot = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, torch.float32)
rot = RotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, torch.float32
)
positions = torch.randint(0,
max_position, (batch_size, seq_len),
device=device)
positions = torch.randint(0, max_position, (batch_size, seq_len), device=device)
head_stride = head_size + (64 if head_stride_is_contiguous else 0)
query = torch.randn(batch_size,
seq_len,
num_heads,
head_stride,
dtype=torch.float32,
device=device)
query = torch.randn(
batch_size, seq_len, num_heads, head_stride, dtype=torch.float32, device=device
)
key = torch.randn_like(query) if use_key else None
query = query[..., :head_size]
key = key[..., :head_size] if use_key else None
@@ -64,5 +71,8 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
# [..., num_heads * head_dim] shape/layout
if head_stride_is_contiguous:
rotary_embedding_opcheck(
rot, positions, query.flatten(start_dim=-2),
key.flatten(start_dim=-2) if use_key else None)
rot,
positions,
query.flatten(start_dim=-2),
key.flatten(start_dim=-2) if use_key else None,
)

View File

@@ -5,20 +5,14 @@ import torch
from vllm.utils import get_cuda_view_from_cpu_tensor, is_uva_available
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
@pytest.mark.skipif(not is_uva_available(), reason="UVA is not available.")
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_cpu_write(device):
torch.set_default_device(device)
cpu_tensor = torch.zeros(10,
10,
device="cpu",
pin_memory=True,
dtype=torch.int32)
cpu_tensor = torch.zeros(10, 10, device="cpu", pin_memory=True, dtype=torch.int32)
cuda_view = get_cuda_view_from_cpu_tensor(cpu_tensor)
assert cuda_view.device.type == "cuda"
@@ -40,11 +34,7 @@ def test_cpu_write(device):
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_gpu_write(device):
torch.set_default_device(device)
cpu_tensor = torch.zeros(10,
10,
device="cpu",
pin_memory=True,
dtype=torch.int32)
cpu_tensor = torch.zeros(10, 10, device="cpu", pin_memory=True, dtype=torch.int32)
cuda_view = get_cuda_view_from_cpu_tensor(cpu_tensor)
assert cuda_view.device.type == "cuda"
@@ -59,4 +49,4 @@ def test_gpu_write(device):
assert cpu_tensor[0, 0] == 2
assert cpu_tensor[2, 3] == 4
assert cpu_tensor[4, 5] == -2
assert cpu_tensor[4, 5] == -2

View File

@@ -10,7 +10,9 @@ from einops import rearrange
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
causal_conv1d_fn,
causal_conv1d_update,
)
from vllm.platforms import current_platform
@@ -39,18 +41,15 @@ def causal_conv1d_ref(
seqlen = x.shape[-1]
dim, width = weight.shape
if initial_states is None:
out = F.conv1d(x,
weight.unsqueeze(1),
bias,
padding=width - 1,
groups=dim)
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
else:
x = torch.cat([initial_states, x], dim=-1)
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
out = out[..., :seqlen]
if return_final_states:
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
dtype_in) # (batch, dim, width - 1)
dtype_in
) # (batch, dim, width - 1)
if final_states_out is not None:
final_states_out.copy_(final_states)
else:
@@ -59,12 +58,9 @@ def causal_conv1d_ref(
return (out, None) if not return_final_states else (out, final_states_out)
def causal_conv1d_update_ref(x,
conv_state,
weight,
bias=None,
activation=None,
cache_seqlens=None):
def causal_conv1d_update_ref(
x, conv_state, weight, bias=None, activation=None, cache_seqlens=None
):
"""
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
@@ -91,24 +87,25 @@ def causal_conv1d_update_ref(x,
assert weight.shape == (dim, width)
if cache_seqlens is None:
x_new = torch.cat([conv_state, x], dim=-1).to(
weight.dtype) # (batch, dim, state_len + seqlen)
weight.dtype
) # (batch, dim, state_len + seqlen)
conv_state.copy_(x_new[:, :, -state_len:])
else:
width_idx = torch.arange(
-(width - 1), 0, dtype=torch.long,
device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(
-1, dim, -1)
x_new = torch.cat([conv_state.gather(2, width_idx), x],
dim=-1).to(weight.dtype)
copy_idx = torch.arange(
seqlen, dtype=torch.long,
device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
copy_idx = torch.remainder(copy_idx,
state_len).unsqueeze(1).expand(-1, dim, -1)
-(width - 1), 0, dtype=torch.long, device=x.device
).unsqueeze(0) + cache_seqlens.unsqueeze(1)
width_idx = (
torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
)
x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(
0
) + cache_seqlens.unsqueeze(1)
copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
conv_state.scatter_(2, copy_idx, x)
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0,
groups=dim)[:, :, -seqlen:]
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[
:, :, -seqlen:
]
if unsqueeze:
out = out.squeeze(-1)
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
@@ -117,15 +114,17 @@ def causal_conv1d_update_ref(x,
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
def causal_conv1d_opcheck_fn(x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
cu_seq_len: Optional[torch.Tensor] = None,
cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
conv_states: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
pad_slot_id: int = PAD_SLOT_ID):
def causal_conv1d_opcheck_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
cu_seq_len: Optional[torch.Tensor] = None,
cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
conv_states: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
pad_slot_id: int = PAD_SLOT_ID,
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
@@ -150,8 +149,7 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor,
@pytest.mark.parametrize("seqlen", [1])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
itype):
def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
@@ -167,23 +165,16 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state.detach().clone()
activation = None if not silu_activation else "silu"
out = causal_conv1d_update(x,
conv_state,
weight,
bias,
activation=activation)
out_ref = causal_conv1d_update_ref(x_ref,
conv_state_ref,
weight,
bias,
activation=activation)
out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation)
out_ref = causal_conv1d_update_ref(
x_ref, conv_state_ref, weight, bias, activation=activation
)
assert torch.equal(conv_state, conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("seqlen", [1, 3])
@@ -192,9 +183,9 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
@pytest.mark.parametrize("batch_size", [3])
def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim,
width, seqlen, has_bias,
silu_activation, itype):
def test_causal_conv1d_update_with_batch_gather(
batch_size, with_padding, dim, width, seqlen, has_bias, silu_activation, itype
):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
@@ -209,31 +200,30 @@ def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim,
total_entries = 10 * batch_size
# x will be (batch, dim, seqlen) with contiguous along dim-axis
x = torch.randn(padded_batch_size, seqlen, dim, device=device,
dtype=itype).transpose(1, 2)
x = torch.randn(
padded_batch_size, seqlen, dim, device=device, dtype=itype
).transpose(1, 2)
x_ref = x.clone()
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device)
unused_states_bool = torch.ones(total_entries,
dtype=torch.bool,
device=device)
dtype=torch.int32, device=device
)
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
unused_states_bool[conv_state_indices] = False
padded_state_indices = torch.concat([
conv_state_indices,
torch.as_tensor(
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
],
dim=0)
padded_state_indices = torch.concat(
[
conv_state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=0,
)
# conv_state will be (cache_lines, dim, state_len)
# with contiguous along dim-axis
conv_state = torch.randn(total_entries,
width - 1,
dim,
device=device,
dtype=itype).transpose(1, 2)
conv_state = torch.randn(
total_entries, width - 1, dim, device=device, dtype=itype
).transpose(1, 2)
conv_state_for_padding_test = conv_state.clone()
@@ -242,22 +232,23 @@ def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim,
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
activation = None if not silu_activation else "silu"
out = causal_conv1d_update(x,
conv_state,
weight,
bias,
activation=activation,
conv_state_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID)
out_ref = causal_conv1d_update_ref(x_ref[:batch_size],
conv_state_ref,
weight,
bias,
activation=activation)
out = causal_conv1d_update(
x,
conv_state,
weight,
bias,
activation=activation,
conv_state_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID,
)
out_ref = causal_conv1d_update_ref(
x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation
)
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
assert torch.equal(conv_state[unused_states_bool],
conv_state_for_padding_test[unused_states_bool])
assert torch.equal(
conv_state[unused_states_bool], conv_state_for_padding_test[unused_states_bool]
)
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
@@ -265,12 +256,13 @@ def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim,
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize('seqlen', [8, 30, 249, 2049, 4096])
@pytest.mark.parametrize('dim', [64, 4096])
@pytest.mark.parametrize('with_padding', [True, False])
@pytest.mark.parametrize('batch', [4, 10])
def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width,
has_bias, silu_activation, itype):
@pytest.mark.parametrize("seqlen", [8, 30, 249, 2049, 4096])
@pytest.mark.parametrize("dim", [64, 4096])
@pytest.mark.parametrize("with_padding", [True, False])
@pytest.mark.parametrize("batch", [4, 10])
def test_causal_conv1d_varlen(
batch, with_padding, dim, seqlen, width, has_bias, silu_activation, itype
):
device = "cuda"
torch.cuda.empty_cache()
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
@@ -288,19 +280,19 @@ def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width,
seqlens.append(
torch.diff(
torch.cat(
[torch.tensor([-1]), eos_pos,
torch.tensor([seqlen - 1])])).tolist())
torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])
).tolist()
)
assert sum(seqlens[-1]) == seqlen
assert all(s > 0 for s in seqlens[-1])
total_entries = batch_size * 10
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
dim=0)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0)
x = rearrange(
torch.randn(1, seqlen, 4096 + dim + 64, device=device, dtype=itype),
"b s d -> b d s")[:, 4096:4096 + dim, :]
"b s d -> b d s",
)[:, 4096 : 4096 + dim, :]
weight = torch.randn(dim, width, device=device, dtype=itype)
@@ -309,34 +301,34 @@ def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width,
weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None
activation = None if not silu_activation else "silu"
final_states = torch.randn(total_entries,
width - 1,
dim,
device=x.device,
dtype=x.dtype).transpose(1, 2)
final_states = torch.randn(
total_entries, width - 1, dim, device=x.device, dtype=x.dtype
).transpose(1, 2)
final_states_ref = final_states.clone()
has_initial_states = torch.randint(0,
2, (cumsum.shape[0] - 1, ),
dtype=torch.bool,
device=x.device)
state_indices = torch.randperm(total_entries,
dtype=torch.int32,
device=x.device)[:batch_size]
padded_state_indices = torch.concat([
state_indices,
torch.as_tensor(
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=-1)
out = causal_conv1d_fn(x.squeeze(0),
weight,
bias=bias,
conv_states=final_states,
query_start_loc=cumsum.cuda(),
cache_indices=padded_state_indices,
has_initial_state=has_initial_states,
activation=activation,
pad_slot_id=PAD_SLOT_ID)
has_initial_states = torch.randint(
0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=x.device
)
state_indices = torch.randperm(total_entries, dtype=torch.int32, device=x.device)[
:batch_size
]
padded_state_indices = torch.concat(
[
state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=-1,
)
out = causal_conv1d_fn(
x.squeeze(0),
weight,
bias=bias,
conv_states=final_states,
query_start_loc=cumsum.cuda(),
cache_indices=padded_state_indices,
has_initial_state=has_initial_states,
activation=activation,
pad_slot_id=PAD_SLOT_ID,
)
out_ref = []
out_ref_b = []
@@ -353,16 +345,20 @@ def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width,
bias_ref,
activation=activation,
return_final_states=True,
final_states_out=final_states_ref[
padded_state_indices[i]].unsqueeze(0),
initial_states=final_states_ref[padded_state_indices[i]].
unsqueeze(0) if has_initial_states[i] else None))
final_states_out=final_states_ref[padded_state_indices[i]].unsqueeze(0),
initial_states=final_states_ref[padded_state_indices[i]].unsqueeze(0)
if has_initial_states[i]
else None,
)
)
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
out_ref_tensor = torch.cat(out_ref, dim=0)
assert torch.allclose(final_states[state_indices],
final_states_ref[state_indices],
rtol=rtol,
atol=atol)
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
assert torch.allclose(
final_states[state_indices],
final_states_ref[state_indices],
rtol=rtol,
atol=atol,
)
unpadded_out = out[:, : out_ref_tensor.shape[-1]]
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)

View File

@@ -7,8 +7,10 @@ import pytest
import torch
from tests.utils import multi_gpu_test
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
from vllm.distributed.parallel_state import (
init_distributed_environment,
initialize_model_parallel,
)
from vllm.model_executor.layers.mamba.mamba_mixer2 import Mixer2RMSNormGated
from vllm.platforms import current_platform
from vllm.utils import update_environment_variables
@@ -24,14 +26,15 @@ from vllm.utils import update_environment_variables
(64, 2),
(64, 4), # hidden_size be divisible by num_gpus
(100, 5), # and n_groups must divide hidden_size
])
],
)
@pytest.mark.parametrize("dtype", [torch.float16])
def test_mixer2_gated_norm_multi_gpu(
batch_size: int,
seq_len: int,
hidden_size_n_groups: tuple[int, int],
dtype: torch.dtype,
device: str = 'cuda',
device: str = "cuda",
):
hidden_size, n_groups = hidden_size_n_groups
num_processes = 2
@@ -39,17 +42,19 @@ def test_mixer2_gated_norm_multi_gpu(
def run_torch_spawn(fn, nprocs):
# need to use torch.mp.spawn otherwise will have problems with
# torch.distributed and cuda
torch.multiprocessing.spawn(fn,
args=(
num_processes,
batch_size,
seq_len,
hidden_size,
n_groups,
dtype,
device,
),
nprocs=nprocs)
torch.multiprocessing.spawn(
fn,
args=(
num_processes,
batch_size,
seq_len,
hidden_size,
n_groups,
dtype,
device,
),
nprocs=nprocs,
)
run_torch_spawn(mixer2_gated_norm_tensor_parallel, 2)
@@ -71,20 +76,22 @@ def mixer2_gated_norm_tensor_parallel(
torch.set_default_device(device)
torch.set_default_dtype(dtype)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': '12345',
})
update_environment_variables(
{
"RANK": str(local_rank),
"LOCAL_RANK": str(local_rank),
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12345",
}
)
# initialize distributed
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# create random weights an inputs
weight = torch.rand((hidden_size, ), dtype=dtype, device=device)
weight = torch.rand((hidden_size,), dtype=dtype, device=device)
hidden_states = torch.randn(batch_size, seq_len, hidden_size)
gate_states = torch.randn(batch_size, seq_len, hidden_size)
@@ -97,14 +104,18 @@ def mixer2_gated_norm_tensor_parallel(
# create gated-norm without TP to compute reference
# - utilize mock patching to disable TP when
with (unittest.mock.patch(
with (
unittest.mock.patch(
"vllm.model_executor.layers.mamba.mamba_mixer2."
"get_tensor_model_parallel_world_size",
return_value=1),
unittest.mock.patch(
"vllm.model_executor.layers.mamba.mamba_mixer2."
"get_tensor_model_parallel_rank",
return_value=0)):
return_value=1,
),
unittest.mock.patch(
"vllm.model_executor.layers.mamba.mamba_mixer2."
"get_tensor_model_parallel_rank",
return_value=0,
),
):
mixer_single_gpu = Mixer2RMSNormGated(
full_hidden_size=hidden_size,
full_n_groups=n_groups,
@@ -115,12 +126,13 @@ def mixer2_gated_norm_tensor_parallel(
# generate and compare
N = hidden_size // world_size
output = mixer(
hidden_states[..., local_rank * N:(local_rank + 1) * N],
gate_states[..., local_rank * N:(local_rank + 1) * N],
hidden_states[..., local_rank * N : (local_rank + 1) * N],
gate_states[..., local_rank * N : (local_rank + 1) * N],
)
ref_output = mixer_single_gpu(hidden_states, gate_states)
torch.testing.assert_close(output,
ref_output[...,
local_rank * N:(local_rank + 1) * N],
atol=5e-3,
rtol=1e-3)
torch.testing.assert_close(
output,
ref_output[..., local_rank * N : (local_rank + 1) * N],
atol=5e-3,
rtol=1e-3,
)

View File

@@ -10,20 +10,15 @@ from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
selective_scan_fn,
selective_state_update,
)
from vllm.platforms import current_platform
def selective_state_update_ref(state,
x,
dt,
A,
B,
C,
D=None,
z=None,
dt_bias=None,
dt_softplus=False):
def selective_state_update_ref(
state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
):
"""
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
@@ -73,16 +68,17 @@ def selective_state_update_ref(state,
assert dt_bias.shape == (nheads, dim)
dt = dt + dt_bias
dt = F.softplus(dt) if dt_softplus else dt
dA = torch.exp(rearrange(dt, "b h d -> b h d 1") *
A) # (batch, nheads, dim, dstate)
B = repeat(B, "b g n -> b (g h) n",
h=nheads // ngroups) # (batch, nheads, dstate)
C = repeat(C, "b g n -> b (g h) n",
h=nheads // ngroups) # (batch, nheads, dstate)
dA = torch.exp(
rearrange(dt, "b h d -> b h d 1") * A
) # (batch, nheads, dim, dstate)
B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate)
state.copy_(state * dA +
dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate
B, "b h n -> b h 1 n"
) # (batch, nheads, dim, dstate)
state.copy_(
state * dA + dB * rearrange(x, "b h d -> b h d 1")
) # (batch, dim, dstate
out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
if D is not None:
out += (x * D).to(out.dtype)
@@ -92,18 +88,20 @@ def selective_state_update_ref(state,
return out
def selective_scan_ref(u,
delta,
A,
B,
C,
D=None,
z=None,
delta_bias=None,
delta_softplus=False,
return_last_state=False,
prev_state=None,
final_state_out=None):
def selective_scan_ref(
u,
delta,
A,
B,
C,
D=None,
z=None,
delta_bias=None,
delta_softplus=False,
return_last_state=False,
prev_state=None,
final_state_out=None,
):
"""
u: r(B D L)
delta: r(B D L)
@@ -132,26 +130,26 @@ def selective_scan_ref(u,
C = C.float()
x = A.new_zeros((batch, dim, dstate)) if prev_state is None else prev_state
ys = []
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
if not is_variable_B:
deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
else:
if B.dim() == 3:
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
else:
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
if is_variable_C and C.dim() == 4:
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
for i in range(u.shape[2]):
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
if not is_variable_C:
y = torch.einsum('bdn,dn->bd', x, C)
y = torch.einsum("bdn,dn->bd", x, C)
else:
if C.dim() == 3:
y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
else:
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
if i == u.shape[2] - 1:
if final_state_out is None:
final_state_out = x
@@ -166,20 +164,22 @@ def selective_scan_ref(u,
return out if not return_last_state else (out, final_state_out)
def selective_scan_opcheck_fn(u,
delta,
A,
B,
C,
D=None,
z=None,
delta_bias=None,
delta_softplus=False,
cu_seq_len=None,
cache_indices=None,
has_initial_state=None,
ssm_states=None,
pad_slot_id=PAD_SLOT_ID):
def selective_scan_opcheck_fn(
u,
delta,
A,
B,
C,
D=None,
z=None,
delta_bias=None,
delta_softplus=False,
cu_seq_len=None,
cache_indices=None,
has_initial_state=None,
ssm_states=None,
pad_slot_id=PAD_SLOT_ID,
):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate).
"""
@@ -206,30 +206,55 @@ def selective_scan_opcheck_fn(u,
# Disable test_autograd_registration for now as it seems to trigger
# a bogus error.
opcheck(torch.ops._C.selective_scan_fwd,
(u, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seq_len,
cache_indices, has_initial_state, ssm_states, pad_slot_id),
test_utils=["test_schema", "test_faketensor"])
opcheck(
torch.ops._C.selective_scan_fwd,
(
u,
delta,
A,
B,
C,
D,
z,
delta_bias,
delta_softplus,
cu_seq_len,
cache_indices,
has_initial_state,
ssm_states,
pad_slot_id,
),
test_utils=["test_schema", "test_faketensor"],
)
@pytest.mark.parametrize('wtype', [torch.float32])
@pytest.mark.parametrize('itype',
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096])
@pytest.mark.parametrize('has_delta_bias', [True])
@pytest.mark.parametrize('delta_softplus', [True])
@pytest.mark.parametrize('has_z', [True])
@pytest.mark.parametrize('has_D', [True])
@pytest.mark.parametrize("wtype", [torch.float32])
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("seqlen", [128, 256, 512, 1024, 2048, 4096])
@pytest.mark.parametrize("has_delta_bias", [True])
@pytest.mark.parametrize("delta_softplus", [True])
@pytest.mark.parametrize("has_z", [True])
@pytest.mark.parametrize("has_D", [True])
@pytest.mark.parametrize("varBC_groups", [1, 2])
@pytest.mark.parametrize("is_variable_C", [True])
@pytest.mark.parametrize("is_variable_B", [True])
@pytest.mark.parametrize("scan_chunks", [1, 2, 3])
def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
has_z, has_delta_bias, delta_softplus, seqlen, itype,
wtype, scan_chunks):
def test_selective_scan(
is_variable_B,
is_variable_C,
varBC_groups,
has_D,
has_z,
has_delta_bias,
delta_softplus,
seqlen,
itype,
wtype,
scan_chunks,
):
if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
pytest.skip() # This config is not applicable
device = 'cuda'
device = "cuda"
rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 3e-2, 5e-2
@@ -242,7 +267,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
batch_size = 1
dim = 4
dstate = 8
A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype))
A = -0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)
A_ref = A.clone()
if not is_variable_B:
B_shape = [dim, dstate]
@@ -250,9 +275,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
B_shape = [batch_size, dstate, seqlen]
else:
B_shape = [batch_size, varBC_groups, dstate, seqlen]
B = torch.randn(B_shape,
device=device,
dtype=wtype if not is_variable_B else itype)
B = torch.randn(B_shape, device=device, dtype=wtype if not is_variable_B else itype)
B_ref = B.clone()
if not is_variable_C:
C_shape = [dim, dstate]
@@ -260,27 +283,27 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
C_shape = [batch_size, dstate, seqlen]
else:
C_shape = [batch_size, varBC_groups, dstate, seqlen]
C = torch.randn(C_shape,
device=device,
dtype=wtype if not is_variable_C else itype)
C = torch.randn(C_shape, device=device, dtype=wtype if not is_variable_C else itype)
C_ref = C.clone()
D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
D_ref = D.clone()
z = torch.randn(batch_size, dim, seqlen, device=device,
dtype=itype) if has_z else None
z = (
torch.randn(batch_size, dim, seqlen, device=device, dtype=itype)
if has_z
else None
)
z_ref = z.clone() if has_z else None
delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)
) if has_delta_bias else None
delta_bias = (
(0.5 * torch.rand(dim, device=device, dtype=torch.float32))
if has_delta_bias
else None
)
u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype)
u_ref = u.clone()
delta = (0.5 *
torch.rand(batch_size, dim, seqlen, device=device, dtype=itype))
delta = 0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)
delta_ref = delta.clone()
state_shape = (batch_size, u.shape[1], int(A.shape[1]))
state = torch.randn(state_shape,
device=u.device,
dtype=itype,
requires_grad=False)
state = torch.randn(state_shape, device=u.device, dtype=itype, requires_grad=False)
state_ref = state.clone()
out = None
out_ref = None
@@ -312,9 +335,10 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
z=_z,
delta_bias=delta_bias,
delta_softplus=delta_softplus,
has_initial_state=torch.ones(batch_size,
device=u.device,
dtype=torch.bool) if c > 0 else None)
has_initial_state=torch.ones(batch_size, device=u.device, dtype=torch.bool)
if c > 0
else None,
)
outs.append(out)
if len(outs) > 1:
out = torch.cat(outs, dim=-1)
@@ -329,27 +353,29 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
z=z_ref,
delta_bias=delta_bias,
delta_softplus=delta_softplus,
return_last_state=True)
return_last_state=True,
)
assert out is not None and out_ref is not None
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
assert state is not None and state_ref is not None
assert torch.allclose(state, state_ref.to(itype), rtol=rtol, atol=atol)
selective_scan_opcheck_fn(u,
delta,
A,
B,
C,
D,
z,
delta_bias=delta_bias,
delta_softplus=delta_softplus,
ssm_states=state)
selective_scan_opcheck_fn(
u,
delta,
A,
B,
C,
D,
z,
delta_bias=delta_bias,
delta_softplus=delta_softplus,
ssm_states=state,
)
@pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("dstate", [16, 32, 64])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
@@ -374,52 +400,47 @@ def test_selective_state_update(dim, dstate, has_z, itype):
D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None
state_ref = state.detach().clone()
selective_state_update(state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
out=out)
out_ref = selective_state_update_ref(state_ref,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True)
selective_state_update(
state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True, out=out
)
out_ref = selective_state_update_ref(
state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True
)
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize('wtype', [torch.float32])
@pytest.mark.parametrize('itype', [torch.float32])
@pytest.mark.parametrize('seqlen', [1, 128, 129, 256, 512, 1024, 2048, 4096])
@pytest.mark.parametrize("wtype", [torch.float32])
@pytest.mark.parametrize("itype", [torch.float32])
@pytest.mark.parametrize("seqlen", [1, 128, 129, 256, 512, 1024, 2048, 4096])
@pytest.mark.parametrize("return_last_state", [True])
@pytest.mark.parametrize('has_delta_bias', [True])
@pytest.mark.parametrize('delta_softplus', [True])
@pytest.mark.parametrize('has_z', [True])
@pytest.mark.parametrize('has_D', [True])
@pytest.mark.parametrize("has_delta_bias", [True])
@pytest.mark.parametrize("delta_softplus", [True])
@pytest.mark.parametrize("has_z", [True])
@pytest.mark.parametrize("has_D", [True])
@pytest.mark.parametrize("varBC_groups", [1, 2])
@pytest.mark.parametrize("is_variable_C", [True])
@pytest.mark.parametrize("is_variable_B", [True])
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [False, True])
def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C,
varBC_groups, has_D, has_z, has_delta_bias,
delta_softplus, return_last_state, seqlen,
itype, wtype):
def test_selective_scan_varlen(
with_padding,
is_variable_B,
is_variable_C,
varBC_groups,
has_D,
has_z,
has_delta_bias,
delta_softplus,
return_last_state,
seqlen,
itype,
wtype,
):
if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
pytest.skip() # This config is not applicable
device = 'cuda'
device = "cuda"
rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 3e-2, 5e-2
@@ -443,72 +464,79 @@ def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C,
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
seqlens.append(
torch.diff(
torch.cat(
[torch.tensor([-1]), eos_pos,
torch.tensor([seqlen - 1])])).tolist())
torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])
).tolist()
)
assert sum(seqlens[-1]) == seqlen
assert all(s > 0 for s in seqlens[-1])
total_entries = batch_size * 10
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
dim=0).cuda()
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0).cuda()
dim = 4
dstate = 8
A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype))
A = -0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)
A_ref = A.clone()
B_shape = [varBC_groups, dstate, seqlen]
B = torch.randn(B_shape,
device=device,
dtype=wtype if not is_variable_B else itype)
B = torch.randn(B_shape, device=device, dtype=wtype if not is_variable_B else itype)
B_ref = B.clone()
C_shape = [varBC_groups, dstate, seqlen]
C = torch.randn(C_shape,
device=device,
dtype=wtype if not is_variable_C else itype)
C = torch.randn(C_shape, device=device, dtype=wtype if not is_variable_C else itype)
C_ref = C.clone()
D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
D_ref = D.clone()
z = torch.randn(dim, seqlen, device=device, dtype=itype)
z_ref = z.clone()
delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)
) if has_delta_bias else None
delta_bias = (
(0.5 * torch.rand(dim, device=device, dtype=torch.float32))
if has_delta_bias
else None
)
u = torch.randn(dim, seqlen, device=device, dtype=itype)
u_ref = u.clone()
delta = (0.5 * torch.rand(dim, seqlen, device=device, dtype=itype))
delta = 0.5 * torch.rand(dim, seqlen, device=device, dtype=itype)
delta_ref = delta.clone()
out = None
out_ref = None
prev_state_shape = (total_entries, u.shape[0], int(A.shape[1]))
prev_state = torch.randn(prev_state_shape,
device=u.device,
dtype=itype,
requires_grad=False)
prev_state = torch.randn(
prev_state_shape, device=u.device, dtype=itype, requires_grad=False
)
prev_state_ref = prev_state.clone()
state_indices = torch.randperm(total_entries,
dtype=torch.int32,
device=u.device)[:batch_size]
unused_states_bool = torch.ones(total_entries,
dtype=torch.bool,
device=device)
state_indices = torch.randperm(total_entries, dtype=torch.int32, device=u.device)[
:batch_size
]
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
unused_states_bool[state_indices] = False
padded_state_indices = torch.concat([
state_indices,
torch.as_tensor(
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=-1)
padded_state_indices = torch.concat(
[
state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=-1,
)
has_initial_state = torch.randint(0,
2, (cumsum.shape[0] - 1, ),
dtype=torch.bool,
device=u.device)
out = selective_scan_fn(u, prev_state, delta, A, B, C, D, z, delta_bias,
delta_softplus, cumsum, padded_state_indices,
has_initial_state)
has_initial_state = torch.randint(
0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=u.device
)
out = selective_scan_fn(
u,
prev_state,
delta,
A,
B,
C,
D,
z,
delta_bias,
delta_softplus,
cumsum,
padded_state_indices,
has_initial_state,
)
outs_ref = []
splits = [
torch.split(var, seqlens[0], dim=-1)
@@ -530,33 +558,46 @@ def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C,
delta_softplus=delta_softplus,
return_last_state=return_last_state,
prev_state=prev_state_ref[padded_state_indices[i]].unsqueeze(0)
if has_initial_state[i] else None,
final_state_out=prev_state_ref[padded_state_indices[i]].unsqueeze(
0))
if has_initial_state[i]
else None,
final_state_out=prev_state_ref[padded_state_indices[i]].unsqueeze(0),
)
outs_ref.append(out_ref_s)
out_ref = torch.cat(outs_ref, dim=-1)[0]
unpadded_out = out[:, :out_ref[0].shape[-1]]
unpadded_out = out[:, : out_ref[0].shape[-1]]
print("Output diff max", (unpadded_out - out_ref).max())
print("Output diff mean", (unpadded_out - out_ref).mean())
print("Output state diff max", (prev_state - prev_state_ref).max())
print("Output state diff mean", (prev_state - prev_state_ref).mean())
assert torch.allclose(prev_state, prev_state_ref, rtol=rtol, atol=atol)
assert torch.allclose(unpadded_out, out_ref, rtol=rtol, atol=atol)
selective_scan_opcheck_fn(u, delta, A, B, C, D, z, delta_bias,
delta_softplus, cumsum, padded_state_indices,
has_initial_state, prev_state)
selective_scan_opcheck_fn(
u,
delta,
A,
B,
C,
D,
z,
delta_bias,
delta_softplus,
cumsum,
padded_state_indices,
has_initial_state,
prev_state,
)
@pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("has_z", [True])
@pytest.mark.parametrize("dstate", [16, 32, 64])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
has_z, itype):
def test_selective_state_update_with_batch_indices(
with_padding, dim, dstate, has_z, itype
):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
@@ -571,17 +612,17 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
total_entries = 10 * batch_size
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device)
unused_states_bool = torch.ones(total_entries,
dtype=torch.bool,
device=device)
dtype=torch.int32, device=device
)
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
unused_states_bool[state_indices] = False
padded_state_indices = torch.concat([
state_indices,
torch.as_tensor(
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
],
dim=0)
padded_state_indices = torch.concat(
[
state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=0,
)
x = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
out = torch.empty_like(x)
dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
@@ -593,61 +634,60 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
z = torch.randn_like(x) if has_z else None
state_ref = state[state_indices, :].clone()
state_before = state.clone()
selective_state_update(state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID,
out=out)
out_ref = selective_state_update_ref(state_ref,
x[:batch_size],
dt[:batch_size],
A,
B[:batch_size],
C[:batch_size],
D=D,
z=z[:batch_size],
dt_bias=dt_bias,
dt_softplus=True)
selective_state_update(
state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID,
out=out,
)
out_ref = selective_state_update_ref(
state_ref,
x[:batch_size],
dt[:batch_size],
A,
B[:batch_size],
C[:batch_size],
D=D,
z=z[:batch_size],
dt_bias=dt_bias,
dt_softplus=True,
)
print("Output diff max", (out[:batch_size] - out_ref).max())
print("Output diff mean", (out[:batch_size] - out_ref).mean())
print("Output state diff max", (state[state_indices, :] - state_ref).max())
print("Output state diff mean",
(state[state_indices, :] - state_ref).mean())
print("Output state diff mean", (state[state_indices, :] - state_ref).mean())
# test padded entries stay the same
if with_padding:
assert torch.equal(state_before[unused_states_bool],
state[unused_states_bool])
assert torch.equal(x[batch_size + 1:], x[batch_size + 1:])
assert torch.equal(dt[batch_size + 1:], dt[batch_size + 1:])
assert torch.equal(B[batch_size + 1:], B[batch_size + 1:])
assert torch.equal(C[batch_size + 1:], C[batch_size + 1:])
assert torch.equal(state_before[unused_states_bool], state[unused_states_bool])
assert torch.equal(x[batch_size + 1 :], x[batch_size + 1 :])
assert torch.equal(dt[batch_size + 1 :], dt[batch_size + 1 :])
assert torch.equal(B[batch_size + 1 :], B[batch_size + 1 :])
assert torch.equal(C[batch_size + 1 :], C[batch_size + 1 :])
# test "real" entries
assert torch.allclose(state[state_indices, :],
state_ref,
rtol=rtol,
atol=atol)
assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("tie_hdim", [False, True])
@pytest.mark.parametrize("ngroups", [1, 2, 4])
@pytest.mark.parametrize("dstate", [16, 32, 64])
@pytest.mark.parametrize("dim", [2048, 4096])
def test_selective_state_update_with_heads_with_batch_indices(
dim, dstate, ngroups, has_z, tie_hdim, itype):
dim, dstate, ngroups, has_z, tie_hdim, itype
):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2)
if itype == torch.bfloat16:
@@ -659,71 +699,55 @@ def test_selective_state_update_with_heads_with_batch_indices(
nheads = dim // headdim
total_entries = 10 * batch_size
state = torch.randn(total_entries,
nheads,
headdim,
dstate,
dtype=itype,
device=device)
state = torch.randn(
total_entries, nheads, headdim, dstate, dtype=itype, device=device
)
state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device)
dtype=torch.int32, device=device
)
x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
out = torch.empty_like(x)
if not tie_hdim:
dt = torch.randn(batch_size,
nheads,
headdim,
device=device,
dtype=itype)
dt = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
dt_bias = torch.rand(nheads, headdim, device=device) - 4.0
A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0
D = torch.randn(nheads, headdim, device=device)
else:
dt = repeat(torch.randn(batch_size, nheads, device=device,
dtype=itype),
"b h -> b h p",
p=headdim)
dt_bias = repeat(torch.rand(nheads, device=device) - 4.0,
"h -> h p",
p=headdim)
A = repeat(-torch.rand(nheads, device=device) - 1.0,
"h -> h p n",
p=headdim,
n=dstate)
dt = repeat(
torch.randn(batch_size, nheads, device=device, dtype=itype),
"b h -> b h p",
p=headdim,
)
dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, "h -> h p", p=headdim)
A = repeat(
-torch.rand(nheads, device=device) - 1.0, "h -> h p n", p=headdim, n=dstate
)
D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim)
B = torch.randn(batch_size, ngroups, dstate, device=device)
C = torch.randn(batch_size, ngroups, dstate, device=device)
z = torch.randn_like(x) if has_z else None
state_ref = state[state_indices, :].detach().clone()
selective_state_update(state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices,
pad_slot_id=PAD_SLOT_ID,
out=out)
out_ref = selective_state_update_ref(state_ref,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True)
selective_state_update(
state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices,
pad_slot_id=PAD_SLOT_ID,
out=out,
)
out_ref = selective_state_update_ref(
state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
assert torch.allclose(state[state_indices, :],
state_ref,
rtol=rtol,
atol=atol)
assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)

View File

@@ -7,10 +7,10 @@ import torch.nn.functional as F
from einops import rearrange, repeat
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined_varlen)
mamba_chunk_scan_combined_varlen,
)
from vllm.platforms import current_platform
from vllm.v1.attention.backends.mamba2_attn import (
compute_varlen_chunk_metadata)
from vllm.v1.attention.backends.mamba2_attn import compute_varlen_chunk_metadata
# Added by the IBM Team, 2024
@@ -22,12 +22,10 @@ def segsum(x):
"""Calculates segment sum."""
T = x.size(-1)
x = repeat(x, "... d -> ... d e", e=T)
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool),
diagonal=-1)
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
x = x.masked_fill(~mask, 0)
x_segsum = torch.cumsum(x, dim=-2)
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool),
diagonal=0)
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
@@ -46,8 +44,9 @@ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
assert X.shape[1] % block_len == 0
# Rearrange into blocks/chunks
X, A, B, C = (rearrange(x, "b (c l) ... -> b c l ...", l=block_len)
for x in (X, A, B, C))
X, A, B, C = (
rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)
)
A = rearrange(A, "b c l h -> b h c l")
A_cumsum = torch.cumsum(A, dim=-1)
@@ -74,7 +73,7 @@ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
# 4. Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)
Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out)
# Add output of intra-chunk and inter-chunk terms
# (diagonal and off-diagonal blocks)
@@ -82,42 +81,31 @@ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
return Y, final_state
def generate_random_inputs(batch_size,
seqlen,
n_heads,
d_head,
itype,
device='cuda'):
def generate_random_inputs(batch_size, seqlen, n_heads, d_head, itype, device="cuda"):
current_platform.seed_everything(0)
A = (-torch.exp(torch.rand(n_heads, dtype=itype, device=device)))
A = -torch.exp(torch.rand(n_heads, dtype=itype, device=device))
dt = F.softplus(
torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) -
4)
X = torch.randn((batch_size, seqlen, n_heads, d_head),
dtype=itype,
device=device)
B = torch.randn((batch_size, seqlen, n_heads, d_head),
dtype=itype,
device=device)
C = torch.randn((batch_size, seqlen, n_heads, d_head),
dtype=itype,
device=device)
torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) - 4
)
X = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device)
B = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device)
C = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device)
return A, dt, X, B, C
def generate_continuous_batched_examples(example_lens_by_batch,
num_examples,
full_length,
last_taken,
exhausted,
n_heads,
d_head,
itype,
device='cuda',
return_naive_ref=True):
def generate_continuous_batched_examples(
example_lens_by_batch,
num_examples,
full_length,
last_taken,
exhausted,
n_heads,
d_head,
itype,
device="cuda",
return_naive_ref=True,
):
# this function generates a random examples of certain length
# and then cut according to "example_lens_by_batch" and feed
# them in continuous batches to the kernels.
@@ -126,23 +114,20 @@ def generate_continuous_batched_examples(example_lens_by_batch,
# reference output.
# generate the full-length example
A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads,
d_head, itype)
A, dt, X, B, C = generate_random_inputs(
num_examples, full_length, n_heads, d_head, itype
)
if return_naive_ref:
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1),
A * dt,
B,
C,
block_len=full_length //
4)
Y_min, final_state_min = ssd_minimal_discrete(
X * dt.unsqueeze(-1), A * dt, B, C, block_len=full_length // 4
)
# internal function that outputs a cont batch of examples
# given a tuple of lengths for each example in the batch
# e.g., example_lens=(8, 4) means take 8 samples from first eg,
# 4 examples from second eg, etc
def get_continuous_batch(example_lens: tuple[int, ...]):
indices = []
for i, x in enumerate(example_lens):
c = last_taken.get(i, 0)
@@ -150,8 +135,10 @@ def generate_continuous_batched_examples(example_lens_by_batch,
last_taken[i] = (c + x) % full_length
exhausted[i] = last_taken[i] == 0
return (torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices)
]).unsqueeze(0) for x in (dt, X, B, C))
return (
torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices)]).unsqueeze(0)
for x in (dt, X, B, C)
)
# internal function that maps "n" to the appropriate right boundary
# value when forming continuous batches from examples of length given
@@ -163,19 +150,20 @@ def generate_continuous_batched_examples(example_lens_by_batch,
IND_E = None
for spec in example_lens_by_batch:
# get the (maybe partial) example seen in this cont batch
dt2, X2, B2, C2 = get_continuous_batch(spec)
# get the metadata
cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0)
seq_idx = torch.zeros(cu_seqlens[-1],
dtype=torch.int32,
device=cu_seqlens.device)
for i, (srt, end) in enumerate(zip(
cu_seqlens = torch.tensor((0,) + spec, device=device).cumsum(dim=0)
seq_idx = torch.zeros(
cu_seqlens[-1], dtype=torch.int32, device=cu_seqlens.device
)
for i, (srt, end) in enumerate(
zip(
cu_seqlens,
cu_seqlens[1:],
)):
)
):
seq_idx[srt:end] = i
# for cont batch
@@ -190,19 +178,21 @@ def generate_continuous_batched_examples(example_lens_by_batch,
X2 = X2.squeeze(0)
B2 = B2.squeeze(0)
C2 = C2.squeeze(0)
yield ([Y_min[s, IND_S[s]:IND_E[s]]
for s in range(num_examples)] if return_naive_ref else None,
cu_seqlens, seq_idx, (A, dt2, X2, B2, C2))
yield (
[Y_min[s, IND_S[s] : IND_E[s]] for s in range(num_examples)]
if return_naive_ref
else None,
cu_seqlens,
seq_idx,
(A, dt2, X2, B2, C2),
)
@pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32])
@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128])
@pytest.mark.parametrize("seq_len_chunk_size", [(112, 16), (128, 32)])
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
itype):
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, itype):
# this tests the kernels on a single example (bs=1)
# TODO: the bfloat16 case requires higher thresholds. To be investigated
@@ -219,15 +209,16 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
# it is not an operational limitation.
seqlen, chunk_size = seq_len_chunk_size
A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads,
d_head, itype)
A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads, d_head, itype)
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
B, C, chunk_size)
Y_min, final_state_min = ssd_minimal_discrete(
X * dt.unsqueeze(-1), A * dt, B, C, chunk_size
)
cu_seqlens = torch.tensor((0, seqlen), device="cuda").cumsum(dim=0)
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
compute_varlen_chunk_metadata(cu_seqlens, chunk_size))
compute_varlen_chunk_metadata(cu_seqlens, chunk_size)
)
# varlen has implicit batch=1
X = X.squeeze(0)
dt = dt.squeeze(0)
@@ -255,10 +246,12 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
# just test the last head
# NOTE, in the kernel we always cast states to fp32
torch.testing.assert_close(final_state[:, -1].to(torch.float32),
final_state_min[:, -1].to(torch.float32),
atol=atol,
rtol=rtol)
torch.testing.assert_close(
final_state[:, -1].to(torch.float32),
final_state_min[:, -1].to(torch.float32),
atol=atol,
rtol=rtol,
)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16])
@@ -267,32 +260,40 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
@pytest.mark.parametrize(
"seq_len_chunk_size_cases",
[
# small-ish chunk_size (8)
(64, 8, 2, [(64, 32), (64, 32)]),
(64, 8, 2, [(32, 32), (32, 32), (32, 32)]),
(64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary
(64, 8, 2, [(4, 4), (4, 4), (4, 4),
(4, 4)]), # chunk_size larger than cont batches
(64, 8, 5, [
(64, 32, 16, 8, 8),
(8, 16, 32, 16, 8),
(8, 8, 16, 32, 16),
]), # mode examples with varied lengths
(
64,
8,
2,
[(4, 4), (4, 4), (4, 4), (4, 4)],
), # chunk_size larger than cont batches
(
64,
8,
5,
[
(64, 32, 16, 8, 8),
(8, 16, 32, 16, 8),
(8, 8, 16, 32, 16),
],
), # mode examples with varied lengths
# large-ish chunk_size (256)
(64, 256, 1, [(5, ), (1, ), (1, ),
(1, )]), # irregular sizes with small sequences
(64, 256, 2, [(5, 30), (1, 2), (1, 2),
(1, 2)]), # irregular sizes with small sequences
(64, 256, 1, [(5,), (1,), (1,), (1,)]), # irregular sizes with small sequences
(
64,
256,
2,
[(5, 30), (1, 2), (1, 2), (1, 2)],
), # irregular sizes with small sequences
# we also need to test some large seqlen
# to catch errors with init states decay
(768, 128, 2, [(138, 225), (138, 225)]),
])
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
itype):
],
)
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, itype):
# this test with multiple examples in a continuous batch
# (i.e. chunked prefill)
@@ -311,12 +312,17 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
states = None
for Y_min, cu_seqlens, _token_seq_idx, (
A, dt, X, B, C) in generate_continuous_batched_examples(
cases, num_examples, seqlen, last_taken, exhausted, n_heads,
d_head, itype):
A,
dt,
X,
B,
C,
) in generate_continuous_batched_examples(
cases, num_examples, seqlen, last_taken, exhausted, n_heads, d_head, itype
):
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
compute_varlen_chunk_metadata(cu_seqlens, chunk_size))
compute_varlen_chunk_metadata(cu_seqlens, chunk_size)
)
Y = torch.empty_like(X)
new_states = mamba_chunk_scan_combined_varlen(
@@ -337,9 +343,8 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
# just test the last in sequence
for i in range(num_examples):
# just test one dim and dstate
Y_eg = Y[cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
Y_eg = Y[cu_seqlens[i] : cu_seqlens[i + 1], 0, 0]
Y_min_eg = Y_min[i][:, 0, 0]
torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol)
@@ -347,18 +352,20 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
states = new_states
for i, clear in exhausted.items():
if clear:
states[i].fill_(0.)
states[i].fill_(0.0)
exhausted[i] = False
@pytest.mark.parametrize("chunk_size", [8, 256])
@pytest.mark.parametrize("seqlens", [
(16, 2, 8, 13),
(270, 88, 212, 203),
(16, 20),
])
@pytest.mark.parametrize(
"seqlens",
[
(16, 2, 8, 13),
(270, 88, 212, 203),
(16, 20),
],
)
def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
# This test verifies the correctness of the chunked prefill implementation
# in the mamba2 ssd kernels, by comparing concatenation (in the sequence
# dimension) of chunked results with the full sequence result.
@@ -387,21 +394,25 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
last_taken: dict = {} # map: eg -> pointer to last taken sample
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
_, cu_seqlens, seq_idx, (A, dt, X, B, C) = next(
generate_continuous_batched_examples([seqlens],
num_sequences,
max_seqlen,
last_taken,
exhausted,
n_heads,
d_head,
itype,
return_naive_ref=False))
generate_continuous_batched_examples(
[seqlens],
num_sequences,
max_seqlen,
last_taken,
exhausted,
n_heads,
d_head,
itype,
return_naive_ref=False,
)
)
seqlens = torch.tensor(seqlens, dtype=torch.int32, device=X.device)
device = X.device
## full seqlen computation
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
compute_varlen_chunk_metadata(cu_seqlens, chunk_size))
compute_varlen_chunk_metadata(cu_seqlens, chunk_size)
)
Y_ref = torch.empty_like(X)
state_ref = mamba_chunk_scan_combined_varlen(
X,
@@ -422,11 +433,9 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
## chunked seqlen computation
# first chunk
chunked_seqlens = seqlens // 2
chunked_cu_seqlens = torch.cat([
torch.tensor([0], device=device),
torch.cumsum(chunked_seqlens, dim=0)
],
dim=0)
chunked_cu_seqlens = torch.cat(
[torch.tensor([0], device=device), torch.cumsum(chunked_seqlens, dim=0)], dim=0
)
chunked_input_seq_len = chunked_cu_seqlens[-1]
X_chunked = torch.zeros_like(X)[:chunked_input_seq_len, ...]
dt_chunked = torch.zeros_like(dt)[:chunked_input_seq_len, ...]
@@ -443,7 +452,8 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
# fmt: on
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
compute_varlen_chunk_metadata(chunked_cu_seqlens, chunk_size))
compute_varlen_chunk_metadata(chunked_cu_seqlens, chunk_size)
)
Y_partial = torch.empty_like(X_chunked)
partial_state = mamba_chunk_scan_combined_varlen(
X_chunked,
@@ -463,11 +473,13 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
# remaining chunk
remaining_chunked_seqlens = seqlens - chunked_seqlens
remaining_chunked_cu_seqlens = torch.cat([
torch.tensor([0], device=device),
torch.cumsum(remaining_chunked_seqlens, dim=0)
],
dim=0)
remaining_chunked_cu_seqlens = torch.cat(
[
torch.tensor([0], device=device),
torch.cumsum(remaining_chunked_seqlens, dim=0),
],
dim=0,
)
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
# fmt: off
remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...] # noqa: E501
@@ -497,8 +509,8 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C)
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
compute_varlen_chunk_metadata(remaining_chunked_cu_seqlens,
chunk_size))
compute_varlen_chunk_metadata(remaining_chunked_cu_seqlens, chunk_size)
)
Y_chunked = torch.empty_like(remaining_X_chunked)
state_chunked = mamba_chunk_scan_combined_varlen(
@@ -520,20 +532,22 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
# kernel chunked is same as kernel overall
for i in range(num_sequences):
Y_seq = Y[cu_seqlens[i]:cu_seqlens[i + 1], ...]
Y_ref_seq = Y_ref[cu_seqlens[i]:cu_seqlens[i + 1], ...]
Y_seq = Y[cu_seqlens[i] : cu_seqlens[i + 1], ...]
Y_ref_seq = Y_ref[cu_seqlens[i] : cu_seqlens[i + 1], ...]
torch.testing.assert_close(
Y_seq[:chunked_seqlens[i], ...],
Y_ref_seq[:chunked_seqlens[i], ...],
Y_seq[: chunked_seqlens[i], ...],
Y_ref_seq[: chunked_seqlens[i], ...],
atol=atol,
rtol=rtol,
msg=lambda x: f"seq{i} output part1 " + x) # noqa: B023
msg=lambda x: f"seq{i} output part1 " + x,
) # noqa: B023
torch.testing.assert_close(
Y_seq[chunked_seqlens[i]:, ...],
Y_ref_seq[chunked_seqlens[i]:, ...],
Y_seq[chunked_seqlens[i] :, ...],
Y_ref_seq[chunked_seqlens[i] :, ...],
atol=atol,
rtol=rtol,
msg=lambda x: f"seq{i} output part2 " + x) # noqa: B023
msg=lambda x: f"seq{i} output part2 " + x,
) # noqa: B023
state_seq = state_chunked[i]
state_seq_ref = state_ref[i]
@@ -542,4 +556,5 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
state_seq_ref,
atol=atol,
rtol=rtol,
msg=lambda x: f"seq{i} state " + x) # noqa: B023
msg=lambda x: f"seq{i} state " + x,
) # noqa: B023

View File

@@ -9,18 +9,19 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from .common import Config
from .mk_objects import (MK_ALL_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES,
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES)
from .mk_objects import (
MK_ALL_PREPARE_FINALIZE_TYPES,
MK_FUSED_EXPERT_TYPES,
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES,
)
def make_config_arg_parser(description: str):
def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalize:
for pf in MK_ALL_PREPARE_FINALIZE_TYPES:
if pf.__name__ == s:
return pf
raise ValueError(
f"Cannot find a PrepareFinalize type that matches {s}")
raise ValueError(f"Cannot find a PrepareFinalize type that matches {s}")
def to_experts_class_type(s: str) -> mk.FusedMoEPermuteExpertsUnpermute:
for fe in MK_FUSED_EXPERT_TYPES:
@@ -45,15 +46,18 @@ def make_config_arg_parser(description: str):
"--pf-type",
type=to_pf_class_type,
required=True,
help=("Choose a PrepareFinalize Type : "
f"{[x.__name__ for x in MK_ALL_PREPARE_FINALIZE_TYPES]}"),
help=(
"Choose a PrepareFinalize Type : "
f"{[x.__name__ for x in MK_ALL_PREPARE_FINALIZE_TYPES]}"
),
)
parser.add_argument(
"--experts-type",
type=to_experts_class_type,
required=True,
help=(f"Choose a FusedExpert type : "
f"{[x.__name__ for x in MK_FUSED_EXPERT_TYPES]}"),
help=(
f"Choose a FusedExpert type : {[x.__name__ for x in MK_FUSED_EXPERT_TYPES]}"
),
)
parser.add_argument(
"-m",
@@ -74,66 +78,65 @@ def make_config_arg_parser(description: str):
default=1024,
help="N dimension of the first fused-moe matmul",
)
parser.add_argument("--num-experts",
type=int,
default=32,
help="Global num experts")
parser.add_argument("--topk",
nargs="+",
type=int,
default=[4, 1],
help="num topk")
parser.add_argument(
"--num-experts", type=int, default=32, help="Global num experts"
)
parser.add_argument("--topk", nargs="+", type=int, default=[4, 1], help="num topk")
parser.add_argument(
"--fused-moe-chunk-size",
type=int,
help="Fused moe chunk size used for the non-batched fused experts impl."
help="Fused moe chunk size used for the non-batched fused experts impl.",
)
# Quant args
parser.add_argument("--quant-dtype",
type=to_quant_torch_dtype,
help="Quant datatype")
parser.add_argument("--per-token-quantized-activations",
action='store_true',
help=("The input activations must be per-token "
"quantized"))
parser.add_argument("--per-channel-quantized-weights",
action="store_true",
help="The weights must be per-channel quantized.")
parser.add_argument("--block-shape",
nargs="+",
type=int,
help="Quantization block shape")
parser.add_argument(
"--quant-dtype", type=to_quant_torch_dtype, help="Quant datatype"
)
parser.add_argument(
"--per-token-quantized-activations",
action="store_true",
help=("The input activations must be per-token quantized"),
)
parser.add_argument(
"--per-channel-quantized-weights",
action="store_true",
help="The weights must be per-channel quantized.",
)
parser.add_argument(
"--block-shape", nargs="+", type=int, help="Quantization block shape"
)
# Torch trace profile generation args
parser.add_argument("--torch-trace-dir-path",
type=str,
default=None,
help="Get torch trace for single execution")
parser.add_argument(
"--torch-trace-dir-path",
type=str,
default=None,
help="Get torch trace for single execution",
)
return parser
def _validate_args(args: argparse.Namespace):
if args.quant_dtype is not None:
assert args.quant_dtype == torch.float8_e4m3fn
if args.block_shape is not None:
assert len(args.block_shape) == 2, (
f"block shape must have 2 elements. got {args.block_shape}")
f"block shape must have 2 elements. got {args.block_shape}"
)
if args.experts_type in MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES:
assert args.world_size == 1, (
"Single GPU objects need world size set to 1")
assert args.world_size == 1, "Single GPU objects need world size set to 1"
if args.torch_trace_dir_path is not None:
from pathlib import Path
assert Path(args.torch_trace_dir_path).is_dir(), (
f"Please create {args.torch_trace_dir_path}")
f"Please create {args.torch_trace_dir_path}"
)
def make_config(args: argparse.Namespace) -> Config:
_validate_args(args)
quant_config = None
@@ -142,7 +145,8 @@ def make_config(args: argparse.Namespace) -> Config:
quant_dtype=args.quant_dtype,
per_act_token_quant=args.per_token_quantized_activations,
per_out_ch_quant=args.per_channel_quantized_weights,
block_shape=args.block_shape)
block_shape=args.block_shape,
)
return Config(
Ms=args.m,
@@ -156,4 +160,5 @@ def make_config(args: argparse.Namespace) -> Config:
fused_experts_type=args.experts_type,
fused_moe_chunk_size=args.fused_moe_chunk_size,
world_size=args.world_size,
torch_trace_dir_path=args.torch_trace_dir_path)
torch_trace_dir_path=args.torch_trace_dir_path,
)

View File

@@ -8,20 +8,30 @@ import torch
import vllm._custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import make_test_weights, per_token_cast_to_fp8
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from tests.kernels.quantization.nvfp4_utils import (
FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype,
)
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig)
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from .mk_objects import (TestMoEQuantConfig, expert_info, make_fused_experts,
make_prepare_finalize, prepare_finalize_info)
from .mk_objects import (
TestMoEQuantConfig,
expert_info,
make_fused_experts,
make_prepare_finalize,
prepare_finalize_info,
)
from .parallel_utils import ProcessGroupInfo
@@ -94,8 +104,7 @@ class Config:
@property
def is_per_tensor_act_quant(self) -> bool:
return (not self.is_per_act_token_quant
and self.quant_block_shape is None)
return not self.is_per_act_token_quant and self.quant_block_shape is None
@property
def is_per_out_ch_quant(self) -> bool:
@@ -134,23 +143,24 @@ class Config:
if self.fused_moe_chunk_size is not None:
env_dict.update(
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)})
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)}
)
return vllm_config, env_dict
def is_fp8_block_quantized(self):
return (self.quant_dtype == torch.float8_e4m3fn
and self.quant_block_shape is not None)
return (
self.quant_dtype == torch.float8_e4m3fn
and self.quant_block_shape is not None
)
def is_batched_prepare_finalize(self):
info = prepare_finalize_info(self.prepare_finalize_type)
return (mk.FusedMoEActivationFormat.BatchedExperts ==
info.activation_format)
return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format
def is_batched_fused_experts(self):
info = expert_info(self.fused_experts_type)
return (mk.FusedMoEActivationFormat.BatchedExperts ==
info.activation_format)
return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format
def is_standard_fused_experts(self):
info = expert_info(self.fused_experts_type)
@@ -190,8 +200,10 @@ class Config:
def needs_deep_ep(self):
info = prepare_finalize_info(self.prepare_finalize_type)
return (info.backend == "deepep_high_throughput"
or info.backend == "deepep_low_latency")
return (
info.backend == "deepep_high_throughput"
or info.backend == "deepep_low_latency"
)
def all2all_backend(self):
info = prepare_finalize_info(self.prepare_finalize_type)
@@ -211,20 +223,26 @@ class Config:
return False
# Check quantization sanity
if (int(self.is_per_act_token_quant) +
int(self.is_per_tensor_act_quant) +
int(self.quant_block_shape is not None)) > 1:
if (
int(self.is_per_act_token_quant)
+ int(self.is_per_tensor_act_quant)
+ int(self.quant_block_shape is not None)
) > 1:
# invalid quant config
return False
# check type support
if self.quant_dtype is None:
if (self.dtype not in self.pf_supported_types()
or self.dtype not in self.fe_supported_types()):
if (
self.dtype not in self.pf_supported_types()
or self.dtype not in self.fe_supported_types()
):
return False
else:
if (self.quant_dtype not in self.pf_supported_types()
or self.quant_dtype not in self.fe_supported_types()):
if (
self.quant_dtype not in self.pf_supported_types()
or self.quant_dtype not in self.fe_supported_types()
):
return False
# Check block quanization support
@@ -261,18 +279,21 @@ class WeightTensors:
def describe(self):
s = ""
s += "== Weight Tensors: \n"
s += f' - {_describe_tensor(self.w1, "w1")} \n'
s += f' - {_describe_tensor(self.w2, "w2")} \n'
s += f' - {_describe_tensor(self.w1_scale, "w1_scale")} \n'
s += f' - {_describe_tensor(self.w2_scale, "w2_scale")} \n'
s += f' - {_describe_tensor(self.w1_gs, "w1_gs")} \n'
s += f' - {_describe_tensor(self.w2_gs, "w2_gs")} \n'
s += f" - {_describe_tensor(self.w1, 'w1')} \n"
s += f" - {_describe_tensor(self.w2, 'w2')} \n"
s += f" - {_describe_tensor(self.w1_scale, 'w1_scale')} \n"
s += f" - {_describe_tensor(self.w2_scale, 'w2_scale')} \n"
s += f" - {_describe_tensor(self.w1_gs, 'w1_gs')} \n"
s += f" - {_describe_tensor(self.w2_gs, 'w2_gs')} \n"
return s
def is_quantized(self) -> bool:
# or w1_scale is not None?
return (self.w1.dtype == torch.float8_e4m3fn
or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8)
return (
self.w1.dtype == torch.float8_e4m3fn
or self.w1.dtype == torch.uint8
or self.w1.dtype == torch.int8
)
def to_current_device(self):
device = torch.cuda.current_device()
@@ -289,16 +310,13 @@ class WeightTensors:
if self.w2_gs is not None:
self.w2_gs = self.w2_gs.to(device=device)
def slice_weights(self, rank: int,
num_local_experts: int) -> "WeightTensors":
def slice_weights(self, rank: int, num_local_experts: int) -> "WeightTensors":
s = rank * num_local_experts
e = s + num_local_experts
w1 = self.w1[s:e, :, :]
w2 = self.w2[s:e, :, :]
w1_scale = self.w1_scale[
s:e, :, :] if self.w1_scale is not None else None
w2_scale = self.w2_scale[
s:e, :, :] if self.w2_scale is not None else None
w1_scale = self.w1_scale[s:e, :, :] if self.w1_scale is not None else None
w2_scale = self.w2_scale[s:e, :, :] if self.w2_scale is not None else None
w1_gs = self.w1_gs[s:e] if self.w1_gs is not None else None
w2_gs = self.w2_gs[s:e] if self.w2_gs is not None else None
@@ -313,15 +331,11 @@ class WeightTensors:
in_dtype=config.dtype,
quant_dtype=config.quant_dtype,
block_shape=config.quant_block_shape,
per_out_ch_quant=config.
is_per_act_token_quant, # or config.is_per_out_ch_quant
per_out_ch_quant=config.is_per_act_token_quant, # or config.is_per_out_ch_quant
)
return WeightTensors(
w1=w1, w2=w2, w1_scale=w1_scale, w2_scale=w2_scale, w1_gs=w1_gs, w2_gs=w2_gs
)
return WeightTensors(w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_gs=w1_gs,
w2_gs=w2_gs)
@dataclass
@@ -336,22 +350,22 @@ class RankTensors:
def describe(self):
s = ""
s += "== Rank Tensors: \n"
s += f' - {_describe_tensor(self.hidden_states, "HS")} \n'
s += f' - {_describe_tensor(self.hidden_states_scale, "HS_scale")} \n'
s += f' - {_describe_tensor(self.topk_weights, "topk_weights")} \n'
s += f' - {_describe_tensor(self.topk_ids, "topk_ids")} \n'
s += f' - {_describe_tensor(self.expert_map, "expert_map")} \n'
s += f" - {_describe_tensor(self.hidden_states, 'HS')} \n"
s += f" - {_describe_tensor(self.hidden_states_scale, 'HS_scale')} \n"
s += f" - {_describe_tensor(self.topk_weights, 'topk_weights')} \n"
s += f" - {_describe_tensor(self.topk_ids, 'topk_ids')} \n"
s += f" - {_describe_tensor(self.expert_map, 'expert_map')} \n"
return s
@staticmethod
def make_hidden_states(
config: Config) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
config: Config,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Return hidden_states
"""
m, k, dtype = (config.M, config.K, config.dtype)
a = (torch.randn(
(m, k), device=torch.cuda.current_device(), dtype=dtype) / 15.0)
a = torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 15.0
if config.quant_dtype is None:
return a, None
@@ -362,36 +376,29 @@ class RankTensors:
# first - so further quantize and dequantize will yield the same
# values.
if config.is_per_tensor_act_quant:
a_q, a_scales = ops.scaled_fp8_quant(
a, use_per_token_if_dynamic=False)
a_q, a_scales = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=False)
return a_q.float().mul(a_scales).to(dtype), a_scales
if config.is_per_act_token_quant:
a_q, a_scales = ops.scaled_fp8_quant(a,
use_per_token_if_dynamic=True)
a_q, a_scales = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=True)
return a_q.float().mul(a_scales).to(dtype), None
assert config.quant_block_shape is not None
block_k = config.quant_block_shape[1]
a_q, a_scales = per_token_cast_to_fp8(a, block_size=block_k)
return a_q.float().view(
(-1, block_k)).mul(a_scales.view(-1, 1)).view(m, k).to(dtype), None
return a_q.float().view((-1, block_k)).mul(a_scales.view(-1, 1)).view(m, k).to(
dtype
), None
@staticmethod
def make(config: Config, pgi: ProcessGroupInfo):
dtype = config.dtype
topk, m, _ = (config.topk, config.M, config.K)
hidden_states, hidden_states_scale = RankTensors.make_hidden_states(
config)
hidden_states, hidden_states_scale = RankTensors.make_hidden_states(config)
num_local_experts, global_num_experts = (config.num_local_experts,
config.E)
score = torch.randn((m, global_num_experts),
device="cuda",
dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk,
False)
num_local_experts, global_num_experts = (config.num_local_experts, config.E)
score = torch.randn((m, global_num_experts), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk, False)
# distribute topk_ids evenly
for mi in range(m):
@@ -400,14 +407,15 @@ class RankTensors:
expert_map = None
if config.world_size > 1 and config.supports_expert_map():
expert_map = torch.full((global_num_experts, ),
fill_value=-1,
dtype=torch.int32)
expert_map = torch.full(
(global_num_experts,), fill_value=-1, dtype=torch.int32
)
s = pgi.rank * num_local_experts
e = s + num_local_experts
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
expert_map = expert_map.to(device=torch.cuda.current_device(),
dtype=torch.int32)
expert_map = expert_map.to(
device=torch.cuda.current_device(), dtype=torch.int32
)
return RankTensors(
hidden_states=hidden_states,
@@ -418,9 +426,9 @@ class RankTensors:
)
def reference_moe_impl(config: Config, weights: WeightTensors,
rank_tensors: RankTensors) -> torch.Tensor:
def reference_moe_impl(
config: Config, weights: WeightTensors, rank_tensors: RankTensors
) -> torch.Tensor:
if config.quant_dtype == "nvfp4":
quant_blocksize = 16
dtype = config.dtype
@@ -433,8 +441,10 @@ def reference_moe_impl(config: Config, weights: WeightTensors,
w2_blockscale = weights.w2_scale
w2_gs = weights.w2_gs
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(
rank_tensors.hidden_states.flatten(), dim=-1)).to(torch.float32)
a_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX)
/ torch.amax(rank_tensors.hidden_states.flatten(), dim=-1)
).to(torch.float32)
assert w1_gs is not None
assert w2_gs is not None
@@ -447,14 +457,17 @@ def reference_moe_impl(config: Config, weights: WeightTensors,
assert w2_blockscale.shape[2] % 4 == 0
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(
rank_tensors.hidden_states, a_global_scale)
rank_tensors.hidden_states, a_global_scale
)
a = dequantize_nvfp4_to_dtype(a_fp4,
a_scale_interleaved,
a_global_scale,
dtype=dtype,
device=a_fp4.device,
block_size=quant_blocksize)
a = dequantize_nvfp4_to_dtype(
a_fp4,
a_scale_interleaved,
a_global_scale,
dtype=dtype,
device=a_fp4.device,
block_size=quant_blocksize,
)
e = w1_q.shape[0]
n = w1_q.shape[1] // 2
@@ -464,18 +477,22 @@ def reference_moe_impl(config: Config, weights: WeightTensors,
w2 = torch.zeros((e, k, n), device="cuda", dtype=dtype)
for idx in range(0, e):
w1[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
w1_blockscale[idx],
w1_gs[idx],
dtype=dtype,
device=w1_q.device,
block_size=quant_blocksize)
w2[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
w2_blockscale[idx],
w2_gs[idx],
dtype=dtype,
device=w2_q.device,
block_size=quant_blocksize)
w1[idx] = dequantize_nvfp4_to_dtype(
w1_q[idx],
w1_blockscale[idx],
w1_gs[idx],
dtype=dtype,
device=w1_q.device,
block_size=quant_blocksize,
)
w2[idx] = dequantize_nvfp4_to_dtype(
w2_q[idx],
w2_blockscale[idx],
w2_gs[idx],
dtype=dtype,
device=w2_q.device,
block_size=quant_blocksize,
)
a_scale = None
w1_scale = None
w2_scale = None
@@ -493,27 +510,29 @@ def reference_moe_impl(config: Config, weights: WeightTensors,
per_act_token_quant = config.is_per_act_token_quant
block_shape = config.quant_block_shape
return torch_experts(a=a,
w1=w1,
w2=w2,
topk_weight=rank_tensors.topk_weights,
topk_ids=rank_tensors.topk_ids,
global_num_experts=config.E,
expert_map=None,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
apply_router_weights_on_input=config.topk == 1
and config.supports_apply_weight_on_input())
return torch_experts(
a=a,
w1=w1,
w2=w2,
topk_weight=rank_tensors.topk_weights,
topk_ids=rank_tensors.topk_ids,
global_num_experts=config.E,
expert_map=None,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
apply_router_weights_on_input=config.topk == 1
and config.supports_apply_weight_on_input(),
)
def _make_gscale(num_experts: int) -> torch.Tensor:
return torch.ones((num_experts, ),
device=torch.cuda.current_device(),
dtype=torch.float32)
return torch.ones(
(num_experts,), device=torch.cuda.current_device(), dtype=torch.float32
)
def make_modular_kernel(
@@ -521,12 +540,12 @@ def make_modular_kernel(
vllm_config: VllmConfig,
quant_config: FusedMoEQuantConfig,
) -> mk.FusedMoEModularKernel:
def next_power_of_2(x):
import math
if x == 0:
return 1
return 2**math.ceil(math.log2(x))
return 2 ** math.ceil(math.log2(x))
# make moe config
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
@@ -546,9 +565,9 @@ def make_modular_kernel(
)
# make modular kernel
prepare_finalize = make_prepare_finalize(config.prepare_finalize_type,
config.all2all_backend(), moe,
quant_config)
prepare_finalize = make_prepare_finalize(
config.prepare_finalize_type, config.all2all_backend(), moe, quant_config
)
fused_experts = make_fused_experts(
config.fused_experts_type,
@@ -559,7 +578,8 @@ def make_modular_kernel(
)
modular_kernel = mk.FusedMoEModularKernel(
prepare_finalize=prepare_finalize, fused_experts=fused_experts)
prepare_finalize=prepare_finalize, fused_experts=fused_experts
)
return modular_kernel
@@ -587,10 +607,8 @@ def run_modular_kernel(
w1_scale=rank_weights.w1_scale,
w2_scale=rank_weights.w2_scale,
a1_scale=rank_tensors.hidden_states_scale,
g1_alphas=(1 / rank_weights.w1_gs)
if rank_weights.w1_gs is not None else None,
g2_alphas=(1 / rank_weights.w2_gs)
if rank_weights.w2_gs is not None else None,
g1_alphas=(1 / rank_weights.w1_gs) if rank_weights.w1_gs is not None else None,
g2_alphas=(1 / rank_weights.w2_gs) if rank_weights.w2_gs is not None else None,
a1_gscale=gscale,
a2_gscale=gscale,
block_shape=config.quant_block_shape,
@@ -603,38 +621,30 @@ def run_modular_kernel(
# impls might update the tensor in place
hidden_states = rank_tensors.hidden_states.clone()
topk_ids = rank_tensors.topk_ids.to(
mk.prepare_finalize.topk_indices_dtype())
topk_ids = rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype())
mk_kwargs = {
"hidden_states":
hidden_states,
"w1":
rank_weights.w1,
"w2":
rank_weights.w2,
"topk_weights":
rank_tensors.topk_weights,
"topk_ids":
topk_ids,
"expert_map":
rank_tensors.expert_map,
"global_num_experts":
config.E,
"apply_router_weight_on_input":
config.topk == 1 and config.supports_apply_weight_on_input(),
"hidden_states": hidden_states,
"w1": rank_weights.w1,
"w2": rank_weights.w2,
"topk_weights": rank_tensors.topk_weights,
"topk_ids": topk_ids,
"expert_map": rank_tensors.expert_map,
"global_num_experts": config.E,
"apply_router_weight_on_input": config.topk == 1
and config.supports_apply_weight_on_input(),
}
num_tokens = rank_tensors.hidden_states.shape[0]
num_tokens_across_dp = torch.tensor([num_tokens] * config.world_size,
device="cuda",
dtype=torch.int)
num_tokens_across_dp = torch.tensor(
[num_tokens] * config.world_size, device="cuda", dtype=torch.int
)
with set_forward_context(
None,
vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
None,
vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
):
out = mk.forward(**mk_kwargs)

View File

@@ -10,14 +10,21 @@ import torch
from tqdm import tqdm
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG)
from vllm.model_executor.layers.fused_moe.config import FUSED_MOE_UNQUANTIZED_CONFIG
from vllm.platforms import current_platform
from .common import (Config, RankTensors, WeightTensors, reference_moe_impl,
run_modular_kernel)
from .mk_objects import (MK_FUSED_EXPERT_TYPES,
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_QUANT_CONFIGS)
from .common import (
Config,
RankTensors,
WeightTensors,
reference_moe_impl,
run_modular_kernel,
)
from .mk_objects import (
MK_FUSED_EXPERT_TYPES,
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
MK_QUANT_CONFIGS,
)
from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config
@@ -38,8 +45,9 @@ def rank_worker(
# sanity check
from vllm import envs
if config.fused_moe_chunk_size is not None:
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
# get weights to this device
weights.to_current_device()
@@ -60,8 +68,7 @@ def rank_worker(
rank_tensors = RankTensors.make(cfgx, pgi)
# modular kernel out
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights,
rank_tensors)
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, rank_tensors)
with set_current_vllm_config(vllm_config):
ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
@@ -70,28 +77,27 @@ def rank_worker(
def make_feature_matrix(csv_file_path: str):
from dataclasses import asdict
import pandas as pd
def add_to_results(config: Config,
success: Result,
results_df: Optional[pd.DataFrame] = None):
def add_to_results(
config: Config, success: Result, results_df: Optional[pd.DataFrame] = None
):
config_dict = asdict(config)
config_dict['prepare_finalize_type'] = config_dict[
'prepare_finalize_type'].__name__
config_dict['fused_experts_type'] = config_dict[
'fused_experts_type'].__name__
config_dict['per_tensor_act_quant'] = config.is_per_tensor_act_quant
quant_config_dict = config_dict['quant_config']
del config_dict['quant_config']
config_dict["prepare_finalize_type"] = config_dict[
"prepare_finalize_type"
].__name__
config_dict["fused_experts_type"] = config_dict["fused_experts_type"].__name__
config_dict["per_tensor_act_quant"] = config.is_per_tensor_act_quant
quant_config_dict = config_dict["quant_config"]
del config_dict["quant_config"]
if quant_config_dict is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
quant_config_dict = asdict(quant_config)
config_dict |= quant_config_dict
result_dict = config_dict | {'success': success.name}
result_dict = config_dict | {"success": success.name}
result_df = pd.DataFrame([result_dict])
if results_df is None:
@@ -112,22 +118,26 @@ def make_feature_matrix(csv_file_path: str):
Q_TYPES = MK_QUANT_CONFIGS
combinations = list(
product(Ms, Ks, Ns, Es, TOPKs, DTYPEs, PF_TYPES, FE_TYPES, Q_TYPES))
product(Ms, Ks, Ns, Es, TOPKs, DTYPEs, PF_TYPES, FE_TYPES, Q_TYPES)
)
results_df: Optional[pd.DataFrame] = None
for m, k, n, e, topks, dtype, pf_type, experts_type, quant_config in tqdm(
combinations): #noqa: E501
config = Config(Ms=[m],
K=k,
N=n,
E=e,
topks=topks,
dtype=dtype,
prepare_finalize_type=pf_type,
fused_experts_type=experts_type,
quant_config=quant_config,
world_size=2,
fused_moe_chunk_size=None)
combinations
): # noqa: E501
config = Config(
Ms=[m],
K=k,
N=n,
E=e,
topks=topks,
dtype=dtype,
prepare_finalize_type=pf_type,
fused_experts_type=experts_type,
quant_config=quant_config,
world_size=2,
fused_moe_chunk_size=None,
)
success = None
if config.is_valid():
@@ -135,9 +145,14 @@ def make_feature_matrix(csv_file_path: str):
try:
weights: WeightTensors = WeightTensors.make(config)
vllm_config, env_dict = config.make_env_data()
parallel_launch_with_config(config.world_size, rank_worker,
vllm_config, env_dict, config,
weights)
parallel_launch_with_config(
config.world_size,
rank_worker,
vllm_config,
env_dict,
config,
weights,
)
success = Result.PASS
except Exception as _:
success = Result.FAIL
@@ -150,25 +165,33 @@ def make_feature_matrix(csv_file_path: str):
results_df.to_csv(f"{csv_file_path}")
if __name__ == '__main__':
if __name__ == "__main__":
import argparse
from pathlib import Path
parser = argparse.ArgumentParser(description=(
"Make ModularKernel feature matrix \n"
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.make_feature_matrix " #noqa: E501
"-f ./feature_matrices/feature_matrix.csv"))
parser.add_argument("-f",
"--feature-matrix-csv-file-path",
type=str,
required=True,
help="File name to Generate a .csv file")
parser = argparse.ArgumentParser(
description=(
"Make ModularKernel feature matrix \n"
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.make_feature_matrix " # noqa: E501
"-f ./feature_matrices/feature_matrix.csv"
)
)
parser.add_argument(
"-f",
"--feature-matrix-csv-file-path",
type=str,
required=True,
help="File name to Generate a .csv file",
)
args = parser.parse_args()
csv_path = args.feature_matrix_csv_file_path
assert csv_path.endswith(
'csv'), f"Need a file path ending with .csv, got {csv_path}"
assert Path(csv_path).parent.is_dir(
), f"Cannot find parent directory for {Path(csv_path).parent}"
assert csv_path.endswith("csv"), (
f"Need a file path ending with .csv, got {csv_path}"
)
assert Path(csv_path).parent.is_dir(), (
f"Cannot find parent directory for {Path(csv_path).parent}"
)
make_feature_matrix(args.feature_matrix_csv_file_path)

View File

@@ -8,24 +8,33 @@ import torch
# Fused experts and PrepareFinalize imports
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
BatchedDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
BatchedTritonOrDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase,
TritonExperts)
BatchedTritonExperts,
NaiveBatchedExperts,
)
from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, TritonExperts
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
TritonOrDeepGemmExperts,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported)
cutlass_fp4_supported,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported)
cutlass_fp8_supported,
)
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.deep_gemm import is_deep_gemm_supported
@@ -60,8 +69,7 @@ class ExpertInfo:
needs_deep_gemm: bool = False
PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize,
PrepareFinalizeInfo] = {}
PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize, PrepareFinalizeInfo] = {}
EXPERT_INFO: dict[mk.FusedMoEPermuteExpertsUnpermute, ExpertInfo] = {}
MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
@@ -71,7 +79,10 @@ MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEPermuteExpertsUnpermute] = []
standard_format = mk.FusedMoEActivationFormat.Standard
batched_format = mk.FusedMoEActivationFormat.BatchedExperts
common_float_types: list[Union[torch.dtype, str]] = [
torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32
torch.float8_e4m3fn,
torch.bfloat16,
torch.float16,
torch.float32,
]
common_float_and_int_types = common_float_types + [torch.int8]
nvfp4_types = ["nvfp4"]
@@ -186,9 +197,11 @@ register_experts(
# Disable on blackwell for now
if has_deep_ep() and not current_platform.has_device_capability(100):
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
DeepEPHTPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
DeepEPLLPrepareAndFinalize,
)
register_prepare_and_finalize(
DeepEPHTPrepareAndFinalize,
@@ -208,7 +221,9 @@ if has_deep_ep() and not current_platform.has_device_capability(100):
if has_pplx():
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
PplxPrepareAndFinalize,
)
register_prepare_and_finalize(
PplxPrepareAndFinalize,
batched_format,
@@ -217,13 +232,14 @@ if has_pplx():
backend="pplx",
)
if (has_flashinfer_cutlass_fused_moe()
and current_platform.has_device_capability(100)):
if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100):
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
FlashInferExperts)
FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize,
create_flashinfer_prepare_finalize)
create_flashinfer_prepare_finalize,
)
register_prepare_and_finalize(
FlashInferCutlassMoEPrepareAndFinalize,
@@ -258,16 +274,18 @@ if has_deep_gemm() and is_deep_gemm_supported():
needs_matching_quant=False,
needs_deep_gemm=True,
)
register_experts(
DeepGemmExperts,
standard_format,
fp8_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_matching_quant=False,
needs_deep_gemm=True,
),
(
register_experts(
DeepGemmExperts,
standard_format,
fp8_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_matching_quant=False,
needs_deep_gemm=True,
),
)
register_experts(
BatchedTritonOrDeepGemmExperts,
batched_format,
@@ -290,8 +308,11 @@ if has_deep_gemm() and is_deep_gemm_supported():
)
if cutlass_fp8_supported():
from vllm.model_executor.layers.fused_moe import (CutlassBatchedExpertsFp8,
CutlassExpertsFp8)
from vllm.model_executor.layers.fused_moe import (
CutlassBatchedExpertsFp8,
CutlassExpertsFp8,
)
register_experts(
CutlassExpertsFp8,
standard_format,
@@ -310,8 +331,8 @@ if cutlass_fp8_supported():
)
if cutlass_fp4_supported():
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4)
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp4
register_experts(
CutlassExpertsFp4,
standard_format,
@@ -324,30 +345,40 @@ if cutlass_fp4_supported():
MK_QUANT_CONFIGS: list[Optional[TestMoEQuantConfig]] = [
None,
# per-channel / per-column weights and per-tensor activations
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=True,
per_act_token_quant=False,
block_shape=None),
TestMoEQuantConfig(
quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=True,
per_act_token_quant=False,
block_shape=None,
),
# per-channel / per-column weights and per-token activations
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=True,
per_act_token_quant=True,
block_shape=None),
TestMoEQuantConfig(
quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=True,
per_act_token_quant=True,
block_shape=None,
),
# per-tensor weights and per-tensor activations
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=None),
TestMoEQuantConfig(
quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=None,
),
# per-tensor weights and per-token activations
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=True,
block_shape=None),
TestMoEQuantConfig(
quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=True,
block_shape=None,
),
# block-quantized weights and 128 block per-token activations
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=[128, 128]),
TestMoEQuantConfig(
quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=[128, 128],
),
# TODO (varun) : Should we test the following combinations ?
# block-quantized weights and per-token activations
# block-quantized weights and per-tensor activations
@@ -355,10 +386,12 @@ MK_QUANT_CONFIGS: list[Optional[TestMoEQuantConfig]] = [
if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
MK_QUANT_CONFIGS += [
TestMoEQuantConfig(quant_dtype="nvfp4",
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=None),
TestMoEQuantConfig(
quant_dtype="nvfp4",
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=None,
),
]
@@ -370,12 +403,14 @@ def make_prepare_finalize(
) -> mk.FusedMoEPrepareAndFinalize:
if backend != "naive" and backend is not None:
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(
moe, quant_config)
moe, quant_config
)
assert prepare_finalize is not None
return prepare_finalize
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
return create_flashinfer_prepare_finalize(
use_dp=moe.moe_parallel_config.dp_size > 1)
use_dp=moe.moe_parallel_config.dp_size > 1
)
else:
return MoEPrepareAndFinalizeNoEP()
@@ -391,10 +426,10 @@ def make_cutlass_strides(
n: int,
k: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
return ab_strides1, ab_strides2, c_strides1, c_strides2
@@ -405,7 +440,6 @@ def make_fused_experts(
num_dispatchers: int,
N: int,
) -> mk.FusedMoEPermuteExpertsUnpermute:
batch_kwargs = {
"max_num_tokens": moe.max_num_tokens,
"num_dispatchers": num_dispatchers,

View File

@@ -6,13 +6,11 @@ import traceback
from typing import Any, Callable, Optional
import torch
from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage]
from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed import (init_distributed_environment,
initialize_model_parallel)
from vllm.distributed import init_distributed_environment, initialize_model_parallel
from vllm.utils import get_open_port
## Parallel Processes Utils
@@ -30,10 +28,11 @@ class ProcessGroupInfo:
device: torch.device
def _set_vllm_config(vllm_config: VllmConfig, world_size: int, rank: int,
local_rank: int):
def _set_vllm_config(
vllm_config: VllmConfig, world_size: int, rank: int, local_rank: int
):
import tempfile
temp_file = tempfile.mkstemp()[1]
with set_current_vllm_config(vllm_config):
@@ -46,13 +45,10 @@ def _set_vllm_config(vllm_config: VllmConfig, world_size: int, rank: int,
)
initialize_model_parallel(
tensor_model_parallel_size=vllm_config.parallel_config.
tensor_parallel_size,
pipeline_model_parallel_size=vllm_config.parallel_config.
pipeline_parallel_size,
tensor_model_parallel_size=vllm_config.parallel_config.tensor_parallel_size,
pipeline_model_parallel_size=vllm_config.parallel_config.pipeline_parallel_size,
)
cpu_group = torch.distributed.new_group(list(range(world_size)),
backend="gloo")
cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo")
return cpu_group
@@ -62,8 +58,7 @@ def _worker_parallel_launch(
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, Optional[VllmConfig], Any,
P], None],
worker: Callable[Concatenate[ProcessGroupInfo, Optional[VllmConfig], Any, P], None],
vllm_config: Optional[VllmConfig],
env_dict: Optional[dict],
*args: P.args,
@@ -131,7 +126,8 @@ def parallel_launch_with_config(
worker,
vllm_config,
env_dict,
) + args,
)
+ args,
nprocs=world_size,
join=True,
)

View File

@@ -14,28 +14,31 @@ from .common import Config, RankTensors, WeightTensors, make_modular_kernel
from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config
def do_profile(fn: Callable,
fn_kwargs: dict[Any, Any],
pgi: ProcessGroupInfo,
config: Config,
num_warmups: int = 5):
def do_profile(
fn: Callable,
fn_kwargs: dict[Any, Any],
pgi: ProcessGroupInfo,
config: Config,
num_warmups: int = 5,
):
for _ in range(num_warmups):
fn(**fn_kwargs)
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
record_shapes=True,
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
record_shapes=True,
) as tprof:
fn(**fn_kwargs)
torch.cuda.synchronize(torch.cuda.current_device())
# TODO (varun): Add a descriptive trace file name
tprof.export_chrome_trace(
f"{config.torch_trace_dir_path}/m{config.M}_{pgi.rank}_trace.json")
f"{config.torch_trace_dir_path}/m{config.M}_{pgi.rank}_trace.json"
)
def profile_modular_kernel(
@@ -82,6 +85,7 @@ def rank_worker(
# sanity check
from vllm import envs
if config.fused_moe_chunk_size is not None:
assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
@@ -108,20 +112,25 @@ def rank_worker(
def run(config: Config):
weights: WeightTensors = WeightTensors.make(config)
vllm_config, env_dict = config.make_env_data()
parallel_launch_with_config(config.world_size, rank_worker, vllm_config,
env_dict, config, weights)
parallel_launch_with_config(
config.world_size, rank_worker, vllm_config, env_dict, config, weights
)
if __name__ == '__main__':
if __name__ == "__main__":
from .cli_args import make_config, make_config_arg_parser
parser = make_config_arg_parser(description=(
"Run single prepare-finalize & fused-experts combination test"
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel " #noqa: E501
"--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
))
parser = make_config_arg_parser(
description=(
"Run single prepare-finalize & fused-experts combination test"
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel " # noqa: E501
"--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
)
)
args = parser.parse_args()
assert args.torch_trace_dir_path is not None, (
"Please pass in a directory to store torch traces")
"Please pass in a directory to store torch traces"
)
config = make_config(args)
run(config)

View File

@@ -3,6 +3,7 @@
"""
DeepEP test utilities
"""
import dataclasses
import os
import traceback
@@ -10,17 +11,18 @@ from typing import Callable, Optional
import torch
from torch.distributed import ProcessGroup
from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage]
from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
from vllm.utils import get_open_port, has_deep_ep
if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
DeepEPHTPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
DeepEPLLPrepareAndFinalize,
)
## Parallel Processes Utils
@@ -96,7 +98,8 @@ def parallel_launch(
0,
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
worker,
) + args,
)
+ args,
nprocs=world_size,
join=True,
)
@@ -118,48 +121,57 @@ class DeepEPLLArgs:
use_fp8_dispatch: bool
def make_deepep_ht_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
ht_args: DeepEPHTArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
def make_deepep_ht_a2a(
pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
ht_args: DeepEPHTArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
):
import deep_ep
# high throughput a2a
num_nvl_bytes = 1024 * 1024 * 1024 # 1GB
num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1
buffer = deep_ep.Buffer(group=pg,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=low_latency_mode,
num_qps_per_rank=num_qps_per_rank)
return DeepEPHTPrepareAndFinalize(buffer=buffer,
num_dispatchers=pgi.world_size,
dp_size=dp_size,
rank_expert_offset=pgi.rank *
ht_args.num_local_experts)
buffer = deep_ep.Buffer(
group=pg,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=low_latency_mode,
num_qps_per_rank=num_qps_per_rank,
)
return DeepEPHTPrepareAndFinalize(
buffer=buffer,
num_dispatchers=pgi.world_size,
dp_size=dp_size,
rank_expert_offset=pgi.rank * ht_args.num_local_experts,
)
def make_deepep_ll_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
deepep_ll_args: DeepEPLLArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
def make_deepep_ll_a2a(
pg: ProcessGroup,
pgi: ProcessGroupInfo,
deepep_ll_args: DeepEPLLArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
):
import deep_ep
# low-latency a2a
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size,
pgi.world_size, deepep_ll_args.num_experts)
deepep_ll_args.max_tokens_per_rank,
deepep_ll_args.hidden_size,
pgi.world_size,
deepep_ll_args.num_experts,
)
buffer = deep_ep.Buffer(group=pg,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=deepep_ll_args.num_experts //
pgi.world_size)
buffer = deep_ep.Buffer(
group=pg,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=deepep_ll_args.num_experts // pgi.world_size,
)
return DeepEPLLPrepareAndFinalize(
buffer=buffer,
@@ -169,17 +181,20 @@ def make_deepep_ll_a2a(pg: ProcessGroup,
)
def make_deepep_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
deepep_ht_args: Optional[DeepEPHTArgs],
deepep_ll_args: Optional[DeepEPLLArgs],
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
def make_deepep_a2a(
pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
deepep_ht_args: Optional[DeepEPHTArgs],
deepep_ll_args: Optional[DeepEPLLArgs],
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
):
if deepep_ht_args is not None:
assert deepep_ll_args is None
return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype,
block_shape)
return make_deepep_ht_a2a(
pg, pgi, dp_size, deepep_ht_args, q_dtype, block_shape
)
assert deepep_ll_args is not None
return make_deepep_ll_a2a(pg, pgi, deepep_ll_args, q_dtype, block_shape)

View File

@@ -5,13 +5,14 @@ import pytest
import torch
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
BatchedDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
BatchedPrepareAndFinalize,
BatchedTritonExperts,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.utils.deep_gemm import calc_diff, is_deep_gemm_supported
from .test_deepgemm import make_block_quant_fp8_weights
@@ -19,15 +20,15 @@ from .test_deepgemm import make_block_quant_fp8_weights
BLOCK_SIZE = [128, 128]
@pytest.mark.skipif(not is_deep_gemm_supported(),
reason="Requires deep_gemm kernels")
@pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels")
@pytest.mark.parametrize("E", [16, 32]) # number of experts
@pytest.mark.parametrize("T", [256, 512]) # tokens per expert
@pytest.mark.parametrize("K", [128, 256]) # hidden dim
@pytest.mark.parametrize("N", [512, 1024]) # intermediate dim per expert
@pytest.mark.parametrize("topk", [2, 4])
def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
monkeypatch):
def test_batched_deepgemm_vs_triton(
E: int, T: int, K: int, N: int, topk: int, monkeypatch
):
"""Compare BatchedDeepGemmExperts to BatchedTritonExperts."""
monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1")

View File

@@ -7,14 +7,18 @@ from typing import Optional
import pytest
import torch
from tests.kernels.moe.utils import (batched_moe,
make_quantized_test_activations,
make_test_weights, naive_batched_moe)
from tests.kernels.moe.utils import (
batched_moe,
make_quantized_test_activations,
make_test_weights,
naive_batched_moe,
)
from tests.kernels.quant_utils import native_batched_masked_quant_matmul
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
invoke_moe_batched_triton_kernel)
invoke_moe_batched_triton_kernel,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform
from vllm.triton_utils import tl
@@ -68,23 +72,32 @@ class BatchedMMTensors:
@staticmethod
def make_tensors(config: BatchedMMConfig):
A = torch.randn(
(config.num_experts, config.max_tokens_per_expert, config.K),
A = (
torch.randn(
(config.num_experts, config.max_tokens_per_expert, config.K),
device="cuda",
dtype=config.in_dtype,
)
/ 10
)
B = torch.randn(
(config.num_experts, config.N, config.K),
device="cuda",
dtype=config.in_dtype) / 10
B = torch.randn((config.num_experts, config.N, config.K),
device="cuda",
dtype=config.in_dtype)
dtype=config.in_dtype,
)
C = torch.zeros(
(config.num_experts, config.max_tokens_per_expert, config.N),
device="cuda",
dtype=config.out_dtype)
dtype=config.out_dtype,
)
num_expert_tokens = torch.randint(low=0,
high=config.max_tokens_per_expert,
size=(config.num_experts, ),
device="cuda",
dtype=torch.int32)
num_expert_tokens = torch.randint(
low=0,
high=config.max_tokens_per_expert,
size=(config.num_experts,),
device="cuda",
dtype=torch.int32,
)
return BatchedMMTensors(A, B, C, num_expert_tokens)
@@ -96,10 +109,15 @@ class BatchedMMTensors:
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
N: int, dtype: torch.dtype,
block_shape: Optional[list[int]],
per_act_token_quant: bool):
def test_batched_mm(
num_experts: int,
max_tokens_per_expert: int,
K: int,
N: int,
dtype: torch.dtype,
block_shape: Optional[list[int]],
per_act_token_quant: bool,
):
current_platform.seed_everything(7)
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
@@ -117,11 +135,13 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
act_dtype = dtype
quant_dtype = None
num_expert_tokens = torch.randint(low=0,
high=max_tokens_per_expert,
size=(num_experts, ),
device="cuda",
dtype=torch.int32)
num_expert_tokens = torch.randint(
low=0,
high=max_tokens_per_expert,
size=(num_experts,),
device="cuda",
dtype=torch.int32,
)
A, A_q, A_scale = make_quantized_test_activations(
num_experts,
@@ -151,7 +171,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
compute_tl_dtype = {
torch.float16: tl.float16,
torch.bfloat16: tl.bfloat16,
torch.float32: tl.float32
torch.float32: tl.float32,
}[test_output.dtype]
assert A_q.dtype == B_q.dtype
@@ -173,7 +193,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
config={
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32
"BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32,
},
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
@@ -186,11 +206,16 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
num_expert_tokens,
)
q_ref_output = native_batched_masked_quant_matmul(A_q, B_q, q_ref_output,
num_expert_tokens,
A_scale, B_scale,
block_shape,
per_act_token_quant)
q_ref_output = native_batched_masked_quant_matmul(
A_q,
B_q,
q_ref_output,
num_expert_tokens,
A_scale,
B_scale,
block_shape,
per_act_token_quant,
)
rtol, atol = {
torch.float16: (6e-2, 6e-2),
@@ -308,12 +333,6 @@ def test_fused_moe_batched_experts(
block_shape=block_shape,
)
torch.testing.assert_close(batched_output,
baseline_output,
atol=3e-2,
rtol=2e-2)
torch.testing.assert_close(batched_output, baseline_output, atol=3e-2, rtol=2e-2)
torch.testing.assert_close(triton_output,
batched_output,
atol=2e-2,
rtol=2e-2)
torch.testing.assert_close(triton_output, batched_output, atol=2e-2, rtol=2e-2)

View File

@@ -5,15 +5,21 @@ import pytest
import torch
from tests.kernels.moe.utils import make_test_quant_config, make_test_weights
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
native_w8a8_block_matmul)
from tests.kernels.quant_utils import (
native_per_token_group_quant_fp8,
native_w8a8_block_matmul,
)
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
_valid_deep_gemm_shape,
deep_gemm_moe_fp8,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
fused_topk,
modular_triton_fused_moe,
)
from vllm.platforms import current_platform
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
@@ -24,8 +30,7 @@ if dg_available:
from deep_gemm import get_m_alignment_for_contiguous_layout
if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
allow_module_level=True)
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
@@ -97,8 +102,7 @@ TOP_KS = [1, 2, 6]
SEEDS = [0]
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids,
block_shape):
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, block_shape):
"""Fused moe with block-wise quantization using native torch."""
B, D = a.shape
topk = topk_ids.size(1)
@@ -114,23 +118,17 @@ def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids,
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
inter_out = native_w8a8_block_matmul(a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
block_shape,
output_dtype=a.dtype)
inter_out = native_w8a8_block_matmul(
a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype
)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = native_per_token_group_quant_fp8(
act_out, block_k)
out[mask] = native_w8a8_block_matmul(act_out_q,
w2[i],
act_out_s,
w2_s[i],
block_shape,
output_dtype=a.dtype)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
act_out_q, act_out_s = native_per_token_group_quant_fp8(act_out, block_k)
out[mask] = native_w8a8_block_matmul(
act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype
)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
# Skip all tests if CUDA is not available
@@ -149,8 +147,9 @@ def setup_cuda():
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
monkeypatch):
def test_w8a8_block_fp8_fused_moe(
M, N, K, E, topk, block_size, dtype, seed, monkeypatch
):
if topk > E:
pytest.skip(f"Skipping test; topk={topk} > E={E}")
@@ -188,12 +187,9 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
block_size,
)
out = fused_experts(a,
w1,
w2,
topk_weights,
topk_ids,
quant_config=quant_config)
out = fused_experts(
a, w1, w2, topk_weights, topk_ids, quant_config=quant_config
)
m_out = m_fused_moe(a, w1, w2, topk_weights, topk_ids)
@@ -210,8 +206,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Not E8M0 scale MOE")
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
monkeypatch):
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch):
if topk > E:
pytest.skip(f"Skipping test: topk={topk} > E={E}")
@@ -245,36 +240,38 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
# setup code in case we are able to revisit this later.
use_compile = False
use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024
and current_platform.is_cuda_alike())
use_cudagraph = (
chunk_size < M and N >= 1024 and K >= 1024 and current_platform.is_cuda_alike()
)
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weights,
topk_ids, block_size)
ref_out = torch_w8a8_block_fp8_moe(
a, w1, w2, w1_s, w2_s, topk_weights, topk_ids, block_size
)
if use_compile:
deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8,
backend="inductor",
fullgraph=True)
deep_gemm_moe_fp8_fn = torch.compile(
deep_gemm_moe_fp8, backend="inductor", fullgraph=True
)
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(topk_weights, 0)
torch._dynamo.mark_dynamic(topk_ids, 0)
else:
deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
topk_ids)
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
if use_cudagraph:
out.fill_(0)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
topk_ids)
out = deep_gemm_moe_fp8_fn(
a, w1, w2, w1_s, w2_s, topk_weights, topk_ids
)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()

View File

@@ -5,16 +5,17 @@ import pytest
import torch
from tests.kernels.moe.utils import make_test_quant_config
from tests.kernels.quant_utils import (native_per_token_group_quant_int8,
native_w8a8_block_matmul)
from tests.kernels.quant_utils import (
native_per_token_group_quant_int8,
native_w8a8_block_matmul,
)
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.platforms import current_platform
if current_platform.get_device_capability() < (7, 0):
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
allow_module_level=True)
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
@@ -77,24 +78,18 @@ def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
inter_out = native_w8a8_block_matmul(a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
block_shape,
output_dtype=a.dtype)
inter_out = native_w8a8_block_matmul(
a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype
)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = native_per_token_group_quant_int8(
act_out, block_k)
act_out_q, act_out_s = native_per_token_group_quant_int8(act_out, block_k)
act_out = act_out.to(torch.float32)
out[mask] = native_w8a8_block_matmul(act_out_q,
w2[i],
act_out_s,
w2_s[i],
block_shape,
output_dtype=a.dtype)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
out[mask] = native_w8a8_block_matmul(
act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype
)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
@pytest.fixture(autouse=True, scope="module")
@@ -131,15 +126,19 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
out = fused_experts(a,
w1,
w2,
topk_weights,
topk_ids,
quant_config=quant_config)
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, quant_config.w1_scale,
quant_config.w2_scale, score, topk,
block_size)
out = fused_experts(
a, w1, w2, topk_weights, topk_ids, quant_config=quant_config
)
ref_out = torch_w8a8_block_int8_moe(
a,
w1,
w2,
quant_config.w1_scale,
quant_config.w2_scale,
score,
topk,
block_size,
)
# Check results
torch.testing.assert_close(out, ref_out, atol=0.065, rtol=0.065)

View File

@@ -15,7 +15,6 @@ from vllm.model_executor.layers.fused_moe.utils import count_expert_num_tokens
@dataclasses.dataclass
class TestTensors:
topk_ids: torch.Tensor
expert_map: Optional[torch.Tensor] = None
@@ -25,32 +24,31 @@ class TestTensors:
self.expert_map = self.expert_map.to(device=device)
@staticmethod
def make(num_tokens: int, num_topk: int, num_experts: int, device: str,
topk_ids_dtype: torch.dtype) -> "TestTensors":
def make(
num_tokens: int,
num_topk: int,
num_experts: int,
device: str,
topk_ids_dtype: torch.dtype,
) -> "TestTensors":
# make topk ids
topk_ids = torch.empty((num_tokens, num_topk),
device=device,
dtype=torch.int64)
topk_ids = torch.empty((num_tokens, num_topk), device=device, dtype=torch.int64)
for x in range(num_tokens):
topk_ids[x] = torch.randperm(num_experts)[:num_topk]
topk_ids = topk_ids.to(dtype=torch.int64)
return TestTensors(topk_ids=topk_ids)
def with_ep_rank(self, ep_rank: int, num_global_experts: int,
num_local_experts: int, device: str):
def with_ep_rank(
self, ep_rank: int, num_global_experts: int, num_local_experts: int, device: str
):
# make an expert map
expert_map = torch.empty((num_global_experts),
device=device,
dtype=torch.int32)
expert_map = torch.empty((num_global_experts), device=device, dtype=torch.int32)
expert_map.fill_(-1)
s = ep_rank * num_local_experts
e = s + num_local_experts
expert_map[s:e] = torch.tensor(list(range(num_local_experts)),
device=device)
expert_map[s:e] = torch.tensor(list(range(num_local_experts)), device=device)
return TestTensors(topk_ids=self.topk_ids.clone(),
expert_map=expert_map)
return TestTensors(topk_ids=self.topk_ids.clone(), expert_map=expert_map)
def ref_impl(tt: TestTensors, expert_num_tokens: torch.Tensor):
@@ -68,49 +66,49 @@ def ref_impl(tt: TestTensors, expert_num_tokens: torch.Tensor):
expert_num_tokens[eid] += count
def do_test_compute_expert_num_tokens(num_tokens: int, num_topk: int,
num_experts: int, ep_size: int,
topk_ids_dtype: torch.dtype):
def do_test_compute_expert_num_tokens(
num_tokens: int,
num_topk: int,
num_experts: int,
ep_size: int,
topk_ids_dtype: torch.dtype,
):
assert num_topk <= num_experts
tt = TestTensors.make(num_tokens,
num_topk,
num_experts,
topk_ids_dtype=topk_ids_dtype,
device="cpu")
tt = TestTensors.make(
num_tokens, num_topk, num_experts, topk_ids_dtype=topk_ids_dtype, device="cpu"
)
num_global_experts = num_experts
assert num_global_experts % ep_size == 0
num_local_experts = num_global_experts // ep_size
for ep_rank in range(ep_size):
tt_rank = tt.with_ep_rank(ep_rank, num_global_experts,
num_local_experts, "cpu")
tt_rank = tt.with_ep_rank(ep_rank, num_global_experts, num_local_experts, "cpu")
ref_expert_num_tokens = torch.zeros((num_local_experts),
device="cpu",
dtype=torch.int32)
ref_expert_num_tokens = torch.zeros(
(num_local_experts), device="cpu", dtype=torch.int32
)
ref_impl(tt_rank, ref_expert_num_tokens)
ref_expert_num_tokens = ref_expert_num_tokens.to("cuda")
tt_rank.to_device("cuda")
# Test with expert_map
triton_expert_num_tokens_w_emap = count_expert_num_tokens(
tt_rank.topk_ids, num_local_experts, tt_rank.expert_map)
tt_rank.topk_ids, num_local_experts, tt_rank.expert_map
)
# Test without expert map
topk_ids = tt_rank.expert_map[tt_rank.topk_ids].to(topk_ids_dtype)
triton_expert_num_tokens_wo_emap = count_expert_num_tokens(
topk_ids, num_local_experts, expert_map=None)
topk_ids, num_local_experts, expert_map=None
)
torch.testing.assert_close(ref_expert_num_tokens,
triton_expert_num_tokens_w_emap,
atol=0,
rtol=0)
torch.testing.assert_close(ref_expert_num_tokens,
triton_expert_num_tokens_wo_emap,
atol=0,
rtol=0)
torch.testing.assert_close(
ref_expert_num_tokens, triton_expert_num_tokens_w_emap, atol=0, rtol=0
)
torch.testing.assert_close(
ref_expert_num_tokens, triton_expert_num_tokens_wo_emap, atol=0, rtol=0
)
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 11, 127, 128, 3333, 7317])
@@ -118,22 +116,29 @@ def do_test_compute_expert_num_tokens(num_tokens: int, num_topk: int,
@pytest.mark.parametrize("num_experts", [64])
@pytest.mark.parametrize("ep_size", [1, 2, 4])
@pytest.mark.parametrize("topk_ids_dtype", [torch.int64])
def test_compute_expert_num_tokens(num_tokens: int, num_topk: int,
num_experts: int, ep_size: int,
topk_ids_dtype: torch.dtype):
do_test_compute_expert_num_tokens(num_tokens, num_topk, num_experts,
ep_size, topk_ids_dtype)
def test_compute_expert_num_tokens(
num_tokens: int,
num_topk: int,
num_experts: int,
ep_size: int,
topk_ids_dtype: torch.dtype,
):
do_test_compute_expert_num_tokens(
num_tokens, num_topk, num_experts, ep_size, topk_ids_dtype
)
@pytest.mark.parametrize("numel", list(range(1, 8192, 111)))
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("ep_size", [2])
@pytest.mark.parametrize("topk_ids_dtype", [torch.int64])
def test_compute_expert_num_tokens_from_numel(numel: int, num_experts: int,
ep_size: int,
topk_ids_dtype: torch.dtype):
do_test_compute_expert_num_tokens(num_tokens=numel,
num_topk=1,
num_experts=num_experts,
ep_size=ep_size,
topk_ids_dtype=topk_ids_dtype)
def test_compute_expert_num_tokens_from_numel(
numel: int, num_experts: int, ep_size: int, topk_ids_dtype: torch.dtype
):
do_test_compute_expert_num_tokens(
num_tokens=numel,
num_topk=1,
num_experts=num_experts,
ep_size=ep_size,
topk_ids_dtype=topk_ids_dtype,
)

View File

@@ -17,19 +17,24 @@ from vllm.utils import cdiv
from vllm.utils.deep_gemm import per_block_cast_to_fp8
@pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [
(4, 8192, 7168, 4096),
(4, 8192, 2048, 7168),
(8, 4096, 7168, 4096),
(8, 4096, 2048, 7168),
(32, 1024, 7168, 4096),
(32, 1024, 2048, 7168),
])
@pytest.mark.parametrize(
"num_groups, expected_m_per_group, k, n",
[
(4, 8192, 7168, 4096),
(4, 8192, 2048, 7168),
(8, 4096, 7168, 4096),
(8, 4096, 2048, 7168),
(32, 1024, 7168, 4096),
(32, 1024, 2048, 7168),
],
)
@pytest.mark.parametrize("out_dtype", [torch.float16])
@pytest.mark.skipif(
(lambda x: x is None or x.to_int() != 100)(
current_platform.get_device_capability()),
reason="Block Scaled Grouped GEMM is only supported on SM100.")
current_platform.get_device_capability()
),
reason="Block Scaled Grouped GEMM is only supported on SM100.",
)
def test_cutlass_grouped_gemm(
num_groups: int,
expected_m_per_group: int,
@@ -40,8 +45,7 @@ def test_cutlass_grouped_gemm(
device = "cuda"
alignment = 128
group_ms = [
int(expected_m_per_group * random.uniform(0.7, 1.3))
for _ in range(num_groups)
int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)
]
m = sum([cdiv(m, alignment) * alignment for m in group_ms])
@@ -58,20 +62,22 @@ def test_cutlass_grouped_gemm(
expert_offsets = torch.tensor(ep_offset, device=device, dtype=torch.int32)
x_fp8 = per_token_cast_to_fp8(x)
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn),
torch.empty((num_groups, cdiv(n, 128), k // 128),
device=device,
dtype=torch.float))
y_fp8 = (
torch.empty_like(y, dtype=torch.float8_e4m3fn),
torch.empty(
(num_groups, cdiv(n, 128), k // 128), device=device, dtype=torch.float
),
)
for i in range(num_groups):
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i], [128, 128])
for i in range(num_groups):
a = x_fp8[0][ep_offset[i]:ep_offset[i + 1]]
a_scale = x_fp8[1][ep_offset[i]:ep_offset[i + 1]]
a = x_fp8[0][ep_offset[i] : ep_offset[i + 1]]
a_scale = x_fp8[1][ep_offset[i] : ep_offset[i + 1]]
b = y_fp8[0][i].t()
b_scale = y_fp8[1][i].t()
baseline = baseline_scaled_mm(a, b, a_scale, b_scale, out_dtype)
ref_out[ep_offset[i]:ep_offset[i + 1]] = baseline
ref_out[ep_offset[i] : ep_offset[i + 1]] = baseline
ops.cutlass_blockwise_scaled_grouped_mm(
out,

View File

@@ -11,13 +11,15 @@ import torch
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, fp8_w8a8_moe_quant_config)
FUSED_MOE_UNQUANTIZED_CONFIG,
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8, run_cutlass_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
fused_topk)
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
cutlass_moe_fp8,
run_cutlass_moe_fp8,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.platforms import current_platform
NUM_EXPERTS = [40, 64]
@@ -39,12 +41,11 @@ MNK_FACTORS = [
(224, 3072, 1536),
(32768, 1024, 1024),
# These sizes trigger wrong answers.
#(7232, 2048, 5120),
#(40000, 2048, 5120),
# (7232, 2048, 5120),
# (40000, 2048, 5120),
]
vllm_config = VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@@ -60,22 +61,25 @@ class MOETensors:
c_strides2: torch.Tensor
@staticmethod
def make_moe_tensors(m: int, k: int, n: int, e: int,
dtype: torch.dtype) -> "MOETensors":
def make_moe_tensors(
m: int, k: int, n: int, e: int, dtype: torch.dtype
) -> "MOETensors":
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
return MOETensors(a=a,
w1=w1,
w2=w2,
ab_strides1=ab_strides1,
c_strides1=c_strides1,
ab_strides2=ab_strides2,
c_strides2=c_strides2)
ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
return MOETensors(
a=a,
w1=w1,
w2=w2,
ab_strides1=ab_strides1,
c_strides1=c_strides1,
ab_strides2=ab_strides2,
c_strides2=c_strides2,
)
@dataclasses.dataclass
@@ -93,9 +97,9 @@ class MOETensors8Bit(MOETensors):
w2_d: Optional[torch.Tensor] = None # w2 -> w2_q -> w2_d
@staticmethod
def make_moe_tensors_8bit(m: int, k: int, n: int, e: int,
per_act_token: bool,
per_out_channel: bool) -> "MOETensors8Bit":
def make_moe_tensors_8bit(
m: int, k: int, n: int, e: int, per_act_token: bool, per_out_channel: bool
) -> "MOETensors8Bit":
dtype = torch.half
q_dtype = torch.float8_e4m3fn
@@ -106,24 +110,21 @@ class MOETensors8Bit(MOETensors):
k_b_scales = k if per_out_channel else 1
# Get the right scale for tests.
a_q, a_scale = ops.scaled_fp8_quant(
moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token)
moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token
)
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
w1_scale = torch.empty((e, n_b_scales, 1),
device="cuda",
dtype=torch.float32)
w2_scale = torch.empty((e, k_b_scales, 1),
device="cuda",
dtype=torch.float32)
w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32)
w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32)
for expert in range(e):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
moe_tensors_fp16.w1[expert],
use_per_token_if_dynamic=per_out_channel)
moe_tensors_fp16.w1[expert], use_per_token_if_dynamic=per_out_channel
)
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
moe_tensors_fp16.w2[expert],
use_per_token_if_dynamic=per_out_channel)
moe_tensors_fp16.w2[expert], use_per_token_if_dynamic=per_out_channel
)
# a_q -> a_d, w1_q -> w1_d, w2_q -> w2_d
a_d = a_q.float().mul(a_scale).to(dtype)
@@ -133,31 +134,37 @@ class MOETensors8Bit(MOETensors):
w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half()
w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half()
return MOETensors8Bit(a=moe_tensors_fp16.a,
w1=moe_tensors_fp16.w1,
w2=moe_tensors_fp16.w2,
ab_strides1=moe_tensors_fp16.ab_strides1,
c_strides1=moe_tensors_fp16.c_strides1,
ab_strides2=moe_tensors_fp16.ab_strides2,
c_strides2=moe_tensors_fp16.c_strides2,
a_q=a_q,
w1_q=w1_q,
w2_q=w2_q,
a_scale=a_scale,
w1_scale=w1_scale,
w2_scale=w2_scale,
a_d=a_d,
w1_d=w1_d,
w2_d=w2_d)
return MOETensors8Bit(
a=moe_tensors_fp16.a,
w1=moe_tensors_fp16.w1,
w2=moe_tensors_fp16.w2,
ab_strides1=moe_tensors_fp16.ab_strides1,
c_strides1=moe_tensors_fp16.c_strides1,
ab_strides2=moe_tensors_fp16.ab_strides2,
c_strides2=moe_tensors_fp16.c_strides2,
a_q=a_q,
w1_q=w1_q,
w2_q=w2_q,
a_scale=a_scale,
w1_scale=w1_scale,
w2_scale=w2_scale,
a_d=a_d,
w1_d=w1_d,
w2_d=w2_d,
)
def run_with_expert_maps(num_experts: int, num_local_experts: int,
**cutlass_moe_kwargs):
def run_with_expert_maps(
num_experts: int, num_local_experts: int, **cutlass_moe_kwargs
):
def slice_experts():
slice_params = [
"w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1",
"c_strides2"
"w1_q",
"w2_q",
"ab_strides1",
"ab_strides2",
"c_strides1",
"c_strides2",
]
full_tensors = {
k: v
@@ -173,9 +180,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
# make expert map
expert_map = [-1] * num_experts
expert_map[s:e] = list(range(num_local_experts))
expert_map = torch.tensor(expert_map,
dtype=torch.int32,
device="cuda")
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
# update cutlass moe arg with expert_map
cutlass_moe_kwargs["expert_map"] = expert_map
@@ -198,18 +203,26 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
return out_tensor
def run_8_bit(moe_tensors: MOETensors8Bit,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
per_out_ch: bool,
num_local_experts: Optional[int] = None) -> torch.Tensor:
assert not any([
t is None for t in [
moe_tensors.w1_q, moe_tensors.w2_q, moe_tensors.w1_scale,
moe_tensors.w2_scale, moe_tensors.a_scale
def run_8_bit(
moe_tensors: MOETensors8Bit,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
per_out_ch: bool,
num_local_experts: Optional[int] = None,
) -> torch.Tensor:
assert not any(
[
t is None
for t in [
moe_tensors.w1_q,
moe_tensors.w2_q,
moe_tensors.w1_scale,
moe_tensors.w2_scale,
moe_tensors.a_scale,
]
]
])
)
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=moe_tensors.w1_scale,
@@ -222,16 +235,16 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
)
kwargs = {
'a': moe_tensors.a,
'w1_q': moe_tensors.w1_q, # type: ignore[union-attr]
'w2_q': moe_tensors.w2_q, # type: ignore[union-attr]
'topk_weights': topk_weights,
'topk_ids': topk_ids,
'ab_strides1': moe_tensors.ab_strides1,
'ab_strides2': moe_tensors.ab_strides2,
'c_strides1': moe_tensors.c_strides1,
'c_strides2': moe_tensors.c_strides2,
'quant_config': quant_config,
"a": moe_tensors.a,
"w1_q": moe_tensors.w1_q, # type: ignore[union-attr]
"w2_q": moe_tensors.w2_q, # type: ignore[union-attr]
"topk_weights": topk_weights,
"topk_ids": topk_ids,
"ab_strides1": moe_tensors.ab_strides1,
"ab_strides2": moe_tensors.ab_strides2,
"c_strides1": moe_tensors.c_strides1,
"c_strides2": moe_tensors.c_strides2,
"quant_config": quant_config,
}
num_experts = moe_tensors.w1.size(0)
@@ -243,7 +256,8 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
return run_with_expert_maps(
num_experts,
num_local_experts, # type: ignore[arg-type]
**kwargs)
**kwargs,
)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@@ -253,8 +267,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
reason="Grouped gemm is not supported on this GPU type.")
current_platform.get_device_capability()
),
reason="Grouped gemm is not supported on this GPU type.",
)
def test_cutlass_moe_8_bit_no_graph(
m: int,
n: int,
@@ -269,25 +285,18 @@ def test_cutlass_moe_8_bit_no_graph(
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_ch)
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch)
score = torch.randn((m, e), device="cuda", dtype=torch.half)
topk_weights, topk_ids, _ = fused_topk(mt.a,
score,
topk,
renormalize=False)
topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False)
# Note that we are using the dequantized versions of the tensors.
# Using a, w1 and w2 directly results in minor output differences.
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
triton_output = fused_experts(mt.a_d,
mt.w1_d,
mt.w2_d,
topk_weights,
topk_ids,
quant_config=quant_config)
triton_output = fused_experts(
mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids, quant_config=quant_config
)
if ep_size is not None:
assert e % ep_size == 0, "Cannot distribute experts evenly"
@@ -295,15 +304,15 @@ def test_cutlass_moe_8_bit_no_graph(
else:
number_local_experts = None
cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token,
per_out_ch, number_local_experts)
cutlass_output = run_8_bit(
mt, topk_weights, topk_ids, per_act_token, per_out_ch, number_local_experts
)
# Note 5.5 only needed for larger problem sizes, 5 works ok for
# the rest.
torch.testing.assert_close(triton_output,
cutlass_output,
atol=5.5e-2,
rtol=1e-2)
torch.testing.assert_close(
triton_output, cutlass_output, atol=5.5e-2, rtol=1e-2
)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@@ -313,8 +322,10 @@ def test_cutlass_moe_8_bit_no_graph(
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
reason="Grouped gemm is not supported on this GPU type.")
current_platform.get_device_capability()
),
reason="Grouped gemm is not supported on this GPU type.",
)
def test_cutlass_moe_8_bit_cuda_graph(
m: int,
n: int,
@@ -330,39 +341,30 @@ def test_cutlass_moe_8_bit_cuda_graph(
with set_current_vllm_config(vllm_config):
dtype = torch.half
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_ch)
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(mt.a,
score,
topk,
renormalize=False)
topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False)
# Note that we are using the dequantized versions of the tensors.
# Using a, w1 and w2 directly results in minor output differences.
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
triton_output = fused_experts(mt.a_d,
mt.w1_d,
mt.w2_d,
topk_weights,
topk_ids,
quant_config=quant_config)
triton_output = fused_experts(
mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids, quant_config=quant_config
)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
cutlass_output = run_8_bit(mt, topk_weights, topk_ids,
per_act_token, per_out_ch)
cutlass_output = run_8_bit(
mt, topk_weights, topk_ids, per_act_token, per_out_ch
)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()
torch.testing.assert_close(triton_output,
cutlass_output,
atol=9e-2,
rtol=1e-2)
torch.testing.assert_close(triton_output, cutlass_output, atol=9e-2, rtol=1e-2)
@pytest.mark.parametrize("m", [64])
@@ -375,8 +377,10 @@ def test_cutlass_moe_8_bit_cuda_graph(
@pytest.mark.parametrize("ep_size", [1, 2, 4, 8, 16])
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
reason="Grouped gemm is not supported on this GPU type.")
current_platform.get_device_capability()
),
reason="Grouped gemm is not supported on this GPU type.",
)
def test_cutlass_moe_8_bit_EP(
m: int,
n: int,
@@ -388,8 +392,9 @@ def test_cutlass_moe_8_bit_EP(
ep_size: int,
monkeypatch,
):
test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token,
per_out_channel, monkeypatch, ep_size)
test_cutlass_moe_8_bit_no_graph(
m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size
)
LARGE_MNK_FACTORS = [
@@ -406,8 +411,10 @@ LARGE_MNK_FACTORS = [
@pytest.mark.parametrize("ep_size", [8])
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
reason="Grouped gemm is not supported on this GPU type.")
current_platform.get_device_capability()
),
reason="Grouped gemm is not supported on this GPU type.",
)
def test_cutlass_moe_8_bit_EP_large(
m: int,
n: int,
@@ -419,8 +426,9 @@ def test_cutlass_moe_8_bit_EP_large(
ep_size: int,
monkeypatch,
):
test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token,
per_out_channel, monkeypatch, ep_size)
test_cutlass_moe_8_bit_no_graph(
m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size
)
@pytest.mark.parametrize("m,n,k,topk", [(1, 8192, 5120, 31)])
@@ -430,8 +438,10 @@ def test_cutlass_moe_8_bit_EP_large(
@pytest.mark.parametrize("ep_size", [8])
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
reason="Grouped gemm is not supported on this GPU type.")
current_platform.get_device_capability()
),
reason="Grouped gemm is not supported on this GPU type.",
)
def test_run_cutlass_moe_fp8(
m: int,
n: int,
@@ -444,14 +454,12 @@ def test_run_cutlass_moe_fp8(
):
current_platform.seed_everything(7)
with set_current_vllm_config(vllm_config):
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_channel)
mt = MOETensors8Bit.make_moe_tensors_8bit(
m, k, n, e, per_act_token, per_out_channel
)
score = torch.randn((m, e), device="cuda", dtype=torch.half)
topk_weights, topk_ids, _ = fused_topk(mt.a,
score,
topk,
renormalize=False)
topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False)
# we want to make sure there is at least one token that's generated in
# this expert shard and at least one token that's NOT generated in this
# expert shard
@@ -462,12 +470,12 @@ def test_run_cutlass_moe_fp8(
workspace2_shape = (m * topk, max(n, k))
output_shape = (m, k)
workspace13 = torch.empty(prod(workspace13_shape),
device="cuda",
dtype=mt.a.dtype)
workspace2 = torch.empty(prod(workspace2_shape),
device="cuda",
dtype=mt.a.dtype)
workspace13 = torch.empty(
prod(workspace13_shape), device="cuda", dtype=mt.a.dtype
)
workspace2 = torch.empty(
prod(workspace2_shape), device="cuda", dtype=mt.a.dtype
)
num_local_experts = e // ep_size
start, end = 0, num_local_experts
@@ -475,36 +483,55 @@ def test_run_cutlass_moe_fp8(
expert_map[start:end] = list(range(num_local_experts))
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
torch.float8_e4m3fn,
per_act_token)
a1q, a1q_scale = moe_kernel_quantize_input(
mt.a, mt.a_scale, torch.float8_e4m3fn, per_act_token
)
global_num_experts = -1 if mt.w1_q is None else mt.w1_q.size(0)
func = lambda output: run_cutlass_moe_fp8(
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2,
workspace13, workspace2, None, mt.a.dtype, per_act_token,
per_out_channel, False, topk_weights)
output,
a1q,
mt.w1_q,
mt.w2_q,
topk_ids,
activation,
global_num_experts,
expert_map,
mt.w1_scale,
mt.w2_scale,
a1q_scale,
None,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
workspace13,
workspace2,
None,
mt.a.dtype,
per_act_token,
per_out_channel,
False,
topk_weights,
)
workspace13.random_()
output_random_workspace = torch.empty(output_shape,
device="cuda",
dtype=mt.a.dtype)
output_random_workspace = torch.empty(
output_shape, device="cuda", dtype=mt.a.dtype
)
func(output_random_workspace)
workspace13.fill_(0)
output_zero_workspace = torch.zeros(output_shape,
device="cuda",
dtype=mt.a.dtype)
output_zero_workspace = torch.zeros(
output_shape, device="cuda", dtype=mt.a.dtype
)
func(output_zero_workspace)
torch.testing.assert_close(output_random_workspace,
output_zero_workspace,
atol=5e-3,
rtol=1e-3)
torch.testing.assert_close(
output_random_workspace, output_zero_workspace, atol=5e-3, rtol=1e-3
)

View File

@@ -16,10 +16,11 @@ from typing_extensions import ParamSpec
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported
@@ -30,18 +31,19 @@ from .utils import make_test_weights
if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
DeepEPHTPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
DeepEPLLPrepareAndFinalize,
)
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
if has_deep_gemm():
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts)
BatchedDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
requires_deep_ep = pytest.mark.skipif(
not has_deep_ep(),
@@ -58,9 +60,10 @@ P = ParamSpec("P")
def next_power_of_2(x):
import math
if x == 0:
return 1
return 2**math.ceil(math.log2(x))
return 2 ** math.ceil(math.log2(x))
def make_block_quant_fp8_weights(
@@ -72,13 +75,9 @@ def make_block_quant_fp8_weights(
"""
Return weights w1q, w2q, w1_scale, w2_scale
"""
(_, w1q, w1_scale, _), (_, w2q, w2_scale,
_) = make_test_weights(e,
n,
k,
torch.bfloat16,
torch.float8_e4m3fn,
block_shape=block_size)
(_, w1q, w1_scale, _), (_, w2q, w2_scale, _) = make_test_weights(
e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_shape=block_size
)
return w1q, w2q, w1_scale, w2_scale
@@ -106,15 +105,15 @@ class TestTensors:
@staticmethod
def make(config: TestConfig, rank) -> "TestTensors":
dtype = torch.bfloat16
topk, m, k = (config.topk, config.m, config.k)
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
rank_tokens = torch.randn(
(m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
rank_tokens = (
torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
)
rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
rank_token_scales = None
@@ -122,25 +121,32 @@ class TestTensors:
low=0,
high=config.num_experts,
size=(m, topk),
device=torch.cuda.current_device()).to(dtype=torch.int64)
device=torch.cuda.current_device(),
).to(dtype=torch.int64)
topk_weights = torch.randn(topk_ids.shape,
dtype=torch.float32,
device=torch.cuda.current_device())
topk_weights = torch.randn(
topk_ids.shape, dtype=torch.float32, device=torch.cuda.current_device()
)
return TestTensors(rank_tokens=rank_tokens,
rank_token_scales=rank_token_scales,
topk=topk_ids,
topk_weights=topk_weights,
config=config)
return TestTensors(
rank_tokens=rank_tokens,
rank_token_scales=rank_token_scales,
topk=topk_ids,
topk_weights=topk_weights,
config=config,
)
def make_ll_modular_kernel(
pg: ProcessGroup, pgi: ProcessGroupInfo, max_tokens_per_rank: int,
dp_size: int, hidden_size: int, q_dtype: Optional[torch.dtype],
test_config: TestConfig,
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
pg: ProcessGroup,
pgi: ProcessGroupInfo,
max_tokens_per_rank: int,
dp_size: int,
hidden_size: int,
q_dtype: Optional[torch.dtype],
test_config: TestConfig,
quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
assert test_config.low_latency
assert test_config.use_fp8_dispatch is not None
@@ -153,26 +159,30 @@ def make_ll_modular_kernel(
max_tokens_per_rank=max_tokens_per_rank,
hidden_size=hidden_size,
num_experts=test_config.num_experts,
use_fp8_dispatch=test_config.use_fp8_dispatch),
use_fp8_dispatch=test_config.use_fp8_dispatch,
),
q_dtype=q_dtype,
block_shape=test_config.block_size)
block_shape=test_config.block_size,
)
fused_experts = BatchedDeepGemmExperts(
max_num_tokens=max_tokens_per_rank,
num_dispatchers=pgi.world_size // dp_size,
quant_config=quant_config,
)
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
return mk
def make_ht_modular_kernel(
pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
num_local_experts: int, q_dtype: Optional[torch.dtype],
test_config: TestConfig,
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
num_local_experts: int,
q_dtype: Optional[torch.dtype],
test_config: TestConfig,
quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
assert not test_config.low_latency
assert test_config.use_fp8_dispatch is None
@@ -183,76 +193,82 @@ def make_ht_modular_kernel(
deepep_ht_args=DeepEPHTArgs(num_local_experts=num_local_experts),
deepep_ll_args=None,
q_dtype=q_dtype,
block_shape=test_config.block_size)
block_shape=test_config.block_size,
)
fused_experts = DeepGemmExperts(quant_config)
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
return mk
def make_modular_kernel(
pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
num_local_experts: int, test_tensors: TestTensors,
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
num_local_experts: int,
test_tensors: TestTensors,
quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
q_dtype = torch.float8_e4m3fn
test_config = test_tensors.config
mk: FusedMoEModularKernel
# Make modular kernel
if test_config.low_latency:
max_tokens_per_rank = max(
64, next_power_of_2(test_tensors.rank_tokens.size(0)))
max_tokens_per_rank = max(64, next_power_of_2(test_tensors.rank_tokens.size(0)))
hidden_size = test_tensors.rank_tokens.size(-1)
mk = make_ll_modular_kernel(pg=pg,
pgi=pgi,
max_tokens_per_rank=max_tokens_per_rank,
dp_size=dp_size,
hidden_size=hidden_size,
q_dtype=q_dtype,
test_config=test_config,
quant_config=quant_config)
mk = make_ll_modular_kernel(
pg=pg,
pgi=pgi,
max_tokens_per_rank=max_tokens_per_rank,
dp_size=dp_size,
hidden_size=hidden_size,
q_dtype=q_dtype,
test_config=test_config,
quant_config=quant_config,
)
else:
mk = make_ht_modular_kernel(pg,
pgi,
dp_size,
num_local_experts,
q_dtype,
test_config,
quant_config=quant_config)
mk = make_ht_modular_kernel(
pg,
pgi,
dp_size,
num_local_experts,
q_dtype,
test_config,
quant_config=quant_config,
)
return mk
def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
dp_size: int, test_tensors: TestTensors,
w1: torch.Tensor, w2: torch.Tensor,
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor]) -> torch.Tensor:
def deepep_deepgemm_moe_impl(
pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
test_tensors: TestTensors,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
) -> torch.Tensor:
test_config = test_tensors.config
num_experts = test_config.num_experts
num_local_experts = w1.size(0)
def build_expert_map():
num_local_experts = w1.size(0)
expert_map = torch.full((num_experts, ),
fill_value=-1,
dtype=torch.int32)
expert_map = torch.full((num_experts,), fill_value=-1, dtype=torch.int32)
s = pgi.rank * num_local_experts
e = s + num_local_experts
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
return expert_map.to(device=torch.cuda.current_device(),
dtype=torch.int32)
return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32)
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
# Low-Latency kernels can't dispatch scales.
a1_scale=(None if test_config.low_latency else
test_tensors.rank_token_scales),
a1_scale=(None if test_config.low_latency else test_tensors.rank_token_scales),
block_shape=test_config.block_size,
)
@@ -263,26 +279,35 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
dp_size=dp_size,
num_local_experts=num_local_experts,
test_tensors=test_tensors,
quant_config=quant_config)
quant_config=quant_config,
)
out = mk.forward(hidden_states=test_tensors.rank_tokens,
w1=w1,
w2=w2,
topk_weights=test_tensors.topk_weights,
topk_ids=test_tensors.topk,
inplace=False,
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),
apply_router_weight_on_input=False)
out = mk.forward(
hidden_states=test_tensors.rank_tokens,
w1=w1,
w2=w2,
topk_weights=test_tensors.topk_weights,
topk_ids=test_tensors.topk,
inplace=False,
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),
apply_router_weight_on_input=False,
)
return out
def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
topk_weights: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
a1_scale: torch.Tensor, block_shape: list[int]):
def triton_impl(
a: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: torch.Tensor,
block_shape: list[int],
):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
@@ -300,7 +325,8 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
quant_config=quant_config,
# Make sure this is set to False so we
# don't end up comparing the same implementation.
allow_deep_gemm=False)
allow_deep_gemm=False,
)
def _test_deepep_deepgemm_moe(
@@ -321,22 +347,21 @@ def _test_deepep_deepgemm_moe(
pg = torch.distributed.new_group(list(range(pgi.world_size)))
test_tensors = TestTensors.make(config, pgi.rank)
block_shape = [
w1.size(1) // w1_scale.size(1),
w1.size(2) // w1_scale.size(2)
]
block_shape = [w1.size(1) // w1_scale.size(1), w1.size(2) // w1_scale.size(2)]
with set_current_vllm_config(VllmConfig()):
# Reference
triton_moe = triton_impl(a=test_tensors.rank_tokens,
topk_ids=test_tensors.topk,
topk_weights=test_tensors.topk_weights,
w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=test_tensors.rank_token_scales,
block_shape=block_shape)
triton_moe = triton_impl(
a=test_tensors.rank_tokens,
topk_ids=test_tensors.topk,
topk_weights=test_tensors.topk_weights,
w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=test_tensors.rank_token_scales,
block_shape=block_shape,
)
# Slice experts for this rank.
num_local_experts = config.num_experts // pgi.world_size
@@ -390,10 +415,15 @@ NUM_EXPERTS = [32]
@multi_gpu_test(num_gpus=2)
@requires_deep_ep
@requires_deep_gemm
@pytest.mark.skipif(is_deep_gemm_e8m0_used(),
reason="Skipping test for Blackwell DeepGEMM")
def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
topk: int, world_dp_size: tuple[int, int]):
@pytest.mark.skipif(
is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM"
)
def test_ht_deepep_deepgemm_moe(
mnk: tuple[int, int, int],
num_experts: int,
topk: int,
world_dp_size: tuple[int, int],
):
"""
Tests for High-Throughput DeepEP + DeepGemm integration.
"""
@@ -409,21 +439,32 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
block_size = [block_m, block_m]
world_size, dp_size = world_dp_size
config = TestConfig(topk=topk,
m=m,
k=k,
n=n,
num_experts=num_experts,
per_act_token_quant=False,
block_size=block_size,
low_latency=False,
use_fp8_dispatch=None)
config = TestConfig(
topk=topk,
m=m,
k=k,
n=n,
num_experts=num_experts,
per_act_token_quant=False,
block_size=block_size,
low_latency=False,
use_fp8_dispatch=None,
)
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
num_experts, n, k, block_size)
num_experts, n, k, block_size
)
parallel_launch(world_size, _test_deepep_deepgemm_moe, dp_size, config, w1,
w2, w1_scale, w2_scale)
parallel_launch(
world_size,
_test_deepep_deepgemm_moe,
dp_size,
config,
w1,
w2,
w1_scale,
w2_scale,
)
MNKs = [
@@ -448,8 +489,9 @@ USE_FP8_DISPATCH = [False]
@multi_gpu_test(num_gpus=2)
@requires_deep_ep
@requires_deep_gemm
@pytest.mark.skipif(is_deep_gemm_e8m0_used(),
reason="Skipping test for Blackwell DeepGEMM")
@pytest.mark.skipif(
is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM"
)
def test_ll_deepep_deepgemm_moe(
mnk: tuple[int, int, int],
num_experts: int,
@@ -482,7 +524,16 @@ def test_ll_deepep_deepgemm_moe(
)
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
num_experts, n, k, block_size)
num_experts, n, k, block_size
)
parallel_launch(world_size, _test_deepep_deepgemm_moe, dp_size, config, w1,
w2, w1_scale, w2_scale)
parallel_launch(
world_size,
_test_deepep_deepgemm_moe,
dp_size,
config,
w1,
w2,
w1_scale,
w2_scale,
)

View File

@@ -16,12 +16,11 @@ from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
per_token_group_quant_fp8,
)
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep
@@ -30,9 +29,11 @@ from .parallel_utils import ProcessGroupInfo, parallel_launch
if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
DeepEPHTPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
DeepEPLLPrepareAndFinalize,
)
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
@@ -45,7 +46,7 @@ MAX_TOKENS_PER_RANK = 64
def make_weights(
e, n, k, dtype
e, n, k, dtype
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Return weights w1, w2, w1_scale, w2_scale
@@ -64,17 +65,15 @@ def make_weights(
k_b_scales = k
w1_q = torch.empty_like(w1, dtype=dtype)
w2_q = torch.empty_like(w2, dtype=dtype)
w1_scale = torch.empty((e, n_b_scales, 1),
device="cuda",
dtype=torch.float32)
w2_scale = torch.empty((e, k_b_scales, 1),
device="cuda",
dtype=torch.float32)
w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32)
w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32)
for expert in range(e):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
w1[expert], use_per_token_if_dynamic=True)
w1[expert], use_per_token_if_dynamic=True
)
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
w2[expert], use_per_token_if_dynamic=True)
w2[expert], use_per_token_if_dynamic=True
)
return w1_q, w2_q, w1_scale, w2_scale
@@ -100,24 +99,25 @@ class TestTensors:
def make(config: TestConfig, low_latency_mode: bool) -> "TestTensors":
# TODO (varun) - check that float16 works ?
assert config.dtype in [torch.bfloat16, torch.float8_e4m3fn]
token_dtype = (torch.bfloat16 if config.dtype == torch.float8_e4m3fn
else config.dtype)
rank_tokens = torch.randn(
(config.m, config.k), device="cuda", dtype=token_dtype) / 10
token_dtype = (
torch.bfloat16 if config.dtype == torch.float8_e4m3fn else config.dtype
)
rank_tokens = (
torch.randn((config.m, config.k), device="cuda", dtype=token_dtype) / 10
)
rank_token_scales = None
topk = torch.randint(low=0,
high=config.num_experts,
size=(config.m, config.topk),
device="cuda").to(dtype=torch.int64)
topk_weights = torch.randn(topk.shape,
dtype=torch.float32,
device="cuda")
return TestTensors(rank_tokens=rank_tokens,
rank_token_scales=rank_token_scales,
topk=topk,
topk_weights=topk_weights,
config=config)
topk = torch.randint(
low=0, high=config.num_experts, size=(config.m, config.topk), device="cuda"
).to(dtype=torch.int64)
topk_weights = torch.randn(topk.shape, dtype=torch.float32, device="cuda")
return TestTensors(
rank_tokens=rank_tokens,
rank_token_scales=rank_token_scales,
topk=topk,
topk_weights=topk_weights,
config=config,
)
def make_modular_kernel(
@@ -132,28 +132,33 @@ def make_modular_kernel(
use_fp8_dispatch: bool,
quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
ht_args: Optional[DeepEPHTArgs] = None
ll_args: Optional[DeepEPLLArgs] = None
if low_latency_mode:
ll_args = DeepEPLLArgs(max_tokens_per_rank=MAX_TOKENS_PER_RANK,
hidden_size=hidden_size,
num_experts=num_experts,
use_fp8_dispatch=use_fp8_dispatch)
ll_args = DeepEPLLArgs(
max_tokens_per_rank=MAX_TOKENS_PER_RANK,
hidden_size=hidden_size,
num_experts=num_experts,
use_fp8_dispatch=use_fp8_dispatch,
)
else:
assert not use_fp8_dispatch, (
"FP8 Dispatch is valid only for low-latency kernels")
"FP8 Dispatch is valid only for low-latency kernels"
)
ht_args = DeepEPHTArgs(num_local_experts=num_local_experts)
a2a : Union[DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize] = \
make_deepep_a2a(pg = pg,
pgi = pgi,
dp_size = dp_size,
q_dtype = q_dtype,
block_shape = None,
deepep_ht_args = ht_args,
deepep_ll_args = ll_args)
a2a: Union[DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize] = (
make_deepep_a2a(
pg=pg,
pgi=pgi,
dp_size=dp_size,
q_dtype=q_dtype,
block_shape=None,
deepep_ht_args=ht_args,
deepep_ll_args=ll_args,
)
)
num_dispatchers = pgi.world_size // dp_size
@@ -167,8 +172,7 @@ def make_modular_kernel(
else:
fused_experts = TritonExperts(quant_config=quant_config)
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
return mk
@@ -186,19 +190,15 @@ def deep_ep_moe_impl(
use_fp8_dispatch: bool,
per_act_token_quant: bool,
) -> torch.Tensor:
num_local_experts = w1.size(0)
def build_expert_map():
num_local_experts = w1.size(0)
expert_map = torch.full((num_experts, ),
fill_value=-1,
dtype=torch.int32)
expert_map = torch.full((num_experts,), fill_value=-1, dtype=torch.int32)
s = pgi.rank * num_local_experts
e = s + num_local_experts
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
return expert_map.to(device=torch.cuda.current_device(),
dtype=torch.int32)
return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32)
hidden_size = test_tensors.rank_tokens.size(1)
is_quantized = w1.dtype == torch.float8_e4m3fn
@@ -214,11 +214,12 @@ def deep_ep_moe_impl(
topk_weights_chunk = test_tensors.topk_weights[chunk_start:chunk_end]
topk_chunk = test_tensors.topk[chunk_start:chunk_end]
rank_token_scales_chunk = test_tensors.rank_token_scales
if rank_token_scales_chunk is not None and rank_token_scales_chunk.size(
0) == total_num_tokens:
if (
rank_token_scales_chunk is not None
and rank_token_scales_chunk.size(0) == total_num_tokens
):
# per act token
rank_token_scales_chunk = rank_token_scales_chunk[
chunk_start:chunk_end]
rank_token_scales_chunk = rank_token_scales_chunk[chunk_start:chunk_end]
quant_config = FusedMoEQuantConfig.make(
q_dtype,
@@ -230,26 +231,37 @@ def deep_ep_moe_impl(
# Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(
pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts,
num_local_experts, q_dtype, use_fp8_dispatch, quant_config)
pg,
pgi,
low_latency_mode,
hidden_size,
dp_size,
num_experts,
num_local_experts,
q_dtype,
use_fp8_dispatch,
quant_config,
)
out = mk.forward(hidden_states=rank_tokens_chunk,
w1=w1,
w2=w2,
topk_weights=topk_weights_chunk,
topk_ids=topk_chunk,
inplace=False,
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),
apply_router_weight_on_input=False)
out = mk.forward(
hidden_states=rank_tokens_chunk,
w1=w1,
w2=w2,
topk_weights=topk_weights_chunk,
topk_ids=topk_chunk,
inplace=False,
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),
apply_router_weight_on_input=False,
)
if not skip_result_store:
out_hidden_states[chunk_start:chunk_end, :].copy_(
out, non_blocking=True)
out_hidden_states[chunk_start:chunk_end, :].copy_(out, non_blocking=True)
max_num_tokens_per_dp = (MAX_TOKENS_PER_RANK
if low_latency_mode else total_num_tokens)
max_num_tokens_per_dp = (
MAX_TOKENS_PER_RANK if low_latency_mode else total_num_tokens
)
for chunk_start_ in range(0, total_num_tokens, max_num_tokens_per_dp):
chunk_start = chunk_start_
@@ -258,9 +270,9 @@ def deep_ep_moe_impl(
chunk_start = min(chunk_start, total_num_tokens - 1)
chunk_end = min(chunk_end, total_num_tokens)
process_chunk(chunk_start,
chunk_end,
skip_result_store=chunk_start_ >= total_num_tokens)
process_chunk(
chunk_start, chunk_end, skip_result_store=chunk_start_ >= total_num_tokens
)
return out_hidden_states
@@ -274,9 +286,11 @@ def torch_moe_impl(
using_fp8_dispatch: bool,
per_act_token_quant: bool,
):
a, topk_ids, topk_weights = (test_tensors.rank_tokens, test_tensors.topk,
test_tensors.topk_weights)
a, topk_ids, topk_weights = (
test_tensors.rank_tokens,
test_tensors.topk,
test_tensors.topk_weights,
)
if using_fp8_dispatch:
# The DeepEP implementation is requested to dispatch using FP8.
# For numerical stability for testing, emulate the fp8 dispatch by
@@ -284,8 +298,11 @@ def torch_moe_impl(
assert not per_act_token_quant
a = test_tensors.rank_tokens
aq, aq_scale = per_token_group_quant_fp8(a, 128)
a = (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)).view(
a.shape).to(a.dtype)
a = (
(aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1))
.view(a.shape)
.to(a.dtype)
)
is_quantized = w1.dtype == torch.float8_e4m3fn
a_dtype = a.dtype
@@ -306,8 +323,9 @@ def torch_moe_impl(
e_w = topk_weights[i][j]
w1_e = w1[e]
w2_e = w2[e]
o_i += (SiluAndMul()
(a_i @ w1_e.transpose(0, 1)) @ w2_e.transpose(0, 1)) * e_w
o_i += (
SiluAndMul()(a_i @ w1_e.transpose(0, 1)) @ w2_e.transpose(0, 1)
) * e_w
if is_quantized:
out = out.to(dtype=a_dtype)
@@ -327,28 +345,36 @@ def _deep_ep_moe(
use_fp8_dispatch: bool,
per_act_token_quant: bool,
):
if not low_latency_mode:
assert not use_fp8_dispatch, (
"FP8 dispatch interface is available only in low-latency mode")
"FP8 dispatch interface is available only in low-latency mode"
)
is_quantized = w1.dtype == torch.float8_e4m3fn
w1 = w1.to(device=torch.cuda.current_device())
w2 = w2.to(device=torch.cuda.current_device())
if is_quantized:
w1_scale = w1_scale.to( # type: ignore
device=torch.cuda.current_device())
device=torch.cuda.current_device()
)
w2_scale = w2_scale.to( # type: ignore
device=torch.cuda.current_device())
device=torch.cuda.current_device()
)
pg = torch.distributed.new_group(list(range(pgi.world_size)))
test_tensors = TestTensors.make(config, low_latency_mode)
with set_current_vllm_config(VllmConfig()):
# Reference
torch_combined = torch_moe_impl(test_tensors, w1, w2, w1_scale,
w2_scale, use_fp8_dispatch,
per_act_token_quant)
torch_combined = torch_moe_impl(
test_tensors,
w1,
w2,
w1_scale,
w2_scale,
use_fp8_dispatch,
per_act_token_quant,
)
# Splice experts for this rank.
num_local_experts = config.num_experts // pgi.world_size
@@ -420,18 +446,23 @@ def test_deep_ep_moe(
current_platform.seed_everything(7)
world_size, dp_size = world_dp_size
config = TestConfig(dtype=dtype,
topk=topk,
m=m,
k=k,
n=n,
num_experts=num_experts)
config = TestConfig(dtype=dtype, topk=topk, m=m, k=k, n=n, num_experts=num_experts)
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch,
per_act_token_quant)
parallel_launch(
world_size,
_deep_ep_moe,
low_latency_mode,
dp_size,
config,
w1,
w2,
w1_scale,
w2_scale,
use_fp8_dispatch,
per_act_token_quant,
)
MNKs = [
@@ -467,8 +498,7 @@ def test_low_latency_deep_ep_moe(
):
low_latency_mode = True
if (low_latency_mode
and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES):
if low_latency_mode and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES:
pytest.skip(
f"Skipping test as hidden size {k} is not in list of supported "
f"hidden sizes {DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES}"
@@ -476,15 +506,20 @@ def test_low_latency_deep_ep_moe(
current_platform.seed_everything(7)
world_size, dp_size = world_dp_size
config = TestConfig(dtype=dtype,
topk=topk,
m=m,
k=k,
n=n,
num_experts=num_experts)
config = TestConfig(dtype=dtype, topk=topk, m=m, k=k, n=n, num_experts=num_experts)
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch,
False)
parallel_launch(
world_size,
_deep_ep_moe,
low_latency_mode,
dp_size,
config,
w1,
w2,
w1_scale,
w2_scale,
use_fp8_dispatch,
False,
)

View File

@@ -11,14 +11,18 @@ import math
import pytest
import torch
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.utils.deep_gemm import (calc_diff, is_deep_gemm_supported,
per_block_cast_to_fp8)
per_token_group_quant_fp8,
)
from vllm.utils.deep_gemm import (
calc_diff,
is_deep_gemm_supported,
per_block_cast_to_fp8,
)
BLOCK_SIZE = [128, 128]
@@ -37,8 +41,10 @@ def make_block_quant_fp8_weights(
w2 shape: (E, K, N)
"""
dtype = torch.bfloat16
fp8_max, fp8_min = torch.finfo(torch.float8_e4m3fn).max, torch.finfo(
torch.float8_e4m3fn).min
fp8_max, fp8_min = (
torch.finfo(torch.float8_e4m3fn).max,
torch.finfo(torch.float8_e4m3fn).min,
)
# bf16 reference weights
w1_bf16 = torch.randn(e, 2 * n, k, device="cuda", dtype=dtype) / 10
@@ -54,24 +60,16 @@ def make_block_quant_fp8_weights(
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
w1_s = torch.empty(e,
n_tiles_w1,
k_tiles_w1,
device="cuda",
dtype=torch.float32)
w2_s = torch.empty(e,
n_tiles_w2,
k_tiles_w2,
device="cuda",
dtype=torch.float32)
w1_s = torch.empty(e, n_tiles_w1, k_tiles_w1, device="cuda", dtype=torch.float32)
w2_s = torch.empty(e, n_tiles_w2, k_tiles_w2, device="cuda", dtype=torch.float32)
for i in range(e):
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i],
block_size=block_size,
use_ue8m0=True)
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i],
block_size=block_size,
use_ue8m0=True)
w1[i], w1_s[i] = per_block_cast_to_fp8(
w1_bf16[i], block_size=block_size, use_ue8m0=True
)
w2[i], w2_s[i] = per_block_cast_to_fp8(
w2_bf16[i], block_size=block_size, use_ue8m0=True
)
return w1, w2, w1_s, w2_s
@@ -81,18 +79,17 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
Run one (M,N,K) configuration on a single GPU and assert DeepGEMM ==
Triton baseline within tolerance.
"""
tokens_bf16 = torch.randn(
m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1)
tokens_bf16 = (
torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
.clamp_min_(-1)
.clamp_max_(1)
)
_, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1])
# expert weight tensors
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k,
block_size)
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k, block_size)
router_logits = torch.randn(m,
num_experts,
device="cuda",
dtype=torch.float32)
router_logits = torch.randn(m, num_experts, device="cuda", dtype=torch.float32)
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
@@ -147,15 +144,14 @@ NUM_EXPERTS = [32]
@pytest.mark.parametrize(("m", "n", "k"), MNKs)
@pytest.mark.parametrize("topk", TOPKS)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.skipif(not is_deep_gemm_supported(),
reason="Requires deep_gemm kernels")
@pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels")
def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch):
with monkeypatch.context() as mp:
mp.setenv("VLLM_USE_DEEP_GEMM", "1")
_fused_moe_mod = importlib.import_module(
"vllm.model_executor.layers.fused_moe.fused_moe")
"vllm.model_executor.layers.fused_moe.fused_moe"
)
call_counter = {"cnt": 0}
@@ -165,8 +161,7 @@ def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch):
call_counter["cnt"] += 1
return orig_fn(*args, **kwargs)
monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8",
_spy_deep_gemm_moe_fp8)
monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8", _spy_deep_gemm_moe_fp8)
if topk > num_experts:
pytest.skip(f"topk={topk} > num_experts={num_experts}")
@@ -181,6 +176,7 @@ def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch):
)
# ensure that the DeepGEMM path was indeed taken.
assert call_counter["cnt"] == 1, \
f"DeepGEMM path was not executed during the test. " \
assert call_counter["cnt"] == 1, (
f"DeepGEMM path was not executed during the test. "
f"Call counter: {call_counter['cnt']}"
)

View File

@@ -6,24 +6,28 @@ import pytest
import torch
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8, flashinfer_cutlass_moe_fp8,
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
swap_w13_to_w31)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
input_to_float8)
apply_flashinfer_per_tensor_scale_fp8,
flashinfer_cutlass_moe_fp8,
register_moe_scaling_factors,
rotate_flashinfer_fp8_moe_weights,
swap_w13_to_w31,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8
from vllm.model_executor.models.llama4 import Llama4MoE
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
if not has_flashinfer_cutlass_fused_moe(
) or not current_platform.has_device_capability(100):
pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support",
allow_module_level=True)
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
100
):
pytest.skip(
"Requires flashinfer_cutlass_fused_moe and nvfp4 support",
allow_module_level=True,
)
NUM_EXPERTS = [16]
TOP_KS = [1]
@@ -39,8 +43,7 @@ MNK_FACTORS = [
(1, 4096, 5120),
]
vllm_config = VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@@ -74,18 +77,17 @@ class TestData:
layer: torch.nn.Module
@staticmethod
def make_moe_tensors_8bit(m: int, k: int, n: int, e: int,
reorder: bool) -> "TestData":
hidden_states = torch.randn(
(m, k), device="cuda", dtype=torch.bfloat16) / 10
def make_moe_tensors_8bit(
m: int, k: int, n: int, e: int, reorder: bool
) -> "TestData":
hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16)
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)
# Scale to fp8
_, a1_scale = input_to_float8(hidden_states)
a1_scale = 1.0 / a1_scale
a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(
dtype=torch.float32)
a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(dtype=torch.float32)
w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)
@@ -102,8 +104,7 @@ class TestData:
# flashinfer expects swapped rows for w13
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
if reorder:
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
layer.w2_weight)
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
layer.custom_routing_function = Llama4MoE.custom_routing_function
layer.intermediate_size_per_partition = n
layer.ep_rank = 0
@@ -145,7 +146,8 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
top_k=topk,
renormalize=False,
custom_routing_function=Llama4MoE.custom_routing_function,
scoring_func="softmax")
scoring_func="softmax",
)
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=td.w13_weight_scale,
@@ -178,12 +180,10 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
top_k=topk,
num_expert_group=None,
topk_group=None,
apply_router_weight_on_input=True)
apply_router_weight_on_input=True,
)
torch.testing.assert_close(output,
flashinfer_output,
atol=5.5e-2,
rtol=1e-2)
torch.testing.assert_close(output, flashinfer_output, atol=5.5e-2, rtol=1e-2)
@pytest.mark.skip(
@@ -213,7 +213,8 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
top_k=topk,
renormalize=False,
custom_routing_function=Llama4MoE.custom_routing_function,
scoring_func="softmax")
scoring_func="softmax",
)
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=td.w13_weight_scale,
@@ -250,7 +251,6 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
apply_router_weight_on_input=True,
)
torch.testing.assert_close(output,
flashinfer_cutlass_output,
atol=5.5e-2,
rtol=1e-2)
torch.testing.assert_close(
output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2
)

View File

@@ -4,26 +4,33 @@ import pytest
import torch
from tests.kernels.moe.utils import make_test_quant_config
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from tests.kernels.quantization.nvfp4_utils import (
FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype,
)
from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe)
FlashInferExperts,
is_valid_flashinfer_cutlass_fused_moe,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
MoEPrepareAndFinalizeNoEP,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
if not has_flashinfer_cutlass_fused_moe(
) or not current_platform.has_device_capability(100):
pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support",
allow_module_level=True)
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
100
):
pytest.skip(
"Requires flashinfer_cutlass_fused_moe and nvfp4 support",
allow_module_level=True,
)
MNK_FACTORS = [
(2, 1024, 1024),
@@ -44,13 +51,13 @@ MNK_FACTORS = [
@pytest.mark.parametrize("topk", [1, 6, 8])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
@torch.inference_mode()
def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
dtype: torch.dtype):
def test_flashinfer_fp4_moe_no_graph(
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
):
current_platform.seed_everything(7)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
quant_blocksize = 16
@@ -66,10 +73,7 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a,
score,
topk,
renormalize=False)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
@@ -87,16 +91,19 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
)
# Reference check:
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(a.flatten(), dim=-1)).to(torch.float32)
a_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1)
).to(torch.float32)
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
_, m_k = a_fp4.shape
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
a_scale_interleaved,
a_global_scale,
dtype=a.dtype,
device=a.device,
block_size=quant_blocksize)
a_in_dtype = dequantize_nvfp4_to_dtype(
a_fp4,
a_scale_interleaved,
a_global_scale,
dtype=a.dtype,
device=a.device,
block_size=quant_blocksize,
)
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
@@ -104,23 +111,26 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
for idx in range(0, e):
w1_d[idx] = dequantize_nvfp4_to_dtype(
w1_q[idx],
quant_config.w1_scale[idx], (1 / quant_config.g1_alphas[idx]),
quant_config.w1_scale[idx],
(1 / quant_config.g1_alphas[idx]),
dtype=dtype,
device=w1_q.device,
block_size=quant_blocksize)
block_size=quant_blocksize,
)
w2_d[idx] = dequantize_nvfp4_to_dtype(
w2_q[idx],
quant_config.w2_scale[idx], (1 / quant_config.g2_alphas[idx]),
quant_config.w2_scale[idx],
(1 / quant_config.g2_alphas[idx]),
dtype=dtype,
device=w2_q.device,
block_size=quant_blocksize)
block_size=quant_blocksize,
)
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
torch.testing.assert_close(torch_output,
flashinfer_output,
atol=1e-1,
rtol=1e-1)
torch.testing.assert_close(
torch_output, flashinfer_output, atol=1e-1, rtol=1e-1
)
if __name__ == "__main__":

View File

@@ -17,20 +17,21 @@ if not has_triton_kernels():
import triton_kernels.swiglu
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
from triton_kernels.numerics import InFlexData
from triton_kernels.numerics_details.mxfp import (downcast_to_mxfp,
upcast_from_mxfp)
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
from triton_kernels.testing import assert_close
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize)
BatchedPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
BatchedOAITritonExperts, triton_kernel_moe_forward)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
BatchedOAITritonExperts,
triton_kernel_moe_forward,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.utils import shuffle_weight
from vllm.utils import round_up
@@ -46,13 +47,11 @@ def deshuffle(w: torch.Tensor):
def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
randbits = [torch.randperm(E) for _ in range(M)]
x_list = [
(-1)**i *
((16384 +
((i * 512) % 4096) + bits).to(torch.int16).view(torch.bfloat16))
(-1) ** i
* ((16384 + ((i * 512) % 4096) + bits).to(torch.int16).view(torch.bfloat16))
for i, bits in enumerate(randbits)
]
exp_data = torch.stack(x_list).to(
device="cuda") # simulating gate_output (M, E)
exp_data = torch.stack(x_list).to(device="cuda") # simulating gate_output (M, E)
# create input tensor
x = torch.randn((M, K), dtype=torch.bfloat16, device="cuda")
@@ -120,20 +119,21 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
value=0,
)
w1_bias_tri = F.pad(w1_bias_tri, (0, w1_right_pad, 0, 0),
mode="constant",
value=0)
w2_bias_tri = F.pad(w2_bias_tri, (0, w2_right_pad, 0, 0),
mode="constant",
value=0)
w1_bias_tri = F.pad(
w1_bias_tri, (0, w1_right_pad, 0, 0), mode="constant", value=0
)
w2_bias_tri = F.pad(
w2_bias_tri, (0, w2_right_pad, 0, 0), mode="constant", value=0
)
x_tri = F.pad(x_tri, (0, x_pad, 0, 0), mode="constant", value=0)
w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(
mx_axis=1)
w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
w_scale_layout, w_scale_layout_opts = (
layout.make_default_matmul_mxfp4_w_scale_layout(
mx_axis=1, num_warps=num_warps))
mx_axis=1, num_warps=num_warps
)
)
w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1)
w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, torch.bfloat16, axis=1)
@@ -141,29 +141,33 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1)
w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, torch.bfloat16, axis=1)
w1_tri = convert_layout(wrap_torch_tensor(w1_tri, FP4), w_layout,
**w_layout_opts)
w1_tri = convert_layout(
wrap_torch_tensor(w1_tri, FP4), w_layout, **w_layout_opts
)
w1_scale_tri = convert_layout(
wrap_torch_tensor(w1_scale_tri),
w_scale_layout,
**w_scale_layout_opts,
)
w2_tri = convert_layout(wrap_torch_tensor(w2_tri, FP4), w_layout,
**w_layout_opts)
w2_tri = convert_layout(
wrap_torch_tensor(w2_tri, FP4), w_layout, **w_layout_opts
)
w2_scale_tri = convert_layout(
wrap_torch_tensor(w2_scale_tri),
w_scale_layout,
**w_scale_layout_opts,
)
pc1 = PrecisionConfig(weight_scale=w1_scale_tri,
flex_ctx=FlexCtx(rhs_data=InFlexData()))
pc2 = PrecisionConfig(weight_scale=w2_scale_tri,
flex_ctx=FlexCtx(rhs_data=InFlexData()))
pc1 = PrecisionConfig(
weight_scale=w1_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData())
)
pc2 = PrecisionConfig(
weight_scale=w2_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData())
)
# tucuate so the rest can run properly
w1 = w1[..., :K, :2 * N]
w1 = w1[..., :K, : 2 * N]
w2 = w2[..., :N, :K]
w1 = deshuffle(w1)
@@ -261,7 +265,8 @@ class Case:
@pytest.mark.parametrize(
", ".join(f.name for f in fields(Case)),
[
tuple(getattr(case, f.name) for f in fields(Case)) for case in [
tuple(getattr(case, f.name) for f in fields(Case))
for case in [
# Case(a_dtype="bf16", w_dtype="bf16"),
# Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"),
Case(a_dtype="bf16", w_dtype="mx4")
@@ -321,10 +326,7 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
gating_output=exp_data,
topk=topk,
)
assert_close(ref=out_ref,
tri=out_triton_monolithic,
maxtol=0.025,
rmstol=0.005)
assert_close(ref=out_ref, tri=out_triton_monolithic, maxtol=0.025, rmstol=0.005)
def batched_moe(
@@ -376,7 +378,8 @@ def batched_moe(
@pytest.mark.parametrize(
", ".join(f.name for f in fields(Case)),
[
tuple(getattr(case, f.name) for f in fields(Case)) for case in [
tuple(getattr(case, f.name) for f in fields(Case))
for case in [
# Case(a_dtype="bf16", w_dtype="bf16"),
# Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"),
Case(a_dtype="bf16", w_dtype="mx4")

View File

@@ -4,16 +4,20 @@
Run `pytest tests/kernels/moe/test_grouped_topk.py`.
"""
import pytest
import torch
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_grouped_topk,
grouped_topk)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_grouped_topk,
grouped_topk,
)
from vllm.platforms import current_platform
@pytest.mark.skipif(not current_platform.is_cuda(),
reason="This test is skipped on non-CUDA platform.")
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
@pytest.mark.parametrize("n_token", [1, 33, 64])
@pytest.mark.parametrize("n_hidden", [1024, 2048])
@pytest.mark.parametrize("n_expert", [16])
@@ -23,23 +27,26 @@ from vllm.platforms import current_platform
@pytest.mark.parametrize("topk_group", [2])
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5])
@pytest.mark.parametrize("dtype",
[torch.float16, torch.bfloat16, torch.float32])
def test_grouped_topk(monkeypatch: pytest.MonkeyPatch, n_token: int,
n_hidden: int, n_expert: int, topk: int,
renormalize: bool, num_expert_group: int,
topk_group: int, scoring_func: str,
routed_scaling_factor: float, dtype: torch.dtype):
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
def test_grouped_topk(
monkeypatch: pytest.MonkeyPatch,
n_token: int,
n_hidden: int,
n_expert: int,
topk: int,
renormalize: bool,
num_expert_group: int,
topk_group: int,
scoring_func: str,
routed_scaling_factor: float,
dtype: torch.dtype,
):
current_platform.seed_everything(0)
hidden_states = torch.randn((n_token, n_hidden),
dtype=dtype,
device="cuda")
gating_output = torch.randn((n_token, n_expert),
dtype=dtype,
device="cuda")
e_score_correction_bias = torch.randn((n_expert, ),
dtype=torch.float32,
device="cuda")
hidden_states = torch.randn((n_token, n_hidden), dtype=dtype, device="cuda")
gating_output = torch.randn((n_token, n_expert), dtype=dtype, device="cuda")
e_score_correction_bias = torch.randn(
(n_expert,), dtype=torch.float32, device="cuda"
)
with monkeypatch.context() as m:
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
@@ -52,7 +59,8 @@ def test_grouped_topk(monkeypatch: pytest.MonkeyPatch, n_token: int,
topk_group=topk_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
)
test_topk_weights, test_topk_ids = fused_grouped_topk(
hidden_states=hidden_states,
@@ -63,14 +71,11 @@ def test_grouped_topk(monkeypatch: pytest.MonkeyPatch, n_token: int,
topk_group=topk_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
)
if renormalize:
torch.testing.assert_close(baseline_topk_weights,
test_topk_weights,
atol=2e-2,
rtol=0)
torch.testing.assert_close(baseline_topk_ids,
test_topk_ids,
atol=0,
rtol=0)
torch.testing.assert_close(
baseline_topk_weights, test_topk_weights, atol=2e-2, rtol=0
)
torch.testing.assert_close(baseline_topk_ids, test_topk_ids, atol=0, rtol=0)

View File

@@ -17,18 +17,29 @@ from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from ...utils import multi_gpu_test
from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors,
reference_moe_impl,
run_modular_kernel)
from .modular_kernel_tools.common import (
Config,
RankTensors,
WeightTensors,
reference_moe_impl,
run_modular_kernel,
)
from .modular_kernel_tools.mk_objects import (
MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, TestMoEQuantConfig,
expert_info)
from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo,
parallel_launch_with_config)
MK_FUSED_EXPERT_TYPES,
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
MK_QUANT_CONFIGS,
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES,
TestMoEQuantConfig,
expert_info,
)
from .modular_kernel_tools.parallel_utils import (
ProcessGroupInfo,
parallel_launch_with_config,
)
has_any_multi_gpu_package = (has_deep_ep() or has_deep_gemm() or has_pplx()
or has_flashinfer_cutlass_fused_moe())
has_any_multi_gpu_package = (
has_deep_ep() or has_deep_gemm() or has_pplx() or has_flashinfer_cutlass_fused_moe()
)
meets_multi_gpu_requirements = pytest.mark.skipif(
not has_any_multi_gpu_package,
@@ -64,9 +75,9 @@ def rank_worker(
# sanity check
from vllm import envs
if base_config.fused_moe_chunk_size is not None:
assert (
base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
assert base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
# get weights to this device
weights.to_current_device()
@@ -93,8 +104,7 @@ def rank_worker(
rank_tensors = RankTensors.make(config, pgi)
# modular kernel out
mk_out = run_modular_kernel(pgi, vllm_config, config, weights,
rank_tensors)
mk_out = run_modular_kernel(pgi, vllm_config, config, weights, rank_tensors)
with set_current_vllm_config(vllm_config):
ref_out = reference_moe_impl(config, weights, rank_tensors)
@@ -115,10 +125,10 @@ def rank_worker(
if len(exceptions) > 0:
raise RuntimeError(
f"{len(exceptions)} of {count} tests failed in child process, "
f"rank={pgi.rank}.")
f"rank={pgi.rank}."
)
else:
print(f"{count} of {count} tests passed in child process, "
f"rank={pgi.rank}.")
print(f"{count} of {count} tests passed in child process, rank={pgi.rank}.")
def run(config: Config, verbose: bool):
@@ -127,8 +137,9 @@ def run(config: Config, verbose: bool):
weights: WeightTensors = WeightTensors.make(config)
vllm_config, env_dict = config.make_env_data()
parallel_launch_with_config(config.world_size, rank_worker, vllm_config,
env_dict, config, weights, verbose)
parallel_launch_with_config(
config.world_size, rank_worker, vllm_config, env_dict, config, weights, verbose
)
Ms = [32, 64]
@@ -149,8 +160,9 @@ def is_nyi_config(config: Config) -> bool:
if info.needs_matching_quant:
# The triton kernels expect both per-act-token-quant and
# per-out-ch-quant or neither.
unsupported_quant_config = ((config.is_per_act_token_quant +
config.is_per_out_ch_quant) == 1)
unsupported_quant_config = (
config.is_per_act_token_quant + config.is_per_out_ch_quant
) == 1
return unsupported_quant_config
return not info.supports_expert_map
@@ -162,19 +174,25 @@ def is_nyi_config(config: Config) -> bool:
@pytest.mark.parametrize("dtype", DTYPEs)
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
@pytest.mark.parametrize(
"combination",
product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
"combination", product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)
)
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [2])
@multi_gpu_test(num_gpus=2)
@meets_multi_gpu_requirements
def test_modular_kernel_combinations_multigpu(
k: int, n: int, e: int, dtype: torch.dtype,
quant_config: Optional[TestMoEQuantConfig],
combination: tuple[mk.FusedMoEPrepareAndFinalize,
mk.FusedMoEPermuteExpertsUnpermute],
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
k: int,
n: int,
e: int,
dtype: torch.dtype,
quant_config: Optional[TestMoEQuantConfig],
combination: tuple[
mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute
],
fused_moe_chunk_size: Optional[int],
world_size: int,
pytestconfig,
):
config = Config(
Ms=Ms,
K=k,
@@ -195,7 +213,7 @@ def test_modular_kernel_combinations_multigpu(
if is_nyi_config(config):
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
verbosity = pytestconfig.getoption('verbose')
verbosity = pytestconfig.getoption("verbose")
run(config, verbosity > 0)
@@ -205,16 +223,23 @@ def test_modular_kernel_combinations_multigpu(
@pytest.mark.parametrize("dtype", DTYPEs)
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
@pytest.mark.parametrize(
"combination",
product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
"combination", product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)
)
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [1])
def test_modular_kernel_combinations_singlegpu(
k: int, n: int, e: int, dtype: torch.dtype,
quant_config: Optional[TestMoEQuantConfig],
combination: tuple[mk.FusedMoEPrepareAndFinalize,
mk.FusedMoEPermuteExpertsUnpermute],
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
k: int,
n: int,
e: int,
dtype: torch.dtype,
quant_config: Optional[TestMoEQuantConfig],
combination: tuple[
mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute
],
fused_moe_chunk_size: Optional[int],
world_size: int,
pytestconfig,
):
config = Config(
Ms=Ms,
K=k,
@@ -235,19 +260,21 @@ def test_modular_kernel_combinations_singlegpu(
if is_nyi_config(config):
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
verbosity = pytestconfig.getoption('verbose')
verbosity = pytestconfig.getoption("verbose")
run(config, verbosity > 0)
if __name__ == '__main__':
if __name__ == "__main__":
# Ability to test individual PrepareAndFinalize and FusedExperts combination
from .modular_kernel_tools.cli_args import (make_config,
make_config_arg_parser)
parser = make_config_arg_parser(description=(
"Run single prepare-finalize & fused-experts combination test"
"Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations " #noqa: E501
"--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
))
from .modular_kernel_tools.cli_args import make_config, make_config_arg_parser
parser = make_config_arg_parser(
description=(
"Run single prepare-finalize & fused-experts combination test"
"Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations " # noqa: E501
"--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
)
)
args = parser.parse_args()
config = make_config(args)

View File

@@ -4,6 +4,7 @@
Run `pytest tests/kernels/test_moe.py`.
"""
import functools
from typing import Callable, Optional, Union
@@ -21,22 +22,32 @@ from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config)
FUSED_MOE_UNQUANTIZED_CONFIG,
int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
fused_topk,
modular_triton_fused_moe,
)
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as iterative_moe)
fused_moe as iterative_moe,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_permute_bias)
marlin_permute_bias,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
rand_marlin_weight_mxfp4_like, rand_marlin_weight_nvfp4_like)
rand_marlin_weight_mxfp4_like,
rand_marlin_weight_nvfp4_like,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
marlin_quant_fp8_torch)
marlin_quant_fp8_torch,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
awq_marlin_quantize, marlin_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
quantize_weights)
awq_marlin_quantize,
marlin_quantize,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_weights
from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
@@ -87,13 +98,15 @@ def run_moe_test(
if isinstance(baseline, torch.Tensor):
baseline_output = baseline
else:
baseline_output = baseline(a,
w1,
w2,
score,
topk,
global_num_experts=global_num_experts,
expert_map=expert_map)
baseline_output = baseline(
a,
w1,
w2,
score,
topk,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
# Pad the weight if moe padding is enabled
if padding:
@@ -105,34 +118,35 @@ def run_moe_test(
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(score, 0)
test_output = moe_fn(a,
w1,
w2,
score,
topk,
global_num_experts=global_num_experts,
expert_map=expert_map)
test_output = moe_fn(
a,
w1,
w2,
score,
topk,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
if use_cudagraph:
test_output.fill_(0)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
test_output = moe_fn(a,
w1,
w2,
score,
topk,
global_num_experts=global_num_experts,
expert_map=expert_map)
test_output = moe_fn(
a,
w1,
w2,
score,
topk,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()
torch.testing.assert_close(test_output,
baseline_output,
atol=atol,
rtol=rtol)
torch.testing.assert_close(test_output, baseline_output, atol=atol, rtol=rtol)
return baseline_output
@@ -176,11 +190,8 @@ def test_fused_moe(
if ep_size > 1:
local_e = e // ep_size
e_ids = torch.randint(0,
e, (local_e, ),
device="cuda",
dtype=torch.int32)
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
e_ids = torch.randint(0, e, (local_e,), device="cuda", dtype=torch.int32)
e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32)
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
w1 = w1[e_ids]
w2 = w2[e_ids]
@@ -204,13 +215,15 @@ def test_fused_moe(
expert_map: Optional[torch.Tensor] = None,
) -> torch.Tensor:
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
return m_fused_moe_fn(a,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map)
return m_fused_moe_fn(
a,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
fused_moe_fn = functools.partial(fused_moe, renormalize=False)
@@ -234,19 +247,22 @@ def test_fused_moe(
# setup code in case we are able to revisit this later.
use_compile = False
use_cudagraph = (n >= 1024 and k >= 1024
and current_platform.is_cuda_alike())
use_cudagraph = n >= 1024 and k >= 1024 and current_platform.is_cuda_alike()
with set_current_vllm_config(vllm_config):
baseline_output = runner(torch_moe, iterative_moe)
runner(baseline_output,
fused_moe_fn,
use_compile=use_compile,
use_cudagraph=use_cudagraph)
runner(baseline_output,
m_fused_moe,
use_compile=use_compile,
use_cudagraph=use_cudagraph)
runner(
baseline_output,
fused_moe_fn,
use_compile=use_compile,
use_cudagraph=use_cudagraph,
)
runner(
baseline_output,
m_fused_moe,
use_compile=use_compile,
use_cudagraph=use_cudagraph,
)
@pytest.mark.parametrize("m,n,k", FUSED_MOE_WN16_MNK_FACTORS)
@@ -257,9 +273,18 @@ def test_fused_moe(
@pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("weight_bits", [4, 8])
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
ep_size: int, dtype: torch.dtype, group_size: int,
has_zp: bool, weight_bits: int):
def test_fused_moe_wn16(
m: int,
n: int,
k: int,
e: int,
topk: int,
ep_size: int,
dtype: torch.dtype,
group_size: int,
has_zp: bool,
weight_bits: int,
):
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
@@ -274,35 +299,40 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
w1_ref = w1.clone()
w2_ref = w2.clone()
w1_qweight = torch.empty((e, 2 * n, k // pack_factor),
device="cuda",
dtype=torch.uint8)
w2_qweight = torch.empty((e, k, n // pack_factor),
device="cuda",
dtype=torch.uint8)
w1_scales = torch.empty((e, 2 * n, k // group_size),
device="cuda",
dtype=dtype)
w2_scales = torch.empty((e, k, n // group_size),
device="cuda",
dtype=dtype)
w1_qzeros = torch.empty((e, 2 * n // pack_factor, k // group_size),
device="cuda",
dtype=torch.uint8)
w2_qzeros = torch.empty((e, k // pack_factor, n // group_size),
device="cuda",
dtype=torch.uint8)
w1_qweight = torch.empty(
(e, 2 * n, k // pack_factor), device="cuda", dtype=torch.uint8
)
w2_qweight = torch.empty((e, k, n // pack_factor), device="cuda", dtype=torch.uint8)
w1_scales = torch.empty((e, 2 * n, k // group_size), device="cuda", dtype=dtype)
w2_scales = torch.empty((e, k, n // group_size), device="cuda", dtype=dtype)
w1_qzeros = torch.empty(
(e, 2 * n // pack_factor, k // group_size), device="cuda", dtype=torch.uint8
)
w2_qzeros = torch.empty(
(e, k // pack_factor, n // group_size), device="cuda", dtype=torch.uint8
)
for i in range(e * 2):
expert_id = i % e
if i // e == 0:
w, w_ref, w_qweight, w_scales, w_qzeros = \
w1, w1_ref, w1_qweight, w1_scales, w1_qzeros
w, w_ref, w_qweight, w_scales, w_qzeros = (
w1,
w1_ref,
w1_qweight,
w1_scales,
w1_qzeros,
)
else:
w, w_ref, w_qweight, w_scales, w_qzeros = \
w2, w2_ref, w2_qweight, w2_scales, w2_qzeros
w, w_ref, w_qweight, w_scales, w_qzeros = (
w2,
w2_ref,
w2_qweight,
w2_scales,
w2_qzeros,
)
weight, qweight, scales, qzeros = quantize_weights(
w[expert_id].T, quant_type, group_size, has_zp, False)
w[expert_id].T, quant_type, group_size, has_zp, False
)
weight = weight.T
qweight = qweight.T.contiguous().to(torch.uint8)
scales = scales.T
@@ -321,11 +351,8 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
if ep_size > 1:
local_e = e // ep_size
e_ids = torch.randint(0,
e, (local_e, ),
device="cuda",
dtype=torch.int32)
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
e_ids = torch.randint(0, e, (local_e,), device="cuda", dtype=torch.int32)
e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32)
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
w1_ref = w1_ref[e_ids]
w2_ref = w2_ref[e_ids]
@@ -344,28 +371,27 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
assert weight_bits == 8
quant_config_builder = int8_w8a16_moe_quant_config
quant_config = quant_config_builder(w1_scale=w1_scales,
w2_scale=w2_scales,
w1_zp=w1_qzeros if has_zp else None,
w2_zp=w2_qzeros if has_zp else None,
block_shape=[0, group_size])
quant_config = quant_config_builder(
w1_scale=w1_scales,
w2_scale=w2_scales,
w1_zp=w1_qzeros if has_zp else None,
w2_zp=w2_qzeros if has_zp else None,
block_shape=[0, group_size],
)
with set_current_vllm_config(vllm_config):
triton_output = fused_moe(a,
w1_qweight,
w2_qweight,
score,
topk,
renormalize=False,
global_num_experts=e,
expert_map=e_map,
quant_config=quant_config)
torch_output = torch_moe(a,
w1_ref,
w2_ref,
score,
topk,
expert_map=e_map)
triton_output = fused_moe(
a,
w1_qweight,
w2_qweight,
score,
topk,
renormalize=False,
global_num_experts=e,
expert_map=e_map,
quant_config=quant_config,
)
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, expert_map=e_map)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
@@ -373,16 +399,20 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("padding", [True, False])
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
@torch.inference_mode()
def test_mixtral_moe(dist_init, dtype: torch.dtype, padding: bool,
use_rocm_aiter: bool, monkeypatch):
def test_mixtral_moe(
dist_init, dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, monkeypatch
):
"""Make sure our Mixtral MoE implementation agrees with the one from
huggingface."""
# clear the cache before every test
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
is_rocm_aiter_moe_enabled,
)
is_rocm_aiter_moe_enabled.cache_clear()
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
@@ -390,17 +420,16 @@ def test_mixtral_moe(dist_init, dtype: torch.dtype, padding: bool,
if dtype == torch.float32:
pytest.skip("AITER ROCm test skip for float32")
monkeypatch.setenv('RANK', "0")
monkeypatch.setenv('LOCAL_RANK', "0")
monkeypatch.setenv('WORLD_SIZE', "1")
monkeypatch.setenv('MASTER_ADDR', 'localhost')
monkeypatch.setenv('MASTER_PORT', '12345')
monkeypatch.setenv("RANK", "0")
monkeypatch.setenv("LOCAL_RANK", "0")
monkeypatch.setenv("WORLD_SIZE", "1")
monkeypatch.setenv("MASTER_ADDR", "localhost")
monkeypatch.setenv("MASTER_PORT", "12345")
init_distributed_environment()
# Instantiate our and huggingface's MoE blocks
vllm_config.compilation_config.static_forward_context = dict()
with (set_current_vllm_config(vllm_config),
set_forward_context(None, vllm_config)):
with set_current_vllm_config(vllm_config), set_forward_context(None, vllm_config):
config = MixtralConfig()
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
vllm_moe = MixtralMoE(
@@ -416,27 +445,30 @@ def test_mixtral_moe(dist_init, dtype: torch.dtype, padding: bool,
# Load the weights
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
for i in range(config.num_local_experts):
weights = (hf_moe.experts[i].w1.weight.data,
hf_moe.experts[i].w3.weight.data)
weights = (
hf_moe.experts[i].w1.weight.data,
hf_moe.experts[i].w3.weight.data,
)
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
hf_inputs = torch.randn(
(1, 64, config.hidden_size)).to(dtype).to("cuda")
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
# vLLM uses 1D query [num_tokens, hidden_dim]
vllm_inputs = hf_inputs.flatten(0, 1)
# Pad the weight if moe padding is enabled
if padding:
vllm_moe.experts.w13_weight = Parameter(F.pad(
vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[...,
0:-128],
requires_grad=False)
vllm_moe.experts.w2_weight = Parameter(F.pad(
vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[...,
0:-128],
requires_grad=False)
vllm_moe.experts.w13_weight = Parameter(
F.pad(vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[
..., 0:-128
],
requires_grad=False,
)
vllm_moe.experts.w2_weight = Parameter(
F.pad(vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128],
requires_grad=False,
)
torch.cuda.synchronize()
torch.cuda.empty_cache()
@@ -453,19 +485,21 @@ def test_mixtral_moe(dist_init, dtype: torch.dtype, padding: bool,
if use_rocm_aiter:
# The values of rtol and atol are set based on the tests in ROCM AITER package. # noqa: E501
# https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174 # noqa: E501
torch.testing.assert_close(hf_states.flatten(0, 1),
vllm_states,
rtol=0.01,
atol=100)
torch.testing.assert_close(
hf_states.flatten(0, 1), vllm_states, rtol=0.01, atol=100
)
else:
torch.testing.assert_close(hf_states.flatten(0, 1),
vllm_states,
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype])
torch.testing.assert_close(
hf_states.flatten(0, 1),
vllm_states,
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype],
)
def marlin_moe_generate_valid_test_cases():
import itertools
m_list = [1, 123, 666]
n_list = [128, 1024]
k_list = [256, 2048]
@@ -484,16 +518,24 @@ def marlin_moe_generate_valid_test_cases():
]
is_k_full_list = [True, False]
all_combinations = itertools.product(m_list, n_list, k_list, e_list,
topk_list, ep_size_list, dtype_list,
group_size_list, act_order_list,
quant_type_list, is_k_full_list)
all_combinations = itertools.product(
m_list,
n_list,
k_list,
e_list,
topk_list,
ep_size_list,
dtype_list,
group_size_list,
act_order_list,
quant_type_list,
is_k_full_list,
)
def is_invalid(m, n, k, e, topk, ep_size, dtype, group_size, act_order,
quant_type, is_k_full):
if quant_type == scalar_types.float8_e4m3fn and \
group_size not in [-1, 128]:
def is_invalid(
m, n, k, e, topk, ep_size, dtype, group_size, act_order, quant_type, is_k_full
):
if quant_type == scalar_types.float8_e4m3fn and group_size not in [-1, 128]:
return False
if quant_type == scalar_types.float4_e2m1f:
if group_size not in [16, 32]:
@@ -522,9 +564,10 @@ def marlin_moe_generate_valid_test_cases():
@pytest.mark.flaky(reruns=2)
@pytest.mark.parametrize(("m, n, k, e, topk, ep_size, dtype, group_size,"
"act_order, quant_type, is_k_full"),
marlin_moe_generate_valid_test_cases())
@pytest.mark.parametrize(
("m, n, k, e, topk, ep_size, dtype, group_size,act_order, quant_type, is_k_full"),
marlin_moe_generate_valid_test_cases(),
)
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_fused_marlin_moe(
m: int,
@@ -549,7 +592,7 @@ def test_fused_marlin_moe(
if ep_size > 1:
local_e = e // ep_size
e_ids = torch.randperm(e, device="cuda", dtype=torch.int32)[:local_e]
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32)
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
w1 = w1[e_ids]
w2 = w2[e_ids]
@@ -567,11 +610,13 @@ def test_fused_marlin_moe(
for i in range(w1.shape[0]):
if quant_type == scalar_types.float4_e2m1f:
if group_size == 16:
w_ref1, qweight1, scales1, global_scale1 = \
w_ref1, qweight1, scales1, global_scale1 = (
rand_marlin_weight_nvfp4_like(w1[i], group_size)
)
else:
w_ref1, qweight1, scales1 = \
rand_marlin_weight_mxfp4_like(w1[i], group_size)
w_ref1, qweight1, scales1 = rand_marlin_weight_mxfp4_like(
w1[i], group_size
)
global_scale1 = None
w_ref1_l.append(w_ref1.T)
@@ -580,14 +625,14 @@ def test_fused_marlin_moe(
if global_scale1 is not None:
global_scale1_l.append(global_scale1)
elif quant_type == scalar_types.float8_e4m3fn:
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(
w1[i], group_size)
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(w1[i], group_size)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
elif has_zp:
w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size)
w1[i].transpose(1, 0), quant_type, group_size
)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
@@ -595,9 +640,9 @@ def test_fused_marlin_moe(
zeros1_l.append(zeros1)
else:
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
marlin_quantize(w1[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
@@ -624,11 +669,13 @@ def test_fused_marlin_moe(
for i in range(w2.shape[0]):
if quant_type == scalar_types.float4_e2m1f:
if group_size == 16:
w_ref2, qweight2, scales2, global_scale2 = \
w_ref2, qweight2, scales2, global_scale2 = (
rand_marlin_weight_nvfp4_like(w2[i], group_size)
)
else:
w_ref2, qweight2, scales2 = \
rand_marlin_weight_mxfp4_like(w2[i], group_size)
w_ref2, qweight2, scales2 = rand_marlin_weight_mxfp4_like(
w2[i], group_size
)
global_scale2 = None
w_ref2_l.append(w_ref2.T)
@@ -637,14 +684,14 @@ def test_fused_marlin_moe(
if global_scale2 is not None:
global_scale2_l.append(global_scale2)
elif quant_type == scalar_types.float8_e4m3fn:
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(
w2[i], group_size)
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(w2[i], group_size)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
elif has_zp:
w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size)
w2[i].transpose(1, 0), quant_type, group_size
)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
@@ -652,9 +699,9 @@ def test_fused_marlin_moe(
zeros2_l.append(zeros2)
else:
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
marlin_quantize(w2[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
@@ -675,12 +722,7 @@ def test_fused_marlin_moe(
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a,
w_ref1,
w_ref2,
score,
topk,
expert_map=e_map)
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map)
marlin_output = torch.ops.vllm.fused_marlin_moe(
a,
@@ -704,7 +746,8 @@ def test_fused_marlin_moe(
w1_zeros=zeros1,
w2_zeros=zeros2,
quant_type_id=quant_type.id,
is_k_full=is_k_full)
is_k_full=is_k_full,
)
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
@@ -738,9 +781,9 @@ def test_fused_marlin_moe_with_bias(m):
for i in range(w1.shape[0]):
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
marlin_quantize(w1[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
@@ -767,9 +810,9 @@ def test_fused_marlin_moe_with_bias(m):
for i in range(w2.shape[0]):
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
marlin_quantize(w2[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
@@ -792,8 +835,7 @@ def test_fused_marlin_moe_with_bias(m):
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1,
b_bias2)
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1, b_bias2)
marlin_output = torch.ops.vllm.fused_marlin_moe(
a,
@@ -817,7 +859,8 @@ def test_fused_marlin_moe_with_bias(m):
w1_zeros=zeros1,
w2_zeros=zeros2,
quant_type_id=quant_type.id,
is_k_full=is_k_full)
is_k_full=is_k_full,
)
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
@@ -825,34 +868,36 @@ def test_fused_marlin_moe_with_bias(m):
def test_moe_align_block_size_opcheck():
num_experts = 4
block_size = 4
topk_ids = torch.randint(0,
num_experts, (3, 4),
dtype=torch.int32,
device='cuda')
topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda")
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids = torch.empty((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
expert_ids = torch.empty(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
opcheck(torch.ops._moe_C.moe_align_block_size,
(topk_ids, num_experts, block_size, sorted_ids, expert_ids,
num_tokens_post_pad))
opcheck(
torch.ops._moe_C.moe_align_block_size,
(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
),
)
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
input = torch.randn((m, topk, k), device="cuda", dtype=dtype)

View File

@@ -11,7 +11,8 @@ import pytest
import torch
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
moe_align_block_size,
)
from vllm.platforms import current_platform
from vllm.utils import round_up
@@ -60,30 +61,33 @@ def _verify_expert_level_sorting(
in topk_ids in the final sorted_ids however this does not impact quality.
"""
# Group tokens by expert from the golden implementation
golden_expert_tokens = _group_tokens_by_expert(golden_sorted_ids,
expert_ids, block_size,
valid_length, total_tokens)
golden_expert_tokens = _group_tokens_by_expert(
golden_sorted_ids, expert_ids, block_size, valid_length, total_tokens
)
actual_expert_tokens = _group_tokens_by_expert(actual_sorted_ids,
expert_ids, block_size,
valid_length, total_tokens)
actual_expert_tokens = _group_tokens_by_expert(
actual_sorted_ids, expert_ids, block_size, valid_length, total_tokens
)
assert set(golden_expert_tokens.keys()) == set(
actual_expert_tokens.keys()), (
f"Expert IDs mismatch: golden={set(golden_expert_tokens.keys())}, "
f"actual={set(actual_expert_tokens.keys())}")
assert set(golden_expert_tokens.keys()) == set(actual_expert_tokens.keys()), (
f"Expert IDs mismatch: golden={set(golden_expert_tokens.keys())}, "
f"actual={set(actual_expert_tokens.keys())}"
)
for expert_id in golden_expert_tokens:
golden_tokens = torch.tensor(golden_expert_tokens[expert_id],
device=actual_sorted_ids.device)
actual_tokens = torch.tensor(actual_expert_tokens[expert_id],
device=actual_sorted_ids.device)
golden_tokens = torch.tensor(
golden_expert_tokens[expert_id], device=actual_sorted_ids.device
)
actual_tokens = torch.tensor(
actual_expert_tokens[expert_id], device=actual_sorted_ids.device
)
assert torch.equal(
torch.sort(golden_tokens)[0],
torch.sort(actual_tokens)[0]), (
f"Expert {expert_id} token mismatch: "
f"golden={golden_expert_tokens[expert_id]}, "
f"actual={actual_expert_tokens[expert_id]}")
torch.sort(golden_tokens)[0], torch.sort(actual_tokens)[0]
), (
f"Expert {expert_id} token mismatch: "
f"golden={golden_expert_tokens[expert_id]}, "
f"actual={actual_expert_tokens[expert_id]}"
)
def torch_moe_align_block_size(
@@ -104,40 +108,38 @@ def torch_moe_align_block_size(
if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
flattened_token_indices = torch.arange(topk_ids.numel(),
device=topk_ids.device,
dtype=torch.int32)
flattened_token_indices = torch.arange(
topk_ids.numel(), device=topk_ids.device, dtype=torch.int32
)
flattened_expert_ids = topk_ids.flatten()
sorted_expert_ids, sort_indices = torch.sort(flattened_expert_ids,
stable=True)
sorted_expert_ids, sort_indices = torch.sort(flattened_expert_ids, stable=True)
sorted_token_indices = flattened_token_indices[sort_indices]
expert_token_counts = torch.zeros(num_experts,
dtype=torch.int64,
device=topk_ids.device)
expert_token_counts = torch.zeros(
num_experts, dtype=torch.int64, device=topk_ids.device
)
for expert_id in range(num_experts):
mask = sorted_expert_ids == expert_id
expert_token_counts[expert_id] = mask.sum()
expert_padded_counts = torch.zeros(num_experts,
dtype=torch.int64,
device=topk_ids.device)
expert_padded_counts = torch.zeros(
num_experts, dtype=torch.int64, device=topk_ids.device
)
for expert_id in range(num_experts):
original_count = expert_token_counts[expert_id]
if original_count > 0:
expert_padded_counts[expert_id] = (
(original_count + block_size - 1) // block_size) * block_size
(original_count + block_size - 1) // block_size
) * block_size
sorted_token_ids = torch.full(
(max_num_tokens_padded, ),
(max_num_tokens_padded,),
topk_ids.numel(),
dtype=torch.int32,
device=topk_ids.device,
)
max_num_blocks = (max_num_tokens_padded + block_size - 1) // block_size
expert_ids = torch.zeros(max_num_blocks,
dtype=torch.int32,
device=topk_ids.device)
expert_ids = torch.zeros(max_num_blocks, dtype=torch.int32, device=topk_ids.device)
current_pos = 0
current_block = 0
@@ -147,20 +149,20 @@ def torch_moe_align_block_size(
num_expert_tokens = expert_tokens.shape[0]
if num_expert_tokens > 0:
sorted_token_ids[current_pos:current_pos +
num_expert_tokens] = (expert_tokens)
sorted_token_ids[current_pos : current_pos + num_expert_tokens] = (
expert_tokens
)
expert_blocks_needed = expert_padded_counts[expert_id] // block_size
expert_ids[current_block:current_block +
expert_blocks_needed] = (expert_id)
expert_ids[current_block : current_block + expert_blocks_needed] = expert_id
current_pos += expert_padded_counts[expert_id]
current_block += expert_blocks_needed
total_padded_tokens = expert_padded_counts.sum()
num_tokens_post_pad = torch.tensor([total_padded_tokens],
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.tensor(
[total_padded_tokens], dtype=torch.int32, device=topk_ids.device
)
if expert_map is not None:
expert_ids = expert_map[expert_ids]
@@ -173,37 +175,32 @@ def torch_moe_align_block_size(
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("pad_sorted_ids", [False, True])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_align_block_size(m: int, topk: int, num_experts: int,
block_size: int, pad_sorted_ids: bool):
def test_moe_align_block_size(
m: int, topk: int, num_experts: int, block_size: int, pad_sorted_ids: bool
):
"""Test moe_align_block_size without expert mapping"""
topk_ids = torch.zeros((m, topk), device="cuda", dtype=torch.int32)
for i in range(m):
experts = torch.randperm(num_experts, device="cuda")[:topk]
topk_ids[i] = experts
actual_sorted_ids, actual_expert_ids, actual_num_tokens = (
moe_align_block_size(
topk_ids=topk_ids,
block_size=block_size,
num_experts=num_experts,
pad_sorted_ids=pad_sorted_ids,
))
actual_sorted_ids, actual_expert_ids, actual_num_tokens = moe_align_block_size(
topk_ids=topk_ids,
block_size=block_size,
num_experts=num_experts,
pad_sorted_ids=pad_sorted_ids,
)
golden_sorted_ids, golden_expert_ids, golden_num_tokens = (
torch_moe_align_block_size(
topk_ids=topk_ids,
block_size=block_size,
num_experts=num_experts,
pad_sorted_ids=pad_sorted_ids,
))
)
)
torch.testing.assert_close(actual_num_tokens,
golden_num_tokens,
atol=0,
rtol=0)
torch.testing.assert_close(actual_expert_ids,
golden_expert_ids,
atol=0,
rtol=0)
torch.testing.assert_close(actual_num_tokens, golden_num_tokens, atol=0, rtol=0)
torch.testing.assert_close(actual_expert_ids, golden_expert_ids, atol=0, rtol=0)
# For sorted_token_ids, verify block-level correctness rather than exact
# order Tokens within each expert's blocks can be in any order, but expert
@@ -219,16 +216,18 @@ def test_moe_align_block_size(m: int, topk: int, num_experts: int,
total_tokens = m * topk
assert actual_num_tokens.item() % block_size == 0, (
"num_tokens_post_pad should be divisible by block_size")
"num_tokens_post_pad should be divisible by block_size"
)
assert actual_num_tokens.item() >= total_tokens, (
"num_tokens_post_pad should be at least total_tokens")
"num_tokens_post_pad should be at least total_tokens"
)
valid_tokens = actual_sorted_ids[actual_sorted_ids < total_tokens]
assert len(valid_tokens) == total_tokens, (
f"Should have exactly {total_tokens} valid tokens, "
f"got {len(valid_tokens)}")
assert (actual_expert_ids >= 0).all() and (
actual_expert_ids
< num_experts).all(), "expert_ids should contain valid expert indices"
f"Should have exactly {total_tokens} valid tokens, got {len(valid_tokens)}"
)
assert (actual_expert_ids >= 0).all() and (actual_expert_ids < num_experts).all(), (
"expert_ids should contain valid expert indices"
)
@pytest.mark.parametrize("m", [16, 32])
@@ -236,46 +235,37 @@ def test_moe_align_block_size(m: int, topk: int, num_experts: int,
@pytest.mark.parametrize("num_experts", [8])
@pytest.mark.parametrize("block_size", [64])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_align_block_size_with_expert_map(m: int, topk: int,
num_experts: int,
block_size: int):
def test_moe_align_block_size_with_expert_map(
m: int, topk: int, num_experts: int, block_size: int
):
"""Test moe_align_block_size with expert mapping (EP scenario)"""
topk_ids = torch.zeros((m, topk), device="cuda", dtype=torch.int32)
for i in range(m):
experts = torch.randperm(num_experts, device="cuda")[:topk]
topk_ids[i] = experts
expert_map = torch.full((num_experts, ),
-1,
device="cuda",
dtype=torch.int32)
expert_map = torch.full((num_experts,), -1, device="cuda", dtype=torch.int32)
local_experts = list(range(0, num_experts, 2))
for i, expert_id in enumerate(local_experts):
expert_map[expert_id] = i
actual_sorted_ids, actual_expert_ids, actual_num_tokens = (
moe_align_block_size(
topk_ids=topk_ids,
block_size=block_size,
num_experts=num_experts,
expert_map=expert_map,
))
actual_sorted_ids, actual_expert_ids, actual_num_tokens = moe_align_block_size(
topk_ids=topk_ids,
block_size=block_size,
num_experts=num_experts,
expert_map=expert_map,
)
golden_sorted_ids, golden_expert_ids, golden_num_tokens = (
torch_moe_align_block_size(
topk_ids=topk_ids,
block_size=block_size,
num_experts=num_experts,
expert_map=expert_map,
))
)
)
torch.testing.assert_close(actual_num_tokens,
golden_num_tokens,
atol=0,
rtol=0)
torch.testing.assert_close(actual_expert_ids,
golden_expert_ids,
atol=0,
rtol=0)
torch.testing.assert_close(actual_num_tokens, golden_num_tokens, atol=0, rtol=0)
torch.testing.assert_close(actual_expert_ids, golden_expert_ids, atol=0, rtol=0)
_verify_expert_level_sorting(
actual_sorted_ids,
golden_sorted_ids,
@@ -290,26 +280,25 @@ def test_moe_align_block_size_deterministic():
m, topk, num_experts, block_size = 128, 2, 32, 64
torch.manual_seed(42)
topk_ids = torch.randint(0,
num_experts, (m, topk),
device="cuda",
dtype=torch.int32)
topk_ids = torch.randint(
0, num_experts, (m, topk), device="cuda", dtype=torch.int32
)
# expect the results to be reproducible
results = []
for _ in range(5):
sorted_ids, expert_ids, num_tokens = moe_align_block_size(
topk_ids=topk_ids, block_size=block_size, num_experts=num_experts)
results.append(
(sorted_ids.clone(), expert_ids.clone(), num_tokens.clone()))
topk_ids=topk_ids, block_size=block_size, num_experts=num_experts
)
results.append((sorted_ids.clone(), expert_ids.clone(), num_tokens.clone()))
for i in range(1, len(results)):
assert torch.equal(
results[0][0],
results[i][0]), ("sorted_ids should be deterministic")
assert torch.equal(
results[0][1],
results[i][1]), ("expert_ids should be deterministic")
assert torch.equal(
results[0][2],
results[i][2]), ("num_tokens should be deterministic")
assert torch.equal(results[0][0], results[i][0]), (
"sorted_ids should be deterministic"
)
assert torch.equal(results[0][1], results[i][1]), (
"expert_ids should be deterministic"
)
assert torch.equal(results[0][2], results[i][2]), (
"num_tokens should be deterministic"
)

View File

@@ -14,7 +14,10 @@ import torch
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.layer import determine_expert_map
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
moe_permute, moe_permute_unpermute_supported, moe_unpermute)
moe_permute,
moe_permute_unpermute_supported,
moe_unpermute,
)
from vllm.platforms import current_platform
NUM_EXPERTS = [16, 64, 256]
@@ -24,35 +27,34 @@ current_platform.seed_everything(0)
def torch_permute(
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
# token_expert_indices: torch.Tensor,
topk: int,
n_expert: int,
n_local_expert: int,
start_expert: int,
expert_map: Optional[torch.Tensor] = None,
align_block_size: Optional[int] = None,
fill_invalid_expert: int = -1) -> list[torch.Tensor]:
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
# token_expert_indices: torch.Tensor,
topk: int,
n_expert: int,
n_local_expert: int,
start_expert: int,
expert_map: Optional[torch.Tensor] = None,
align_block_size: Optional[int] = None,
fill_invalid_expert: int = -1,
) -> list[torch.Tensor]:
n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1]
if expert_map is not None:
is_local_expert = (expert_map[topk_ids] != -1)
not_local_expert = (expert_map[topk_ids] == -1)
topk_ids = is_local_expert * (
topk_ids - start_expert) + not_local_expert * (topk_ids + n_expert)
token_expert_indices = torch.arange(0,
n_token * topk,
dtype=torch.int32,
device=hidden_states.device).reshape(
(n_token, topk))
is_local_expert = expert_map[topk_ids] != -1
not_local_expert = expert_map[topk_ids] == -1
topk_ids = is_local_expert * (topk_ids - start_expert) + not_local_expert * (
topk_ids + n_expert
)
token_expert_indices = torch.arange(
0, n_token * topk, dtype=torch.int32, device=hidden_states.device
).reshape((n_token, topk))
sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(),
stable=True)
sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(), stable=True)
dst_row_id2src_row_id_map = token_expert_indices.flatten()[sorted_indices]
expert_first_token_offset = torch.zeros(n_local_expert + 1,
dtype=torch.int64,
device="cuda")
expert_first_token_offset = torch.zeros(
n_local_expert + 1, dtype=torch.int64, device="cuda"
)
idx = 0
for i in range(0, n_local_expert):
cnt = 0
@@ -64,116 +66,133 @@ def torch_permute(
_, src2dst_idx = torch.sort(dst_row_id2src_row_id_map)
valid_row_idx = []
if align_block_size is None:
permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map //
topk, ...]
permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map // topk, ...]
permuted_row_size = permuted_hidden_states.shape[0]
m_indices = torch.empty(permuted_row_size,
device="cuda",
dtype=torch.int32).fill_(fill_invalid_expert)
m_indices = torch.empty(
permuted_row_size, device="cuda", dtype=torch.int32
).fill_(fill_invalid_expert)
for i in range(1, n_local_expert + 1):
first_token_offset = expert_first_token_offset[i - 1]
last_token_offset = expert_first_token_offset[i]
m_indices[first_token_offset:last_token_offset] = i - 1
src_row_id2dst_row_id_map = torch.arange(
0, n_token * topk, device="cuda",
dtype=torch.int32)[src2dst_idx].reshape((n_token, topk))
0, n_token * topk, device="cuda", dtype=torch.int32
)[src2dst_idx].reshape((n_token, topk))
valid_row_idx += [i for i in range(expert_first_token_offset[-1])]
dst_row_id2src_row_id_map[
expert_first_token_offset[-1]:] = n_token * topk
dst_row_id2src_row_id_map[expert_first_token_offset[-1] :] = n_token * topk
return [
permuted_hidden_states, expert_first_token_offset,
src_row_id2dst_row_id_map, dst_row_id2src_row_id_map, m_indices,
valid_row_idx
permuted_hidden_states,
expert_first_token_offset,
src_row_id2dst_row_id_map,
dst_row_id2src_row_id_map,
m_indices,
valid_row_idx,
]
else:
permuted_row_size = (topk * n_token + n_expert *
(align_block_size - 1) + align_block_size -
1) // align_block_size * align_block_size
permuted_idx = torch.full((permuted_row_size, ),
n_token * topk,
dtype=torch.int32,
device=hidden_states.device)
permuted_hidden_states = torch.empty((permuted_row_size, n_hidden),
device="cuda",
dtype=hidden_states.dtype)
align_src_row_id2dst_row_id = torch.empty(n_token * topk,
device="cuda",
dtype=torch.int32)
align_expert_first_token_offset = torch.zeros_like(
expert_first_token_offset)
m_indices = torch.empty(permuted_row_size,
device="cuda",
dtype=torch.int32).fill_(fill_invalid_expert)
permuted_row_size = (
(topk * n_token + n_expert * (align_block_size - 1) + align_block_size - 1)
// align_block_size
* align_block_size
)
permuted_idx = torch.full(
(permuted_row_size,),
n_token * topk,
dtype=torch.int32,
device=hidden_states.device,
)
permuted_hidden_states = torch.empty(
(permuted_row_size, n_hidden), device="cuda", dtype=hidden_states.dtype
)
align_src_row_id2dst_row_id = torch.empty(
n_token * topk, device="cuda", dtype=torch.int32
)
align_expert_first_token_offset = torch.zeros_like(expert_first_token_offset)
m_indices = torch.empty(
permuted_row_size, device="cuda", dtype=torch.int32
).fill_(fill_invalid_expert)
# get align_permuted_hidden_states,
# valid row_idx and align_expert_first_token_offset
for i in range(1, n_local_expert + 1):
first_token_offset = expert_first_token_offset[i - 1]
last_token_offset = expert_first_token_offset[i]
n_token_in_expert = last_token_offset - first_token_offset
align_expert_first_token_offset[
i] = align_expert_first_token_offset[
i - 1] + (n_token_in_expert + align_block_size -
1) // align_block_size * align_block_size
align_expert_first_token_offset[i] = (
align_expert_first_token_offset[i - 1]
+ (n_token_in_expert + align_block_size - 1)
// align_block_size
* align_block_size
)
align_first_token_offset = align_expert_first_token_offset[i - 1]
align_last_token_offset = align_expert_first_token_offset[i]
dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[
first_token_offset:first_token_offset + n_token_in_expert]
first_token_offset : first_token_offset + n_token_in_expert
]
# store token in current expert with align_first_token_offset
permuted_hidden_states[align_first_token_offset:\
align_first_token_offset+n_token_in_expert,\
...] = hidden_states[\
dst_row_id2src_row_id_in_expert // topk,\
...]
permuted_idx[align_first_token_offset:\
align_first_token_offset+\
n_token_in_expert] = dst_row_id2src_row_id_in_expert
permuted_hidden_states[
align_first_token_offset : align_first_token_offset + n_token_in_expert,
...,
] = hidden_states[dst_row_id2src_row_id_in_expert // topk, ...]
permuted_idx[
align_first_token_offset : align_first_token_offset + n_token_in_expert
] = dst_row_id2src_row_id_in_expert
# set current expert m_indices
m_indices[align_first_token_offset:align_last_token_offset] = i - 1
valid_row_idx += [
i for i in range(align_first_token_offset,
align_first_token_offset + n_token_in_expert)
i
for i in range(
align_first_token_offset,
align_first_token_offset + n_token_in_expert,
)
]
# get align_src_row_id2dst_row_id
for i in range(n_token * topk):
eid = sorted_topk_ids[i]
if (eid >= n_local_expert):
if eid >= n_local_expert:
# check token not in local expert
align_src_row_id2dst_row_id[
i] = align_expert_first_token_offset[-1]
align_src_row_id2dst_row_id[i] = align_expert_first_token_offset[-1]
continue
first_token_offset = expert_first_token_offset[eid]
align_first_token_offset = align_expert_first_token_offset[eid]
token_offset = i - first_token_offset
align_src_row_id2dst_row_id[
i] = align_first_token_offset + token_offset
align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[\
src2dst_idx].reshape((n_token, topk))
align_src_row_id2dst_row_id[i] = align_first_token_offset + token_offset
align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[src2dst_idx].reshape(
(n_token, topk)
)
return [
permuted_hidden_states, align_expert_first_token_offset,
align_src_row_id2dst_row_id, permuted_idx, m_indices, valid_row_idx
permuted_hidden_states,
align_expert_first_token_offset,
align_src_row_id2dst_row_id,
permuted_idx,
m_indices,
valid_row_idx,
]
def torch_unpermute(permuted_hidden_states: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indices: torch.Tensor,
src_row_id2dst_row_id_map: torch.Tensor,
valid_row_idx: torch.Tensor, topk: int,
n_expert: int) -> torch.Tensor:
def torch_unpermute(
permuted_hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
token_expert_indices: torch.Tensor,
src_row_id2dst_row_id_map: torch.Tensor,
valid_row_idx: torch.Tensor,
topk: int,
n_expert: int,
) -> torch.Tensor:
# ignore invalid row
n_hidden = permuted_hidden_states.shape[1]
mask = torch.zeros(permuted_hidden_states.shape[0],
dtype=bool,
device="cuda")
mask = torch.zeros(permuted_hidden_states.shape[0], dtype=bool, device="cuda")
mask[valid_row_idx] = True
permuted_hidden_states[~mask] = 0
permuted_hidden_states = permuted_hidden_states[
src_row_id2dst_row_id_map.flatten(), ...]
src_row_id2dst_row_id_map.flatten(), ...
]
permuted_hidden_states = permuted_hidden_states.view(-1, topk, n_hidden)
output = (permuted_hidden_states * topk_weights.unsqueeze(2)).sum(1).to(
permuted_hidden_states.dtype)
output = (
(permuted_hidden_states * topk_weights.unsqueeze(2))
.sum(1)
.to(permuted_hidden_states.dtype)
)
return output
@@ -184,59 +203,76 @@ def torch_unpermute(permuted_hidden_states: torch.Tensor,
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("align_block_size", [None, 128])
def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
n_expert: int, ep_size: int, dtype: torch.dtype,
align_block_size: Optional[int]):
def test_moe_permute_unpermute(
n_token: int,
n_hidden: int,
topk: int,
n_expert: int,
ep_size: int,
dtype: torch.dtype,
align_block_size: Optional[int],
):
if not moe_permute_unpermute_supported():
pytest.skip("moe_permute_unpermute is not supported on this platform.")
fill_invalid_expert = 0
ep_rank = np.random.randint(0, ep_size)
expert_map = None
n_local_expert = n_expert
if (ep_size != 1):
n_local_expert, expert_map = determine_expert_map(
ep_size, ep_rank, n_expert)
if ep_size != 1:
n_local_expert, expert_map = determine_expert_map(ep_size, ep_rank, n_expert)
expert_map = expert_map.cuda()
start_expert = n_local_expert * ep_rank
current_platform.seed_everything(0)
hidden_states = torch.randn((n_token, n_hidden), device="cuda").to(dtype)
gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype)
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states, gating_output, topk, False)
(gold_permuted_hidden_states, gold_expert_first_token_offset,
gold_inv_permuted_idx, gold_permuted_idx, gold_m_indices,
valid_row_idx) = torch_permute(
hidden_states,
topk_ids,
# token_expert_indices,
topk,
n_expert,
n_local_expert,
start_expert,
expert_map=expert_map,
align_block_size=align_block_size,
fill_invalid_expert=fill_invalid_expert)
hidden_states, gating_output, topk, False
)
(
gold_permuted_hidden_states,
gold_expert_first_token_offset,
gold_inv_permuted_idx,
gold_permuted_idx,
gold_m_indices,
valid_row_idx,
) = torch_permute(
hidden_states,
topk_ids,
# token_expert_indices,
topk,
n_expert,
n_local_expert,
start_expert,
expert_map=expert_map,
align_block_size=align_block_size,
fill_invalid_expert=fill_invalid_expert,
)
(permuted_hidden_states, _, expert_first_token_offset, inv_permuted_idx,
m_indices) = moe_permute(hidden_states=hidden_states,
a1q_scale=None,
topk_ids=topk_ids,
n_expert=n_expert,
n_local_expert=n_local_expert,
expert_map=expert_map,
align_block_size=align_block_size,
fill_invalid_expert=fill_invalid_expert)
(
permuted_hidden_states,
_,
expert_first_token_offset,
inv_permuted_idx,
m_indices,
) = moe_permute(
hidden_states=hidden_states,
a1q_scale=None,
topk_ids=topk_ids,
n_expert=n_expert,
n_local_expert=n_local_expert,
expert_map=expert_map,
align_block_size=align_block_size,
fill_invalid_expert=fill_invalid_expert,
)
# check expert_first_token_offset
torch.testing.assert_close(gold_expert_first_token_offset,
expert_first_token_offset,
atol=0,
rtol=0)
torch.testing.assert_close(
gold_expert_first_token_offset, expert_first_token_offset, atol=0, rtol=0
)
# check src_row_id2dst_row_id_map
torch.testing.assert_close(gold_inv_permuted_idx.flatten(),
inv_permuted_idx,
atol=0,
rtol=0)
torch.testing.assert_close(
gold_inv_permuted_idx.flatten(), inv_permuted_idx, atol=0, rtol=0
)
# check mindice
# current kernel usage assumes deepgemm requires align_block_size
# when it's not provided then we don't compute m_indices (for cutlass)
@@ -244,19 +280,28 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)
# check permuted_hidden_states, only valid token
torch.testing.assert_close(gold_permuted_hidden_states[valid_row_idx],
permuted_hidden_states[valid_row_idx],
atol=0,
rtol=0)
torch.testing.assert_close(
gold_permuted_hidden_states[valid_row_idx],
permuted_hidden_states[valid_row_idx],
atol=0,
rtol=0,
)
# add a random tensor to simulate group gemm
result0 = 0.5 * permuted_hidden_states + torch.randn_like(
permuted_hidden_states)
result0 = 0.5 * permuted_hidden_states + torch.randn_like(permuted_hidden_states)
result4 = torch.empty_like(hidden_states)
moe_unpermute(result4, result0, topk_weights, inv_permuted_idx,
expert_first_token_offset)
moe_unpermute(
result4, result0, topk_weights, inv_permuted_idx, expert_first_token_offset
)
gold4 = torch_unpermute(result0, topk_weights, topk_ids,
token_expert_indices, inv_permuted_idx,
valid_row_idx, topk, n_local_expert)
gold4 = torch_unpermute(
result0,
topk_weights,
topk_ids,
token_expert_indices,
inv_permuted_idx,
valid_row_idx,
topk,
n_local_expert,
)
# check unpermuted hidden
torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0)

View File

@@ -11,27 +11,39 @@ import torch
from packaging import version
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
QuarkLinearMethod, QuarkW4A4MXFP4)
QuarkLinearMethod,
QuarkW4A4MXFP4,
)
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
QuarkW4A4MXFp4MoEMethod)
QuarkW4A4MXFp4MoEMethod,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
importlib.metadata.version("amd-quark")
) >= version.parse("0.8.99")
TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda(
) and current_platform.is_device_capability(100)
TRTLLM_GEN_MXFP4_AVAILABLE = (
current_platform.is_cuda() and current_platform.is_device_capability(100)
)
HOPPER_MXFP4_BF16_AVAILABLE = (current_platform.is_cuda()
and current_platform.is_device_capability(90)
and has_flashinfer())
HOPPER_MXFP4_BF16_AVAILABLE = (
current_platform.is_cuda()
and current_platform.is_device_capability(90)
and has_flashinfer()
)
if TRTLLM_GEN_MXFP4_AVAILABLE:
from flashinfer import (fp4_quantize, mxfp8_quantize,
next_positive_power_of_2,
reorder_rows_for_gated_act_gemm, shuffle_matrix_a,
shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe)
from flashinfer import (
fp4_quantize,
mxfp8_quantize,
next_positive_power_of_2,
reorder_rows_for_gated_act_gemm,
shuffle_matrix_a,
shuffle_matrix_sf_a,
trtllm_fp4_block_scale_moe,
)
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices
@@ -48,21 +60,25 @@ def enable_pickle(monkeypatch):
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
@pytest.mark.parametrize('model_case', [
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1),
ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8),
ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1)
])
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE,
reason="amd-quark>=0.9 is not available")
@pytest.mark.parametrize(
"model_case",
[
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1),
ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8),
ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1),
],
)
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
if torch.cuda.device_count() < model_case.tp:
pytest.skip(f"This test requires >={model_case.tp} gpus, got only "
f"{torch.cuda.device_count()}")
pytest.skip(
f"This test requires >={model_case.tp} gpus, got only "
f"{torch.cuda.device_count()}"
)
with vllm_runner(model_case.model_id,
tensor_parallel_size=model_case.tp,
load_format="dummy") as llm:
with vllm_runner(
model_case.model_id, tensor_parallel_size=model_case.tp, load_format="dummy"
) as llm:
def check_model(model):
layer = model.model.layers[0]
@@ -72,21 +88,16 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4)
assert isinstance(layer.mlp.experts.quant_method,
QuarkW4A4MXFp4MoEMethod)
assert isinstance(layer.mlp.experts.quant_method, QuarkW4A4MXFp4MoEMethod)
if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
llm.apply_model(check_model)
output = llm.generate_greedy("Today I am in the French Alps and",
max_tokens=20)
output = llm.generate_greedy("Today I am in the French Alps and", max_tokens=20)
assert output
def swiglu(x,
alpha: float = 1.702,
beta: float = 1.0,
limit: Optional[float] = None):
def swiglu(x, alpha: float = 1.702, beta: float = 1.0, limit: Optional[float] = None):
# Note we add an extra bias of 1 to the linear layer
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
if limit is not None:
@@ -96,24 +107,19 @@ def swiglu(x,
return out_glu * (x_linear + beta)
fp4_lookup_table = [
0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6
]
fp4_lookup_table = [0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6]
def mxfp4_dequantize(x, scale):
assert x.dtype == torch.uint8
x = x.view(torch.uint8).to(torch.int32)
x_unpacked = torch.zeros(*x.shape[:-1],
x.shape[-1] * 2,
dtype=torch.int32,
device=x.device)
x_unpacked = torch.zeros(
*x.shape[:-1], x.shape[-1] * 2, dtype=torch.int32, device=x.device
)
x_unpacked[..., 0::2].copy_(x & 0xF)
x_unpacked[..., 1::2].copy_((x >> 4) & 0xF)
x_float = torch.zeros(x_unpacked.shape,
dtype=torch.float32,
device=x.device)
x_float = torch.zeros(x_unpacked.shape, dtype=torch.float32, device=x.device)
for i, val in enumerate(fp4_lookup_table):
x_float[x_unpacked == i] = val
@@ -162,9 +168,10 @@ def reference_moe(
t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
t = swiglu(t, alpha=alpha, beta=beta, limit=limit)
if act_type == 'mxfp8':
t_quantized, t_scale = mxfp8_quantize(t.to(torch.bfloat16),
is_sf_swizzled_layout=False)
if act_type == "mxfp8":
t_quantized, t_scale = mxfp8_quantize(
t.to(torch.bfloat16), is_sf_swizzled_layout=False
)
t = mxfp8_dequantize(t_quantized, t_scale)
# MLP #2
mlp2_weight = w2[expert_indices, ...]
@@ -221,37 +228,53 @@ def tg_mxfp4_moe(
transpose_optimized: bool = False,
) -> torch.Tensor:
sf_block_size = 32
assert (w13_weight.dim() == 3 and w13_weight.shape[0] == num_experts
and w13_weight.shape[1] == intermediate_size * 2
and w13_weight.shape[2] == hidden_size // 2)
assert (w13_weight_scale.dim() == 3
and w13_weight_scale.shape[0] == num_experts
and w13_weight_scale.shape[1] == intermediate_size * 2
and w13_weight_scale.shape[2] == hidden_size // sf_block_size)
assert (w2_weight.dim() == 3 and w2_weight.shape[0] == num_experts
and w2_weight.shape[1] == hidden_size
and w2_weight.shape[2] == intermediate_size // 2)
assert (w2_weight_scale.dim() == 3
and w2_weight_scale.shape[1] == hidden_size
and w2_weight_scale.shape[2] == intermediate_size // sf_block_size)
assert (w13_bias.dim() == 2 and w13_bias.shape[0] == num_experts
and w13_bias.shape[1] == intermediate_size * 2)
assert (w2_bias.dim() == 2 and w2_bias.shape[0] == num_experts
and w2_bias.shape[1] == hidden_size)
assert (
w13_weight.dim() == 3
and w13_weight.shape[0] == num_experts
and w13_weight.shape[1] == intermediate_size * 2
and w13_weight.shape[2] == hidden_size // 2
)
assert (
w13_weight_scale.dim() == 3
and w13_weight_scale.shape[0] == num_experts
and w13_weight_scale.shape[1] == intermediate_size * 2
and w13_weight_scale.shape[2] == hidden_size // sf_block_size
)
assert (
w2_weight.dim() == 3
and w2_weight.shape[0] == num_experts
and w2_weight.shape[1] == hidden_size
and w2_weight.shape[2] == intermediate_size // 2
)
assert (
w2_weight_scale.dim() == 3
and w2_weight_scale.shape[1] == hidden_size
and w2_weight_scale.shape[2] == intermediate_size // sf_block_size
)
assert (
w13_bias.dim() == 2
and w13_bias.shape[0] == num_experts
and w13_bias.shape[1] == intermediate_size * 2
)
assert (
w2_bias.dim() == 2
and w2_bias.shape[0] == num_experts
and w2_bias.shape[1] == hidden_size
)
# Swap w1 and w3 as the definition of
# swiglu is different in the trtllm-gen
w13_weight_scale_ = w13_weight_scale.clone()
w13_weight_ = w13_weight.clone()
w13_bias_ = w13_bias.clone()
w13_weight[:, :intermediate_size, :].copy_(
w13_weight_[:, intermediate_size:, :])
w13_weight[:, intermediate_size:, :].copy_(
w13_weight_[:, :intermediate_size, :])
w13_weight[:, :intermediate_size, :].copy_(w13_weight_[:, intermediate_size:, :])
w13_weight[:, intermediate_size:, :].copy_(w13_weight_[:, :intermediate_size, :])
w13_weight_scale[:, :intermediate_size, :].copy_(
w13_weight_scale_[:, intermediate_size:, :])
w13_weight_scale_[:, intermediate_size:, :]
)
w13_weight_scale[:, intermediate_size:, :].copy_(
w13_weight_scale_[:, :intermediate_size, :])
w13_weight_scale_[:, :intermediate_size, :]
)
w13_bias[:, :intermediate_size].copy_(w13_bias_[:, intermediate_size:])
w13_bias[:, intermediate_size:].copy_(w13_bias_[:, :intermediate_size])
@@ -261,18 +284,23 @@ def tg_mxfp4_moe(
w13_bias_interleaved = []
for i in range(num_experts):
w13_weight_interleaved.append(
reorder_rows_for_gated_act_gemm(w13_weight[i].clone()))
reorder_rows_for_gated_act_gemm(w13_weight[i].clone())
)
w13_weight_scale_interleaved.append(
reorder_rows_for_gated_act_gemm(w13_weight_scale[i].clone()))
reorder_rows_for_gated_act_gemm(w13_weight_scale[i].clone())
)
w13_bias_interleaved.append(
reorder_rows_for_gated_act_gemm(w13_bias[i].clone().reshape(-1,
1)))
reorder_rows_for_gated_act_gemm(w13_bias[i].clone().reshape(-1, 1))
)
w13_weight = torch.stack(w13_weight_interleaved).reshape(
num_experts, 2 * intermediate_size, hidden_size // 2)
num_experts, 2 * intermediate_size, hidden_size // 2
)
w13_weight_scale = torch.stack(w13_weight_scale_interleaved).reshape(
num_experts, 2 * intermediate_size, hidden_size // 32)
num_experts, 2 * intermediate_size, hidden_size // 32
)
w13_bias = torch.stack(w13_bias_interleaved).reshape(
num_experts, 2 * intermediate_size)
num_experts, 2 * intermediate_size
)
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_shuffled = []
@@ -291,9 +319,11 @@ def tg_mxfp4_moe(
w13_weight[i].view(torch.uint8),
epilogue_tile_m,
)
gemm1_weights_shuffled.append(w13_weight[i].view(
torch.uint8)[permute_indices.to(
w13_weight.device)].contiguous())
gemm1_weights_shuffled.append(
w13_weight[i]
.view(torch.uint8)[permute_indices.to(w13_weight.device)]
.contiguous()
)
# w13 scale shuffling
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
_cache_permute_indices,
@@ -302,26 +332,35 @@ def tg_mxfp4_moe(
num_elts_per_sf=16,
)
gemm1_scales_shuffled.append(
nvfp4_block_scale_interleave(w13_weight_scale[i].view(
torch.uint8)[permute_sf_indices.to(
w13_weight_scale.device)].contiguous()))
nvfp4_block_scale_interleave(
w13_weight_scale[i]
.view(torch.uint8)[permute_sf_indices.to(w13_weight_scale.device)]
.contiguous()
)
)
# w13 bias shuffling
permute_bias_indices = _maybe_get_cached_w2_permute_indices(
_cache_permute_indices,
w13_bias[i].clone().reshape(-1, 1),
epilogue_tile_m,
)
gemm1_bias_shuffled.append(w13_bias[i].clone().reshape(
-1, 1)[permute_bias_indices.to(w13_bias.device)].contiguous())
gemm1_bias_shuffled.append(
w13_bias[i]
.clone()
.reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)]
.contiguous()
)
# w2 weight shuffling
permute_indices = _maybe_get_cached_w2_permute_indices(
_cache_permute_indices,
w2_weight[i].view(torch.uint8),
epilogue_tile_m,
)
gemm2_weights_shuffled.append(w2_weight[i].view(
torch.uint8)[permute_indices.to(
w2_weight.device)].contiguous())
gemm2_weights_shuffled.append(
w2_weight[i]
.view(torch.uint8)[permute_indices.to(w2_weight.device)]
.contiguous()
)
# w2 scale shuffling
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
_cache_permute_indices,
@@ -330,48 +369,65 @@ def tg_mxfp4_moe(
num_elts_per_sf=16,
)
gemm2_scales_shuffled.append(
nvfp4_block_scale_interleave(w2_weight_scale[i].view(
torch.uint8)[permute_sf_indices.to(
w2_weight_scale.device)].contiguous()))
nvfp4_block_scale_interleave(
w2_weight_scale[i]
.view(torch.uint8)[permute_sf_indices.to(w2_weight_scale.device)]
.contiguous()
)
)
# w2 bias shuffling
permute_indices = _maybe_get_cached_w2_permute_indices(
_cache_permute_indices,
w2_bias[i].clone().reshape(-1, 1),
epilogue_tile_m,
)
gemm2_bias_shuffled.append(w2_bias[i].clone().reshape(
-1, 1)[permute_indices.to(w2_bias.device)].contiguous())
gemm2_bias_shuffled.append(
w2_bias[i]
.clone()
.reshape(-1, 1)[permute_indices.to(w2_bias.device)]
.contiguous()
)
else:
for i in range(num_experts):
gemm1_weights_shuffled.append(
shuffle_matrix_a(w13_weight[i].view(torch.uint8),
epilogue_tile_m))
shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m)
)
gemm1_scales_shuffled.append(
shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8),
epilogue_tile_m))
shuffle_matrix_sf_a(
w13_weight_scale[i].view(torch.uint8), epilogue_tile_m
)
)
gemm2_weights_shuffled.append(
shuffle_matrix_a(w2_weight[i].view(torch.uint8),
epilogue_tile_m))
shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m)
)
gemm2_scales_shuffled.append(
shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8),
epilogue_tile_m))
shuffle_matrix_sf_a(
w2_weight_scale[i].view(torch.uint8), epilogue_tile_m
)
)
gemm1_bias_shuffled.append(
shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m))
shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m)
)
gemm2_bias_shuffled.append(
shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m))
shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m)
)
w13_weight = torch.stack(gemm1_weights_shuffled)
w13_weight_scale = torch.stack(gemm1_scales_shuffled).reshape(
num_experts, 2 * intermediate_size,
hidden_size // sf_block_size).view(torch.float8_e4m3fn)
w13_weight_scale = (
torch.stack(gemm1_scales_shuffled)
.reshape(num_experts, 2 * intermediate_size, hidden_size // sf_block_size)
.view(torch.float8_e4m3fn)
)
w13_bias = torch.stack(gemm1_bias_shuffled).reshape(num_experts, -1)
w2_weight = torch.stack(gemm2_weights_shuffled)
w2_weight_scale = torch.stack(gemm2_scales_shuffled).reshape(
num_experts, hidden_size,
intermediate_size // sf_block_size).view(torch.float8_e4m3fn)
w2_weight_scale = (
torch.stack(gemm2_scales_shuffled)
.reshape(num_experts, hidden_size, intermediate_size // sf_block_size)
.view(torch.float8_e4m3fn)
)
w2_bias = torch.stack(gemm2_bias_shuffled).reshape(num_experts, -1)
tg_result = trtllm_fp4_block_scale_moe(
@@ -401,7 +457,8 @@ def tg_mxfp4_moe(
routed_scaling_factor=None,
tile_tokens_dim=get_tile_tokens_dim(hidden_states, topk, num_experts),
routing_method_type=1, # renormalize
do_finalize=True)[0]
do_finalize=True,
)[0]
return tg_result
@@ -424,20 +481,21 @@ def check_accuracy(a, b, atol, rtol, percent):
if mismatch_percent > 1 - percent:
raise Exception(
f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} "
f"(threshold: {1-percent:.4f})")
f"(threshold: {1 - percent:.4f})"
)
@pytest.mark.parametrize("topk", [1, 4])
@pytest.mark.parametrize("num_experts", [32, 128])
@pytest.mark.parametrize("num_tokens", [1, 128, 1024])
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
(1.702, 1.0, 7.0)])
@pytest.mark.parametrize("act_type", ['mxfp8', 'bf16'])
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)])
@pytest.mark.parametrize("act_type", ["mxfp8", "bf16"])
@pytest.mark.parametrize("transpose_optimized", [False, True])
@pytest.mark.skipif(
not TRTLLM_GEN_MXFP4_AVAILABLE,
reason="nvidia gpu and compute capability sm100 is required for this test")
reason="nvidia gpu and compute capability sm100 is required for this test",
)
def test_trtllm_gen_mxfp4_fused_moe(
topk: int,
num_experts: int,
@@ -452,45 +510,52 @@ def test_trtllm_gen_mxfp4_fused_moe(
):
seed = 42
torch.manual_seed(seed)
hidden_states = torch.randn(num_tokens,
hidden_size,
device="cuda:0",
dtype=torch.bfloat16)
w13 = (torch.randn(num_experts,
intermediate_size * 2,
hidden_size,
device="cuda:0",
dtype=torch.bfloat16))
w2 = (torch.randn(num_experts,
hidden_size,
intermediate_size,
device="cuda:0",
dtype=torch.bfloat16))
bias13 = torch.randn(num_experts, intermediate_size * 2,
device="cuda:0") * 10
hidden_states = torch.randn(
num_tokens, hidden_size, device="cuda:0", dtype=torch.bfloat16
)
w13 = torch.randn(
num_experts,
intermediate_size * 2,
hidden_size,
device="cuda:0",
dtype=torch.bfloat16,
)
w2 = torch.randn(
num_experts,
hidden_size,
intermediate_size,
device="cuda:0",
dtype=torch.bfloat16,
)
bias13 = torch.randn(num_experts, intermediate_size * 2, device="cuda:0") * 10
bias2 = torch.randn(num_experts, hidden_size, device="cuda:0") * 10
router_logits = torch.rand(num_tokens, num_experts,
dtype=torch.float32).cuda()
router_logits = torch.rand(num_tokens, num_experts, dtype=torch.float32).cuda()
w13, w13_scale = fp4_quantize(w13,
torch.tensor(1.0, device="cuda:0"),
32,
sf_use_ue8m0=True,
is_sf_swizzled_layout=False)
w13, w13_scale = fp4_quantize(
w13,
torch.tensor(1.0, device="cuda:0"),
32,
sf_use_ue8m0=True,
is_sf_swizzled_layout=False,
)
w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape(
num_experts, intermediate_size * 2, hidden_size // 32)
w2, w2_scale = fp4_quantize(w2,
torch.tensor(1.0, device="cuda:0"),
32,
sf_use_ue8m0=True,
is_sf_swizzled_layout=False)
num_experts, intermediate_size * 2, hidden_size // 32
)
w2, w2_scale = fp4_quantize(
w2,
torch.tensor(1.0, device="cuda:0"),
32,
sf_use_ue8m0=True,
is_sf_swizzled_layout=False,
)
w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape(
num_experts, hidden_size, intermediate_size // 32)
if act_type == 'mxfp8':
num_experts, hidden_size, intermediate_size // 32
)
if act_type == "mxfp8":
hidden_states, hidden_states_scale = mxfp8_quantize(
hidden_states, is_sf_swizzled_layout=False)
hidden_states_scale = hidden_states_scale.view(
torch.float8_e4m3fn).reshape(-1)
hidden_states, is_sf_swizzled_layout=False
)
hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape(-1)
else:
hidden_states_scale = None
@@ -500,9 +565,10 @@ def test_trtllm_gen_mxfp4_fused_moe(
w2_ref = mxfp4_dequantize(w2.clone(), w2_scale.clone())
bias13_ref = bias13
bias2_ref = bias2
if act_type == 'mxfp8':
hidden_states_ref = mxfp8_dequantize(
hidden_states, hidden_states_scale).to(torch.float32)
if act_type == "mxfp8":
hidden_states_ref = mxfp8_dequantize(hidden_states, hidden_states_scale).to(
torch.float32
)
else:
hidden_states_ref = hidden_states.to(torch.float32)
# Process tokens in chunks of 32 to reduce memory usage
@@ -529,29 +595,31 @@ def test_trtllm_gen_mxfp4_fused_moe(
# trtllm-gen result
if alpha is not None:
alpha = torch.full((num_experts, ), alpha, device=hidden_states.device)
alpha = torch.full((num_experts,), alpha, device=hidden_states.device)
if limit is not None:
limit = torch.full((num_experts, ), limit, device=hidden_states.device)
limit = torch.full((num_experts,), limit, device=hidden_states.device)
if beta is not None:
beta = torch.full((num_experts, ), beta, device=hidden_states.device)
tg_result = tg_mxfp4_moe(router_logits,
topk,
num_experts,
intermediate_size,
hidden_size,
hidden_states,
hidden_states_scale,
w13,
w13_scale,
bias13,
w2,
w2_scale,
bias2,
act_type,
alpha=alpha,
beta=beta,
limit=limit,
transpose_optimized=transpose_optimized)
beta = torch.full((num_experts,), beta, device=hidden_states.device)
tg_result = tg_mxfp4_moe(
router_logits,
topk,
num_experts,
intermediate_size,
hidden_size,
hidden_states,
hidden_states_scale,
w13,
w13_scale,
bias13,
w2,
w2_scale,
bias2,
act_type,
alpha=alpha,
beta=beta,
limit=limit,
transpose_optimized=transpose_optimized,
)
# relatively loose check since the mxfp4 quantization is less accurate
check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8)
@@ -573,8 +641,7 @@ def _interleave_scales_lastdim_by4(scales: torch.Tensor) -> torch.Tensor:
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("num_tokens", [1, 128])
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
(1.702, 1.0, 7.0)])
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)])
@pytest.mark.skipif(
not HOPPER_MXFP4_BF16_AVAILABLE,
reason="nvidia gpu sm90 and flashinfer are required for this test",
@@ -593,52 +660,73 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
device = "cuda:0"
# Inputs
hidden_states = torch.randn(num_tokens,
hidden_size,
device=device,
dtype=torch.bfloat16)
hidden_states = torch.randn(
num_tokens, hidden_size, device=device, dtype=torch.bfloat16
)
# Random MXFP4 weights and scales (uint8), contiguous [w1; w3]
w13_q = torch.randint(
0,
256, (num_experts, 2 * intermediate_size, hidden_size // 2),
256,
(num_experts, 2 * intermediate_size, hidden_size // 2),
device=device,
dtype=torch.uint8)
dtype=torch.uint8,
)
w13_scale = torch.randint(
118,
123, (num_experts, 2 * intermediate_size, hidden_size // 32),
123,
(num_experts, 2 * intermediate_size, hidden_size // 32),
device=device,
dtype=torch.uint8)
dtype=torch.uint8,
)
w2_q = torch.randint(0,
256,
(num_experts, hidden_size, intermediate_size // 2),
device=device,
dtype=torch.uint8)
w2_q = torch.randint(
0,
256,
(num_experts, hidden_size, intermediate_size // 2),
device=device,
dtype=torch.uint8,
)
w2_scale = torch.randint(
118,
123, (num_experts, hidden_size, intermediate_size // 32),
123,
(num_experts, hidden_size, intermediate_size // 32),
device=device,
dtype=torch.uint8)
dtype=torch.uint8,
)
# Bias contiguous [b1; b3]
bias13 = (torch.randn(num_experts,
2 * intermediate_size,
device=device,
dtype=torch.bfloat16) * 10)
bias2 = (torch.randn(
num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10)
router_logits = torch.rand(num_tokens,
num_experts,
dtype=torch.float32,
device=device)
bias13 = (
torch.randn(
num_experts, 2 * intermediate_size, device=device, dtype=torch.bfloat16
)
* 10
)
bias2 = (
torch.randn(num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10
)
router_logits = torch.rand(
num_tokens, num_experts, dtype=torch.float32, device=device
)
w13_ref = mxfp4_dequantize(w13_q.clone(), w13_scale.clone()).reshape(
num_experts, 2 * intermediate_size, hidden_size)
num_experts, 2 * intermediate_size, hidden_size
)
w2_ref = mxfp4_dequantize(w2_q.clone(), w2_scale.clone()).reshape(
num_experts, hidden_size, intermediate_size)
ref = reference_moe(router_logits.to(torch.float32), topk, num_experts,
hidden_states.to(torch.float32), w13_ref,
bias13.to(torch.float32), w2_ref,
bias2.to(torch.float32), alpha, beta, limit, 'bf16')
num_experts, hidden_size, intermediate_size
)
ref = reference_moe(
router_logits.to(torch.float32),
topk,
num_experts,
hidden_states.to(torch.float32),
w13_ref,
bias13.to(torch.float32),
w2_ref,
bias2.to(torch.float32),
alpha,
beta,
limit,
"bf16",
)
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
@@ -654,23 +742,24 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
w13_s_inter = _interleave_scales_lastdim_by4(w13_s)
w2_s_inter = _interleave_scales_lastdim_by4(w2_scale)
routing_weights = torch.nn.functional.softmax(router_logits,
dim=1,
dtype=torch.float32)
token_final_scales, token_selected_experts = torch.topk(routing_weights,
topk,
dim=-1)
token_final_scales = (token_final_scales /
token_final_scales.sum(dim=-1, keepdim=True))
routing_weights = torch.nn.functional.softmax(
router_logits, dim=1, dtype=torch.float32
)
token_final_scales, token_selected_experts = torch.topk(
routing_weights, topk, dim=-1
)
token_final_scales = token_final_scales / token_final_scales.sum(
dim=-1, keepdim=True
)
token_selected_experts = token_selected_experts.to(torch.int).contiguous()
out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
if alpha is not None:
alpha = torch.full((num_experts, ), alpha, device=hidden_states.device)
alpha = torch.full((num_experts,), alpha, device=hidden_states.device)
if beta is not None:
beta = torch.full((num_experts, ), beta, device=hidden_states.device)
beta = torch.full((num_experts,), beta, device=hidden_states.device)
if limit is not None:
limit = torch.full((num_experts, ), limit, device=hidden_states.device)
limit = torch.full((num_experts,), limit, device=hidden_states.device)
_ = flashinfer_cutlass_fused_moe(
input=hidden_states,
@@ -680,8 +769,7 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
fc2_expert_weights=w2_q,
output_dtype=torch.bfloat16,
output=out,
quant_scales=[w13_s_inter.to(torch.uint8),
w2_s_inter.to(torch.uint8)],
quant_scales=[w13_s_inter.to(torch.uint8), w2_s_inter.to(torch.uint8)],
fc1_expert_biases=w13_b,
fc2_expert_biases=bias2.to(torch.bfloat16),
swiglu_alpha=alpha,
@@ -702,11 +790,13 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("num_tokens", [1, 128])
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
(1.702, 1.0, 7.0)])
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)])
@pytest.mark.skipif(
not (current_platform.is_cuda()
and current_platform.is_device_capability(100) and has_flashinfer()),
not (
current_platform.is_cuda()
and current_platform.is_device_capability(100)
and has_flashinfer()
),
reason="NVIDIA GPU sm100 and flashinfer are required for this test",
)
def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
@@ -723,32 +813,43 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
device = "cuda:0"
# Inputs
hidden_states = torch.randn(num_tokens,
hidden_size,
device=device,
dtype=torch.bfloat16)
hidden_states = torch.randn(
num_tokens, hidden_size, device=device, dtype=torch.bfloat16
)
# Float weights in w13 format [w1; w3]
w13 = (torch.randn(num_experts,
2 * intermediate_size,
hidden_size,
device=device,
dtype=torch.bfloat16) / 10)
w2 = (torch.randn(num_experts,
hidden_size,
intermediate_size,
device=device,
dtype=torch.bfloat16) / 10)
w13 = (
torch.randn(
num_experts,
2 * intermediate_size,
hidden_size,
device=device,
dtype=torch.bfloat16,
)
/ 10
)
w2 = (
torch.randn(
num_experts,
hidden_size,
intermediate_size,
device=device,
dtype=torch.bfloat16,
)
/ 10
)
# Bias contiguous [b1; b3]
bias13 = (torch.randn(num_experts,
2 * intermediate_size,
device=device,
dtype=torch.bfloat16) * 10)
bias2 = (torch.randn(
num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10)
router_logits = torch.rand(num_tokens,
num_experts,
dtype=torch.float32,
device=device)
bias13 = (
torch.randn(
num_experts, 2 * intermediate_size, device=device, dtype=torch.bfloat16
)
* 10
)
bias2 = (
torch.randn(num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10
)
router_logits = torch.rand(
num_tokens, num_experts, dtype=torch.float32, device=device
)
# Quantize weights to MXFP4 per expert (SM100 path)
from flashinfer import mxfp4_quantize
@@ -761,36 +862,56 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
sfs.append(sf)
return torch.stack(qs), torch.stack(sfs)
def dequant_mxfp4_batches(mat_fp4: torch.Tensor,
scale_tensor: torch.Tensor):
def dequant_mxfp4_batches(mat_fp4: torch.Tensor, scale_tensor: torch.Tensor):
num_batches = mat_fp4.size(0)
scale_tensor = scale_tensor.view(num_batches, -1)
from flashinfer import mxfp4_dequantize
return torch.stack([
mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :])
for b in range(num_batches)
])
return torch.stack(
[
mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :])
for b in range(num_batches)
]
)
w13_q, w13_scale = quant_mxfp4_batches(w13, num_experts)
w2_q, w2_scale = quant_mxfp4_batches(w2, num_experts)
# Reference result using dequantized tensors and reference_moe
w13_ref = dequant_mxfp4_batches(
w13_q.view(torch.uint8),
w13_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
num_experts, 2 * intermediate_size, hidden_size).to(device)
w2_ref = dequant_mxfp4_batches(
w2_q.view(torch.uint8),
w2_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
num_experts, hidden_size, intermediate_size).to(device)
w13_ref = (
dequant_mxfp4_batches(
w13_q.view(torch.uint8), w13_scale.view(torch.uint8).reshape(-1)
)
.to(torch.float32)
.reshape(num_experts, 2 * intermediate_size, hidden_size)
.to(device)
)
w2_ref = (
dequant_mxfp4_batches(
w2_q.view(torch.uint8), w2_scale.view(torch.uint8).reshape(-1)
)
.to(torch.float32)
.reshape(num_experts, hidden_size, intermediate_size)
.to(device)
)
# Quantize activations for SM100 path and dequantize for reference
hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32)
# Reference uses BF16 input but quantizes intermediate activation to MXFP8
ref = reference_moe(router_logits.to(torch.float32), topk, num_experts,
hidden_states.to(torch.float32), w13_ref,
bias13.to(torch.float32), w2_ref,
bias2.to(torch.float32), alpha, beta, limit, 'mxfp8')
ref = reference_moe(
router_logits.to(torch.float32),
topk,
num_experts,
hidden_states.to(torch.float32),
w13_ref,
bias13.to(torch.float32),
w2_ref,
bias2.to(torch.float32),
alpha,
beta,
limit,
"mxfp8",
)
# Prepare inputs for FlashInfer CUTLASS fused MoE
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
@@ -807,31 +928,28 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
# Build routing for kernel
routing_weights = torch.nn.functional.softmax(router_logits,
dim=1,
dtype=torch.float32)
token_final_scales, token_selected_experts = torch.topk(routing_weights,
topk,
dim=-1)
token_final_scales = (token_final_scales /
token_final_scales.sum(dim=-1, keepdim=True))
routing_weights = torch.nn.functional.softmax(
router_logits, dim=1, dtype=torch.float32
)
token_final_scales, token_selected_experts = torch.topk(
routing_weights, topk, dim=-1
)
token_final_scales = token_final_scales / token_final_scales.sum(
dim=-1, keepdim=True
)
token_selected_experts = token_selected_experts.to(torch.int).contiguous()
out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
if alpha is not None:
alpha_t = torch.full((num_experts, ),
alpha,
device=hidden_states.device)
alpha_t = torch.full((num_experts,), alpha, device=hidden_states.device)
else:
alpha_t = None
if beta is not None:
beta_t = torch.full((num_experts, ), beta, device=hidden_states.device)
beta_t = torch.full((num_experts,), beta, device=hidden_states.device)
else:
beta_t = None
if limit is not None:
limit_t = torch.full((num_experts, ),
limit,
device=hidden_states.device)
limit_t = torch.full((num_experts,), limit, device=hidden_states.device)
else:
limit_t = None

View File

@@ -4,9 +4,11 @@ import pytest
import torch
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from tests.kernels.quantization.nvfp4_utils import (
FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype,
)
from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
@@ -16,8 +18,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform
if not current_platform.has_device_capability(100):
pytest.skip("Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True)
pytest.skip(
"Nvfp4 Requires compute capability of 10 or above.", allow_module_level=True
)
MNK_FACTORS = [
(2, 1024, 1024),
@@ -38,36 +41,34 @@ MNK_FACTORS = [
@pytest.mark.parametrize("topk", [1, 6, 8])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
@torch.inference_mode()
def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
dtype: torch.dtype):
def test_cutlass_fp4_moe_no_graph(
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
):
current_platform.seed_everything(7)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
quant_blocksize = 16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
(_, w1_q, w1_blockscale,
w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights(
e,
n,
k,
in_dtype=dtype,
quant_dtype="nvfp4",
block_shape=None, # use quant_blocksize?
per_out_ch_quant=False,
)
(_, w1_q, w1_blockscale, w1_gs), (_, w2_q, w2_blockscale, w2_gs) = (
make_test_weights(
e,
n,
k,
in_dtype=dtype,
quant_dtype="nvfp4",
block_shape=None, # use quant_blocksize?
per_out_ch_quant=False,
)
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a,
score,
topk,
renormalize=False)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
assert w1_gs is not None
assert w2_gs is not None
@@ -97,40 +98,44 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
)
# Reference check:
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(a.flatten(), dim=-1)).to(torch.float32)
a_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1)
).to(torch.float32)
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
a_scale_interleaved,
a_global_scale,
dtype=a.dtype,
device=a.device,
block_size=quant_blocksize)
a_in_dtype = dequantize_nvfp4_to_dtype(
a_fp4,
a_scale_interleaved,
a_global_scale,
dtype=a.dtype,
device=a.device,
block_size=quant_blocksize,
)
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
for idx in range(0, e):
w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
w1_blockscale[idx],
w1_gs[idx],
dtype=dtype,
device=w1_q.device,
block_size=quant_blocksize)
w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
w2_blockscale[idx],
w2_gs[idx],
dtype=dtype,
device=w2_q.device,
block_size=quant_blocksize)
w1_d[idx] = dequantize_nvfp4_to_dtype(
w1_q[idx],
w1_blockscale[idx],
w1_gs[idx],
dtype=dtype,
device=w1_q.device,
block_size=quant_blocksize,
)
w2_d[idx] = dequantize_nvfp4_to_dtype(
w2_q[idx],
w2_blockscale[idx],
w2_gs[idx],
dtype=dtype,
device=w2_q.device,
block_size=quant_blocksize,
)
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
torch.testing.assert_close(torch_output,
cutlass_output,
atol=1e-1,
rtol=1e-1)
torch.testing.assert_close(torch_output, cutlass_output, atol=1e-1, rtol=1e-1)
if __name__ == "__main__":

View File

@@ -9,13 +9,10 @@ import torch
from tests.kernels.utils import torch_experts
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassBatchedExpertsFp8)
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassBatchedExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.platforms import current_platform
from vllm.utils import cdiv
@@ -24,9 +21,13 @@ from .parallel_utils import ProcessGroupInfo, parallel_launch
try:
from pplx_kernels import AllToAll
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_finalize, nvshmem_get_unique_id,
nvshmem_init)
from pplx_kernels.nvshmem import (
nvshmem_alloc_empty_unique_id,
nvshmem_finalize,
nvshmem_get_unique_id,
nvshmem_init,
)
has_pplx = True
except ImportError:
has_pplx = False
@@ -50,12 +51,12 @@ def chunk_by_rank(t, r, w):
chunk = rank_chunk(num, r, w)
rem = num % w
if rem == 0 or r < rem:
return t[(r * chunk):(r + 1) * chunk].contiguous()
return t[(r * chunk) : (r + 1) * chunk].contiguous()
else:
long_chunks = (num // w + 1) * rem
short_chunks = (r - rem) * chunk
start = long_chunks + short_chunks
return t[start:start + chunk].contiguous()
return t[start : start + chunk].contiguous()
def pplx_cutlass_moe(
@@ -75,7 +76,9 @@ def pplx_cutlass_moe(
group_name: Optional[str],
):
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
PplxPrepareAndFinalize,
)
assert torch.cuda.current_device() == pgi.local_rank
num_tokens, hidden_dim = a.shape
@@ -126,35 +129,40 @@ def pplx_cutlass_moe(
ata,
max_num_tokens=max_num_tokens,
num_local_experts=num_local_experts,
num_dispatchers=num_dispatchers)
num_dispatchers=num_dispatchers,
)
ab_strides1 = torch.full((num_local_experts, ),
hidden_dim,
device="cuda",
dtype=torch.int64)
ab_strides2 = torch.full((num_local_experts, ),
intermediate_dim,
device="cuda",
dtype=torch.int64)
c_strides1 = torch.full((num_local_experts, ),
2 * intermediate_dim,
device="cuda",
dtype=torch.int64)
c_strides2 = torch.full((num_local_experts, ),
hidden_dim,
device="cuda",
dtype=torch.int64)
ab_strides1 = torch.full(
(num_local_experts,), hidden_dim, device="cuda", dtype=torch.int64
)
ab_strides2 = torch.full(
(num_local_experts,), intermediate_dim, device="cuda", dtype=torch.int64
)
c_strides1 = torch.full(
(num_local_experts,), 2 * intermediate_dim, device="cuda", dtype=torch.int64
)
c_strides2 = torch.full(
(num_local_experts,), hidden_dim, device="cuda", dtype=torch.int64
)
experts = CutlassBatchedExpertsFp8(
num_local_experts, num_dispatchers, out_dtype, ab_strides1,
ab_strides2, c_strides1, c_strides2,
num_local_experts,
num_dispatchers,
out_dtype,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
fp8_w8a8_moe_quant_config(
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
if per_act_token else a1_scale[rank]))
if per_act_token
else a1_scale[rank],
),
)
fused_cutlass_experts = FusedMoEModularKernel(
prepare_finalize,
@@ -162,10 +170,10 @@ def pplx_cutlass_moe(
)
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
chunk_topk_weight = chunk_by_rank(topk_weights, rank,
world_size).to(device)
chunk_topk_ids = chunk_by_rank(topk_ids, rank,
world_size).to(torch.uint32).to(device)
chunk_topk_weight = chunk_by_rank(topk_weights, rank, world_size).to(device)
chunk_topk_ids = (
chunk_by_rank(topk_ids, rank, world_size).to(torch.uint32).to(device)
)
out = fused_cutlass_experts(
a_chunk,
@@ -174,7 +182,7 @@ def pplx_cutlass_moe(
chunk_topk_weight,
chunk_topk_ids,
global_num_experts=num_experts,
expert_map=None, #TODO
expert_map=None, # TODO
)
torch.cuda.synchronize()
@@ -210,35 +218,48 @@ def _pplx_moe(
):
try:
if use_internode:
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
uid = (
nvshmem_get_unique_id()
if pgi.rank == 0
else nvshmem_alloc_empty_unique_id()
)
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
else:
group_ranks = list(range(pgi.world_size))
cpu_group = torch.distributed.new_group(group_ranks,
backend="gloo")
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
group_name = cpu_group.group_name
with set_current_vllm_config(vllm_config):
torch_output = torch_experts(a_full, w1_full, w2_full,
topk_weights, topk_ids)
pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale,
w2_scale, topk_weights, topk_ids,
a1_scale, out_dtype, per_act_token,
per_out_ch, group_name)
torch_output = torch_experts(
a_full, w1_full, w2_full, topk_weights, topk_ids
)
pplx_output = pplx_cutlass_moe(
pgi,
dp_size,
a,
w1,
w2,
w1_scale,
w2_scale,
topk_weights,
topk_ids,
a1_scale,
out_dtype,
per_act_token,
per_out_ch,
group_name,
)
torch_output = chunk_by_rank(torch_output, pgi.rank,
pgi.world_size).to(pplx_output.device)
torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(
pplx_output.device
)
# Uncomment if more debugging is needed
# print("PPLX OUT:", pplx_output)
# print("TORCH OUT:", torch_output)
torch.testing.assert_close(pplx_output,
torch_output,
atol=0.05,
rtol=0)
torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0)
finally:
if use_internode:
nvshmem_finalize()
@@ -251,13 +272,15 @@ def _pplx_moe(
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]])
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) # , [4, 2]])
@pytest.mark.parametrize("use_internode", [False])
@multi_gpu_test(num_gpus=2)
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
reason="Grouped gemm is not supported on this GPU type.")
current_platform.get_device_capability()
),
reason="Grouped gemm is not supported on this GPU type.",
)
@requires_pplx
def test_cutlass_moe_pplx(
m: int,
@@ -273,7 +296,6 @@ def test_cutlass_moe_pplx(
current_platform.seed_everything(7)
with set_current_vllm_config(vllm_config):
dtype = torch.half
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10.0
@@ -283,22 +305,18 @@ def test_cutlass_moe_pplx(
n_b_scales = 2 * n if per_out_ch else 1
k_b_scales = k if per_out_ch else 1
w1_q = torch.empty((e, 2 * n, k),
device="cuda",
dtype=torch.float8_e4m3fn)
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn)
w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn)
w1_scale = torch.empty((e, n_b_scales, 1),
device="cuda",
dtype=torch.float32)
w2_scale = torch.empty((e, k_b_scales, 1),
device="cuda",
dtype=torch.float32)
w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32)
w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32)
for expert in range(e):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
w1[expert], use_per_token_if_dynamic=per_out_ch)
w1[expert], use_per_token_if_dynamic=per_out_ch
)
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
w2[expert], use_per_token_if_dynamic=per_out_ch)
w2[expert], use_per_token_if_dynamic=per_out_ch
)
w1_d = torch.empty_like(w1)
w2_d = torch.empty_like(w2)
@@ -307,19 +325,35 @@ def test_cutlass_moe_pplx(
w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half()
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a,
score,
topk,
renormalize=False)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
world_size, dp_size = world_dp_size
a_scale1 = torch.randn(
(m if per_act_token else 1, 1), device="cuda",
dtype=torch.float32) / 10.0
a_scale1 = (
torch.randn(
(m if per_act_token else 1, 1), device="cuda", dtype=torch.float32
)
/ 10.0
)
if not per_act_token:
a_scale1 = a_scale1.repeat(world_size, 1)
parallel_launch(world_size, _pplx_moe, dp_size, a, w1_q, w2_q,
w1_scale, w2_scale, topk_weights, topk_ids, a_scale1,
dtype, a, w1_d, w2_d, per_act_token, per_out_ch,
use_internode)
parallel_launch(
world_size,
_pplx_moe,
dp_size,
a,
w1_q,
w2_q,
w1_scale,
w2_scale,
topk_weights,
topk_ids,
a_scale1,
dtype,
a,
w1_d,
w2_d,
per_act_token,
per_out_ch,
use_internode,
)

View File

@@ -4,6 +4,7 @@
Run `pytest tests/kernels/test_pplx_moe.py`.
"""
import copy
import itertools
import textwrap
@@ -15,29 +16,34 @@ import torch
try:
from pplx_kernels import AllToAll
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_finalize, nvshmem_get_unique_id,
nvshmem_init)
from pplx_kernels.nvshmem import (
nvshmem_alloc_empty_unique_id,
nvshmem_finalize,
nvshmem_get_unique_id,
nvshmem_init,
)
has_pplx = True
except ImportError:
has_pplx = False
from tests.kernels.moe.modular_kernel_tools.parallel_utils import (
_set_vllm_config)
from tests.kernels.moe.utils import (make_shared_experts, make_test_weights,
naive_batched_moe)
from tests.kernels.moe.modular_kernel_tools.parallel_utils import _set_vllm_config
from tests.kernels.moe.utils import (
make_shared_experts,
make_test_weights,
naive_batched_moe,
)
from tests.kernels.quant_utils import dequant
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_topk, override_config
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
TopKWeightAndReduceDelegate,
)
from vllm.platforms import current_platform
from vllm.utils import round_up
@@ -59,7 +65,7 @@ BATCHED_MOE_MNK_FACTORS = [
PPLX_COMBOS = [
# TODO(bnell): figure out why this fails, seems to be test problem
#(1, 128, 128),
# (1, 128, 128),
(2, 128, 512),
(3, 1024, 2048),
(4, 128, 128),
@@ -91,17 +97,16 @@ def torch_prepare(
num_tokens, hidden_dim = a.shape
topk = topk_ids.shape[1]
tokens_per_expert = torch.bincount(topk_ids.view(-1),
minlength=num_experts)
tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts)
assert tokens_per_expert.numel() == num_experts
if max_num_tokens is None:
max_num_tokens = int(tokens_per_expert.max().item())
b_a = torch.zeros((num_experts, max_num_tokens, hidden_dim),
dtype=a.dtype,
device=a.device)
b_a = torch.zeros(
(num_experts, max_num_tokens, hidden_dim), dtype=a.dtype, device=a.device
)
token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device)
@@ -109,28 +114,29 @@ def torch_prepare(
for j in range(topk):
expert_id = topk_ids[token, j]
idx = token_counts[expert_id]
b_a[expert_id, idx:idx + 1, :] = a[token, :]
b_a[expert_id, idx : idx + 1, :] = a[token, :]
token_counts[expert_id] = token_counts[expert_id] + 1
return b_a, tokens_per_expert
def torch_finalize(b_out: torch.Tensor, topk_weight: torch.Tensor,
topk_ids: torch.Tensor) -> torch.Tensor:
def torch_finalize(
b_out: torch.Tensor, topk_weight: torch.Tensor, topk_ids: torch.Tensor
) -> torch.Tensor:
num_tokens = topk_ids.shape[0]
num_experts = b_out.shape[0]
K = b_out.shape[-1]
out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device)
expert_counts = torch.zeros(num_experts,
dtype=torch.int,
device=b_out.device)
expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device)
for token in range(num_tokens):
expert_ids = topk_ids[token]
for i in range(expert_ids.numel()):
expert_id = expert_ids[i]
idx = expert_counts[expert_id]
out[token, :] = out[token, :] + b_out[expert_id, idx:idx +
1, :] * topk_weight[token, i]
out[token, :] = (
out[token, :]
+ b_out[expert_id, idx : idx + 1, :] * topk_weight[token, i]
)
expert_counts[expert_id] = expert_counts[expert_id] + 1
return out
@@ -149,17 +155,18 @@ def torch_batched_moe(
num_tokens, topk = topk_ids.shape
_, max_num_tokens, K = b_a.shape
assert num_experts == b_a.shape[0] and w2.shape[1] == K
out = torch.zeros((num_experts, max_num_tokens, K),
dtype=b_a.dtype,
device=b_a.device)
tmp = torch.empty((max_num_tokens, w1.shape[1] // 2),
dtype=b_a.dtype,
device=b_a.device)
out = torch.zeros(
(num_experts, max_num_tokens, K), dtype=b_a.dtype, device=b_a.device
)
tmp = torch.empty(
(max_num_tokens, w1.shape[1] // 2), dtype=b_a.dtype, device=b_a.device
)
for expert in range(num_experts):
num = tokens_per_expert[expert]
if num > 0:
torch.ops._C.silu_and_mul(
tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1))
tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1)
)
out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1)
return torch_finalize(out, topk_weight, topk_ids)
@@ -186,20 +193,16 @@ def test_fused_moe_batched_experts(
with set_current_vllm_config(vllm_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
baseline_output = torch_experts(a, w1, w2, topk_weight,
topk_ids) # only for baseline
baseline_output = torch_experts(
a, w1, w2, topk_weight, topk_ids
) # only for baseline
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
batched_output = naive_batched_moe(
a, w1, w2, topk_weight, topk_ids) # pick torch_experts or this
a, w1, w2, topk_weight, topk_ids
) # pick torch_experts or this
torch.testing.assert_close(baseline_output,
torch_output,
atol=2e-2,
rtol=0)
torch.testing.assert_close(baseline_output,
batched_output,
atol=2e-2,
rtol=0)
torch.testing.assert_close(baseline_output, torch_output, atol=2e-2, rtol=0)
torch.testing.assert_close(baseline_output, batched_output, atol=2e-2, rtol=0)
def create_pplx_prepare_finalize(
@@ -217,7 +220,9 @@ def create_pplx_prepare_finalize(
group_name: Optional[str],
):
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes)
PplxPrepareAndFinalize,
pplx_hidden_dim_scale_bytes,
)
max_num_tokens = max(rank_chunk(num_tokens, 0, world_size), 1)
num_local_experts = rank_chunk(num_experts, 0, world_size)
@@ -266,28 +271,31 @@ def rank_chunk(num: int, r: int, w: int) -> int:
def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
chunk = rank_chunk(t.shape[0], r, w)
return t[(r * chunk):(r + 1) * chunk]
return t[(r * chunk) : (r + 1) * chunk]
def maybe_chunk_by_rank(t: Optional[torch.Tensor], r: int,
w: int) -> Optional[torch.Tensor]:
def maybe_chunk_by_rank(
t: Optional[torch.Tensor], r: int, w: int
) -> Optional[torch.Tensor]:
if t is not None:
return chunk_by_rank(t, r, w)
else:
return t
def chunk_scales_by_rank(t: Optional[torch.Tensor], r: int,
w: int) -> Optional[torch.Tensor]:
def chunk_scales_by_rank(
t: Optional[torch.Tensor], r: int, w: int
) -> Optional[torch.Tensor]:
if t is not None and t.numel() > 1:
chunk = rank_chunk(t.shape[0], r, w)
return t[(r * chunk):(r + 1) * chunk]
return t[(r * chunk) : (r + 1) * chunk]
else:
return t
def chunk_scales(t: Optional[torch.Tensor], start: int,
end: int) -> Optional[torch.Tensor]:
def chunk_scales(
t: Optional[torch.Tensor], start: int, end: int
) -> Optional[torch.Tensor]:
if t is not None and t.numel() > 1:
return t[start:end]
else:
@@ -350,8 +358,7 @@ def pplx_prepare_finalize(
device=device,
)
if (quant_dtype is not None and not per_act_token_quant
and block_shape is None):
if quant_dtype is not None and not per_act_token_quant and block_shape is None:
a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
else:
@@ -375,8 +382,7 @@ def pplx_prepare_finalize(
),
)
b_a = dummy_work(
dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype))
b_a = dummy_work(dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype))
prepare_finalize.finalize(
out,
@@ -410,15 +416,17 @@ def _pplx_prepare_finalize(
):
try:
if use_internode:
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
uid = (
nvshmem_get_unique_id()
if pgi.rank == 0
else nvshmem_alloc_empty_unique_id()
)
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
group_name = None
else:
group_ranks = list(range(pgi.world_size))
cpu_group = torch.distributed.new_group(group_ranks,
backend="gloo")
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
group_name = cpu_group.group_name
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
@@ -426,22 +434,28 @@ def _pplx_prepare_finalize(
a_rep = torch.repeat_interleave(dummy_work(a), topk, dim=0)
torch_output = (a_rep.view(m, topk, k) *
topk_weight.view(m, topk, 1).to(a_rep.dtype)).sum(
dim=1)
torch_output = (
a_rep.view(m, topk, k) * topk_weight.view(m, topk, 1).to(a_rep.dtype)
).sum(dim=1)
pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight,
topk_ids, num_experts, quant_dtype,
block_shape, per_act_token_quant,
group_name)
pplx_output = pplx_prepare_finalize(
pgi,
dp_size,
a,
topk_weight,
topk_ids,
num_experts,
quant_dtype,
block_shape,
per_act_token_quant,
group_name,
)
torch_output = chunk_by_rank(torch_output, pgi.rank,
pgi.world_size).to(pgi.device)
torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(
pgi.device
)
torch.testing.assert_close(pplx_output,
torch_output,
atol=3e-2,
rtol=3e-2)
torch.testing.assert_close(pplx_output, torch_output, atol=3e-2, rtol=3e-2)
finally:
if use_internode:
nvshmem_finalize()
@@ -491,9 +505,19 @@ def test_pplx_prepare_finalize_slow(
a = torch.randn((m, k), device=device, dtype=act_dtype) / 10
score = torch.randn((m, e), device=device, dtype=act_dtype)
parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score,
topk, e, quant_dtype, block_shape, per_act_token_quant,
use_internode)
parallel_launch(
world_size,
_pplx_prepare_finalize,
dp_size,
a,
score,
topk,
e,
quant_dtype,
block_shape,
per_act_token_quant,
use_internode,
)
def pplx_moe(
@@ -517,7 +541,6 @@ def pplx_moe(
use_cudagraphs: bool = True,
shared_experts: Optional[torch.nn.Module] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
num_tokens, hidden_dim = a.shape
num_experts = w1.shape[0]
topk = topk_ids.shape[1]
@@ -579,21 +602,23 @@ def pplx_moe(
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
if use_compile:
_fused_experts = torch.compile(fused_experts,
backend='inductor',
fullgraph=True)
_fused_experts = torch.compile(
fused_experts, backend="inductor", fullgraph=True
)
torch._dynamo.mark_dynamic(a_chunk, 0)
torch._dynamo.mark_dynamic(chunk_topk_weight, 0)
torch._dynamo.mark_dynamic(chunk_topk_ids, 0)
else:
_fused_experts = fused_experts
out = _fused_experts(a_chunk,
w1_chunk,
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
global_num_experts=num_experts)
out = _fused_experts(
a_chunk,
w1_chunk,
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
global_num_experts=num_experts,
)
if use_cudagraphs:
if isinstance(out, tuple):
@@ -604,12 +629,14 @@ def pplx_moe(
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
out = _fused_experts(a_chunk,
w1_chunk,
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
global_num_experts=num_experts)
out = _fused_experts(
a_chunk,
w1_chunk,
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
global_num_experts=num_experts,
)
torch.cuda.synchronize()
graph.replay()
@@ -640,15 +667,17 @@ def _pplx_moe(
):
try:
if use_internode:
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
uid = (
nvshmem_get_unique_id()
if pgi.rank == 0
else nvshmem_alloc_empty_unique_id()
)
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
group_name = None
else:
group_ranks = list(range(pgi.world_size))
cpu_group = torch.distributed.new_group(group_ranks,
backend="gloo")
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
group_name = cpu_group.group_name
m, k = a.shape
@@ -666,8 +695,7 @@ def _pplx_moe(
w1_s = w1_s.to(device) if w1_s is not None else None
w2_s = w2_s.to(device) if w2_s is not None else None
if (quant_dtype is not None and not per_act_token_quant
and block_shape is None):
if quant_dtype is not None and not per_act_token_quant and block_shape is None:
a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
else:
@@ -742,31 +770,27 @@ def _pplx_moe(
if shared_output is not None:
assert pplx_shared_output is not None
chunked_shared_output = chunk_by_rank(
shared_output, pgi.rank,
pgi.world_size).to(pplx_shared_output.device)
shared_output, pgi.rank, pgi.world_size
).to(pplx_shared_output.device)
else:
chunked_shared_output = None
chunked_batch_output = chunk_by_rank(
batched_output, pgi.rank, pgi.world_size).to(pplx_output.device)
batched_output, pgi.rank, pgi.world_size
).to(pplx_output.device)
torch.testing.assert_close(batched_output,
torch_output,
atol=3e-2,
rtol=3e-2)
torch.testing.assert_close(batched_output, torch_output, atol=3e-2, rtol=3e-2)
torch.testing.assert_close(pplx_output,
chunked_batch_output,
atol=3e-2,
rtol=3e-2)
torch.testing.assert_close(
pplx_output, chunked_batch_output, atol=3e-2, rtol=3e-2
)
if shared_experts is not None:
assert chunked_shared_output is not None
assert pplx_shared_output is not None
torch.testing.assert_close(pplx_shared_output,
chunked_shared_output,
atol=3e-2,
rtol=3e-2)
torch.testing.assert_close(
pplx_shared_output, chunked_shared_output, atol=3e-2, rtol=3e-2
)
finally:
if use_internode:
@@ -823,15 +847,33 @@ def test_pplx_moe_slow(
per_out_ch_quant=per_act_token_quant,
)
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, e,
w1_s, w2_s, quant_dtype, per_act_token_quant, block_shape,
use_internode)
parallel_launch(
world_size,
_pplx_moe,
dp_size,
a,
w1,
w2,
score,
topk,
e,
w1_s,
w2_s,
quant_dtype,
per_act_token_quant,
block_shape,
use_internode,
)
def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
use_shared_experts: bool, make_weights: bool,
test_fn: Callable):
def _pplx_test_loop(
pgi: ProcessGroupInfo,
dp_size: int,
use_internode: bool,
use_shared_experts: bool,
make_weights: bool,
test_fn: Callable,
):
def format_result(msg, ex=None):
if ex is not None:
x = str(ex)
@@ -850,12 +892,12 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
new_vllm_config = copy.deepcopy(vllm_config)
new_vllm_config.parallel_config.data_parallel_size = pgi.world_size
new_vllm_config.parallel_config.enable_expert_parallel = True
_set_vllm_config(new_vllm_config, pgi.world_size, pgi.rank,
pgi.local_rank)
_set_vllm_config(new_vllm_config, pgi.world_size, pgi.rank, pgi.local_rank)
current_platform.seed_everything(7)
combos = itertools.product(PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES,
[False, True], [None, [128, 128]])
combos = itertools.product(
PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES, [False, True], [None, [128, 128]]
)
exceptions = []
count = 0
for mnk, e, topk, dtype, per_act_token_quant, block_shape in combos:
@@ -873,13 +915,11 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, "
f"dtype={dtype}, per_act_token={per_act_token_quant}, "
f"block_shape={block_shape}, use_internode={use_internode}, "
f"use_shared_experts={use_shared_experts}")
f"use_shared_experts={use_shared_experts}"
)
if not use_fp8_w8a8 and (per_act_token_quant
or block_shape is not None):
print(
f"{test_desc} - Skip quantization test for non-quantized type."
)
if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
print(f"{test_desc} - Skip quantization test for non-quantized type.")
continue
if per_act_token_quant and block_shape is not None:
@@ -934,10 +974,10 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
if len(exceptions) > 0:
raise RuntimeError(
f"{len(exceptions)} of {count} tests failed in child process, "
f"rank={pgi.rank}.")
f"rank={pgi.rank}."
)
else:
print(f"{count} of {count} tests passed in child process, "
f"rank={pgi.rank}.")
print(f"{count} of {count} tests passed in child process, rank={pgi.rank}.")
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@@ -950,8 +990,15 @@ def test_pplx_prepare_finalize(
):
current_platform.seed_everything(7)
world_size, dp_size = world_dp_size
parallel_launch(world_size * dp_size, _pplx_test_loop, dp_size,
use_internode, False, False, _pplx_prepare_finalize)
parallel_launch(
world_size * dp_size,
_pplx_test_loop,
dp_size,
use_internode,
False,
False,
_pplx_prepare_finalize,
)
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@@ -966,5 +1013,12 @@ def test_pplx_moe(
):
current_platform.seed_everything(7)
world_size, dp_size = world_dp_size
parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode,
use_shared_experts, True, _pplx_moe)
parallel_launch(
world_size,
_pplx_test_loop,
dp_size,
use_internode,
use_shared_experts,
True,
_pplx_moe,
)

View File

@@ -24,13 +24,14 @@ aiter_available = importlib.util.find_spec("aiter") is not None
pytestmark = pytest.mark.skipif(
not (current_platform.is_rocm() and aiter_available),
reason="AITER ops are only available on ROCm with aiter package installed")
reason="AITER ops are only available on ROCm with aiter package installed",
)
def test_rocm_aiter_biased_grouped_topk_custom_op_registration():
"""Test that the custom op is correctly registered."""
# Check if the op exists in torch.ops.vllm
assert hasattr(torch.ops.vllm, 'rocm_aiter_biased_grouped_topk')
assert hasattr(torch.ops.vllm, "rocm_aiter_biased_grouped_topk")
# Check if the op is callable
assert callable(torch.ops.vllm.rocm_aiter_biased_grouped_topk)
@@ -39,7 +40,7 @@ def test_rocm_aiter_biased_grouped_topk_custom_op_registration():
def test_rocm_aiter_grouped_topk_custom_op_registration():
"""Test that the custom op is correctly registered."""
# Check if the op exists in torch.ops.vllm
assert hasattr(torch.ops.vllm, 'rocm_aiter_grouped_topk')
assert hasattr(torch.ops.vllm, "rocm_aiter_grouped_topk")
# Check if the op is callable
assert callable(torch.ops.vllm.rocm_aiter_grouped_topk)
@@ -56,25 +57,29 @@ def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility():
renormalize = True
scale_factor = 1.0
gating_output = torch.randn((token, expert),
dtype=torch.bfloat16,
device="cuda")
e_score_correction_bias = torch.randn((expert, ),
dtype=torch.bfloat16,
device="cuda")
gating_output = torch.randn((token, expert), dtype=torch.bfloat16, device="cuda")
e_score_correction_bias = torch.randn(
(expert,), dtype=torch.bfloat16, device="cuda"
)
device = gating_output.device
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
topk_weights = torch.empty((token, topk),
dtype=torch.float32,
device=device)
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
# Define a function that uses the op
def biased_grouped_topk_fn(gating_output, e_score_correction_bias,
topk_weights, topk_ids):
def biased_grouped_topk_fn(
gating_output, e_score_correction_bias, topk_weights, topk_ids
):
return torch.ops.vllm.rocm_aiter_biased_grouped_topk(
gating_output, e_score_correction_bias, topk_weights, topk_ids,
num_expert_group, topk_group, renormalize, scale_factor)
gating_output,
e_score_correction_bias,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
renormalize,
scale_factor,
)
# Verify the op's fake implementation
torch.library.opcheck(
@@ -84,51 +89,49 @@ def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility():
"num_expert_group": num_expert_group,
"topk_group": topk_group,
"need_renorm": renormalize,
"routed_scaling_factor": scale_factor
"routed_scaling_factor": scale_factor,
},
test_utils=("test_faketensor"))
test_utils=("test_faketensor"),
)
# Compile the function with appropriate settings
compiled_fn = torch.compile(biased_grouped_topk_fn,
fullgraph=True,
backend="inductor",
mode="reduce-overhead",
dynamic=False)
compiled_fn = torch.compile(
biased_grouped_topk_fn,
fullgraph=True,
backend="inductor",
mode="reduce-overhead",
dynamic=False,
)
topk_weights_original = torch.empty((token, topk),
dtype=torch.float32,
device=device)
topk_ids_original = torch.empty((token, topk),
dtype=torch.int32,
device=device)
topk_weights_original = torch.empty(
(token, topk), dtype=torch.float32, device=device
)
topk_ids_original = torch.empty((token, topk), dtype=torch.int32, device=device)
topk_weights_compiled = torch.empty((token, topk),
dtype=torch.float32,
device=device)
topk_ids_compiled = torch.empty((token, topk),
dtype=torch.int32,
device=device)
topk_weights_compiled = torch.empty(
(token, topk), dtype=torch.float32, device=device
)
topk_ids_compiled = torch.empty((token, topk), dtype=torch.int32, device=device)
# Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
biased_grouped_topk_fn(gating_output, e_score_correction_bias,
topk_weights_original, topk_ids_original)
compiled_fn(gating_output, e_score_correction_bias, topk_weights_compiled,
topk_ids_compiled)
biased_grouped_topk_fn(
gating_output, e_score_correction_bias, topk_weights_original, topk_ids_original
)
compiled_fn(
gating_output, e_score_correction_bias, topk_weights_compiled, topk_ids_compiled
)
# Sort the results for comparison since the order might not be deterministic
topk_ids_original, indices_original = torch.sort(topk_ids_original)
topk_weights_original = torch.gather(topk_weights_original, 1,
indices_original)
topk_weights_original = torch.gather(topk_weights_original, 1, indices_original)
topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled)
topk_weights_compiled = torch.gather(topk_weights_compiled, 1,
indices_compiled)
topk_weights_compiled = torch.gather(topk_weights_compiled, 1, indices_compiled)
# Verify results match
assert torch.allclose(topk_weights_original,
topk_weights_compiled,
rtol=1e-2,
atol=1e-2)
assert torch.allclose(
topk_weights_original, topk_weights_compiled, rtol=1e-2, atol=1e-2
)
assert torch.allclose(topk_ids_original, topk_ids_compiled)
@@ -144,73 +147,73 @@ def test_rocm_aiter_grouped_topk_torch_compile_compatibility():
scoring_func = "softmax"
scale_factor = 1.0
gating_output = torch.randn((token, expert),
dtype=torch.bfloat16,
device="cuda")
gating_output = torch.randn((token, expert), dtype=torch.bfloat16, device="cuda")
device = gating_output.device
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
topk_weights = torch.empty((token, topk),
dtype=torch.float32,
device=device)
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
# Define a function that uses the op
def grouped_topk_fn(gating_output, topk_weights, topk_ids, scoring_func):
return torch.ops.vllm.rocm_aiter_grouped_topk(
gating_output, topk_weights, topk_ids, num_expert_group,
topk_group, renormalize, scoring_func, scale_factor)
gating_output,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
renormalize,
scoring_func,
scale_factor,
)
# Verify the op's fake implementation
torch.library.opcheck(torch.ops.vllm.rocm_aiter_grouped_topk,
(gating_output, topk_weights, topk_ids),
kwargs={
"num_expert_group": num_expert_group,
"topk_group": topk_group,
"need_renorm": renormalize,
"scoring_func": scoring_func,
"routed_scaling_factor": scale_factor
},
test_utils=("test_faketensor"))
torch.library.opcheck(
torch.ops.vllm.rocm_aiter_grouped_topk,
(gating_output, topk_weights, topk_ids),
kwargs={
"num_expert_group": num_expert_group,
"topk_group": topk_group,
"need_renorm": renormalize,
"scoring_func": scoring_func,
"routed_scaling_factor": scale_factor,
},
test_utils=("test_faketensor"),
)
# Compile the function with appropriate settings
compiled_fn = torch.compile(grouped_topk_fn,
fullgraph=True,
backend="inductor",
mode="reduce-overhead",
dynamic=False)
compiled_fn = torch.compile(
grouped_topk_fn,
fullgraph=True,
backend="inductor",
mode="reduce-overhead",
dynamic=False,
)
topk_weights_original = torch.empty((token, topk),
dtype=torch.float32,
device=device)
topk_ids_original = torch.empty((token, topk),
dtype=torch.int32,
device=device)
topk_weights_original = torch.empty(
(token, topk), dtype=torch.float32, device=device
)
topk_ids_original = torch.empty((token, topk), dtype=torch.int32, device=device)
topk_weights_compiled = torch.empty((token, topk),
dtype=torch.float32,
device=device)
topk_ids_compiled = torch.empty((token, topk),
dtype=torch.int32,
device=device)
topk_weights_compiled = torch.empty(
(token, topk), dtype=torch.float32, device=device
)
topk_ids_compiled = torch.empty((token, topk), dtype=torch.int32, device=device)
# Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
grouped_topk_fn(gating_output, topk_weights_original, topk_ids_original,
scoring_func)
compiled_fn(gating_output, topk_weights_compiled, topk_ids_compiled,
scoring_func)
grouped_topk_fn(
gating_output, topk_weights_original, topk_ids_original, scoring_func
)
compiled_fn(gating_output, topk_weights_compiled, topk_ids_compiled, scoring_func)
# Sort the results for comparison since the order might not be deterministic
topk_ids_original, indices_original = torch.sort(topk_ids_original)
topk_weights_original = torch.gather(topk_weights_original, 1,
indices_original)
topk_weights_original = torch.gather(topk_weights_original, 1, indices_original)
topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled)
topk_weights_compiled = torch.gather(topk_weights_compiled, 1,
indices_compiled)
topk_weights_compiled = torch.gather(topk_weights_compiled, 1, indices_compiled)
# Verify results match
assert torch.allclose(topk_weights_original,
topk_weights_compiled,
rtol=1e-2,
atol=1e-2)
assert torch.allclose(
topk_weights_original, topk_weights_compiled, rtol=1e-2, atol=1e-2
)
assert torch.allclose(topk_ids_original, topk_ids_compiled)

View File

@@ -5,7 +5,8 @@ import pytest
import torch
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
silu_mul_fp8_quant_deep_gemm_cuda)
silu_mul_fp8_quant_deep_gemm_cuda,
)
from vllm.platforms import current_platform
from vllm.utils import cdiv
@@ -34,7 +35,6 @@ CASES = [
(256, 16, 7168, fp8_dtype),
(256, 32, 7168, fp8_dtype),
(256, 64, 7168, fp8_dtype),
# Only add a few fnuz tests to help with long CI times.
(8, 512, 7168, torch.float8_e4m3fnuz),
(8, 1024, 7168, torch.float8_e4m3fnuz),
@@ -52,15 +52,15 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type):
tokens_per_expert = torch.randint(
low=T // 2,
high=T,
size=(E, ),
size=(E,),
dtype=torch.int32,
device="cuda",
)
# Run the Triton kernel
y_q, y_s = silu_mul_fp8_quant_deep_gemm_cuda(y,
tokens_per_expert,
group_size=group_size)
y_q, y_s = silu_mul_fp8_quant_deep_gemm_cuda(
y, tokens_per_expert, group_size=group_size
)
torch.cuda.synchronize()
fp8_info = torch.finfo(fp8_dtype)
@@ -75,9 +75,9 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type):
for e in range(E):
nt = tokens_per_expert[e].item()
ref_s = torch.empty((T, cdiv(H, group_size)),
dtype=torch.float32,
device="cuda")
ref_s = torch.empty(
(T, cdiv(H, group_size)), dtype=torch.float32, device="cuda"
)
ref_q = torch.empty((T, H), dtype=fp8_dtype, device="cuda")
for t in range(nt):
@@ -87,14 +87,17 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type):
# process full groups
n_full_groups = H // group_size
if n_full_groups > 0:
data_grp = data[:n_full_groups * group_size].view(
n_full_groups, group_size)
data_grp = data[: n_full_groups * group_size].view(
n_full_groups, group_size
)
amax = data_grp.abs().amax(dim=1).clamp(min=eps)
scale = amax / fp8_max
scaled = data[:n_full_groups *
group_size] / scale.repeat_interleave(group_size)
ref_q_row[:n_full_groups * group_size] = scaled.clamp(
fp8_min, fp8_max).to(fp8_dtype)
scaled = data[: n_full_groups * group_size] / scale.repeat_interleave(
group_size
)
ref_q_row[: n_full_groups * group_size] = scaled.clamp(
fp8_min, fp8_max
).to(fp8_dtype)
ref_s[t, :n_full_groups] = scale
# process remainder group

View File

@@ -11,13 +11,11 @@ from tests.kernels.moe.utils import fused_moe
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.platforms import current_platform
if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
allow_module_level=True)
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
@@ -31,14 +29,13 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
assert B.ndim == 2 and B.is_contiguous(
), "B must be a 2D contiguous tensor"
assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor"
# Reshape input
M = A.numel() // A.shape[-1]
B = B.t() # Transpose weight matrix
N, K = B.shape
origin_C_shape = A.shape[:-1] + (K, )
origin_C_shape = A.shape[:-1] + (K,)
A = A.reshape(M, N)
# As is per-token [M, 1], Bs is per-column [1, K]
@@ -88,17 +85,17 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
act_out = SiluAndMul().forward_native(inter_out)
# Quantize activation output with per-token
act_out_q, act_out_s = ops.scaled_fp8_quant(
act_out, use_per_token_if_dynamic=True)
act_out, use_per_token_if_dynamic=True
)
# Second MLP layer
out[mask] = native_w8a8_per_token_matmul(act_out_q,
w2[i],
act_out_s,
w2_s[i],
output_dtype=a.dtype)
out[mask] = native_w8a8_per_token_matmul(
act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype
)
# Apply routing weights and sum
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
@pytest.fixture(autouse=True, scope="module")
@@ -116,8 +113,10 @@ TOP_KS = [2, 6]
SEEDS = [0]
@pytest.mark.parametrize("M, N, K, E, topk, dtype, seed",
itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS))
@pytest.mark.parametrize(
"M, N, K, E, topk, dtype, seed",
itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS),
)
@torch.inference_mode()
def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
torch.manual_seed(seed)
@@ -133,12 +132,10 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
# Generate int8 weights
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2
w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min,
max=fp8_max).to(torch.float8_e4m3fn)
w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min,
max=fp8_max).to(torch.float8_e4m3fn)
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
# Generate scale for each column (per-column quantization)
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale
@@ -163,7 +160,7 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
)
# Check results
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
rel_diff = torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
assert rel_diff < 0.05

View File

@@ -6,17 +6,17 @@ import torch
import vllm._custom_ops as ops
from tests.kernels.quant_utils import per_block_cast_to_int8
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX)
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
BatchedPrepareAndFinalize,
BatchedTritonExperts,
NaiveBatchedExperts,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils import round_up
from vllm.utils.deep_gemm import per_block_cast_to_fp8
@@ -45,12 +45,7 @@ def triton_moe(
a2_scale=a2_scale,
)
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
quant_config=quant_config)
return fused_experts(a, w1, w2, topk_weight, topk_ids, quant_config=quant_config)
def batched_moe(
@@ -80,10 +75,9 @@ def batched_moe(
)
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
num_dispatchers=1,
num_local_experts=w1.shape[0],
rank=0),
BatchedPrepareAndFinalize(
max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
),
BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
@@ -121,10 +115,9 @@ def naive_batched_moe(
)
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
num_dispatchers=1,
num_local_experts=w1.shape[0],
rank=0),
BatchedPrepareAndFinalize(
max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
),
NaiveBatchedExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
@@ -135,8 +128,9 @@ def naive_batched_moe(
return fused_experts(a, w1, w2, topk_weight, topk_ids)
def chunk_scales(scales: Optional[torch.Tensor], start: int,
end: int) -> Optional[torch.Tensor]:
def chunk_scales(
scales: Optional[torch.Tensor], start: int, end: int
) -> Optional[torch.Tensor]:
if scales is not None:
if scales.numel() == 1:
return scales
@@ -159,13 +153,15 @@ def make_quantized_test_activations(
a_scale = None
if quant_dtype is not None:
assert (quant_dtype == torch.float8_e4m3fn
or quant_dtype == torch.int8), "only fp8/int8 supported"
assert quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8, (
"only fp8/int8 supported"
)
a_q = torch.zeros_like(a, dtype=quant_dtype)
a_scale_l = [None] * E
for e in range(E):
a_q[e], a_scale_l[e] = moe_kernel_quantize_input(
a[e], None, quant_dtype, per_act_token_quant, block_shape)
a[e], None, quant_dtype, per_act_token_quant, block_shape
)
a_scale = torch.stack(a_scale_l)
if not per_act_token_quant and block_shape is None:
@@ -181,8 +177,11 @@ def moe_quantize_weights(
per_token_quant: bool,
block_shape: Optional[list[int]],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
assert (quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8
or quant_dtype == "nvfp4"), "only fp8/int8/nvfp4 supported"
assert (
quant_dtype == torch.float8_e4m3fn
or quant_dtype == torch.int8
or quant_dtype == "nvfp4"
), "only fp8/int8/nvfp4 supported"
w_gs = None
@@ -199,10 +198,12 @@ def moe_quantize_weights(
else:
if quant_dtype == torch.int8:
w, w_s = ops.scaled_int8_quant(
w, w_s, use_per_token_if_dynamic=per_token_quant)
w, w_s, use_per_token_if_dynamic=per_token_quant
)
elif quant_dtype == torch.float8_e4m3fn:
w, w_s = ops.scaled_fp8_quant(
w, w_s, use_per_token_if_dynamic=per_token_quant)
w, w_s, use_per_token_if_dynamic=per_token_quant
)
elif quant_dtype == "nvfp4":
assert not per_token_quant
w_amax = torch.abs(w).max().to(torch.float32)
@@ -222,8 +223,7 @@ def make_test_weight(
quant_dtype: Union[torch.dtype, str, None] = None,
block_shape: Optional[list[int]] = None,
per_out_ch_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
w_gs = None
@@ -233,7 +233,8 @@ def make_test_weight(
w_gs_l = [None] * e
for idx in range(e):
w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape)
w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape
)
w = torch.stack(w_l)
w_s = torch.stack(w_s_l)
@@ -264,26 +265,25 @@ def make_test_weights(
quant_dtype: Union[torch.dtype, str, None] = None,
block_shape: Optional[list[int]] = None,
per_out_ch_quant: bool = False,
) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]],
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]]:
) -> tuple[
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
]:
return (
make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
per_out_ch_quant),
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
per_out_ch_quant),
make_test_weight(
e, 2 * n, k, in_dtype, quant_dtype, block_shape, per_out_ch_quant
),
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, per_out_ch_quant),
)
def per_token_cast_to_fp8(
x: torch.Tensor,
block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
x: torch.Tensor, block_size: int = 128
) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
pad_size = (block_size - (n % block_size)) % block_size
x = torch.nn.functional.pad(x,
(0, pad_size), value=0) if pad_size > 0 else x
x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
x_view = x.view(m, -1, block_size)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
@@ -313,27 +313,31 @@ def make_test_quant_config(
a1_gscale: Optional[torch.Tensor] = None
a2_gscale: Optional[torch.Tensor] = None
if quant_dtype == "nvfp4":
a1_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32)
a2_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32)
a1_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
a2_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
a1_scale = a1_gscale
a2_scale = a2_gscale
else:
a1_scale = None
a2_scale = None
return w1, w2, FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
w1_scale=w1_s,
w2_scale=w2_s,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
a1_scale=a1_scale,
a2_scale=a2_scale,
# TODO: make sure this is handled properly
g1_alphas=(1 / w1_gs) if w1_gs is not None else None,
g2_alphas=(1 / w2_gs) if w2_gs is not None else None,
return (
w1,
w2,
FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
w1_scale=w1_s,
w2_scale=w2_s,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
a1_scale=a1_scale,
a2_scale=a2_scale,
# TODO: make sure this is handled properly
g1_alphas=(1 / w1_gs) if w1_gs is not None else None,
g2_alphas=(1 / w2_gs) if w2_gs is not None else None,
),
)
@@ -348,21 +352,23 @@ def fused_moe(
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
) -> torch.Tensor:
topk_weights, topk_ids, _ = fused_topk(hidden_states, score.float(), topk,
renormalize)
return fused_experts(hidden_states,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
quant_config=quant_config)
topk_weights, topk_ids, _ = fused_topk(
hidden_states, score.float(), topk, renormalize
)
return fused_experts(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
quant_config=quant_config,
)
# CustomOp?
class BaselineMM(torch.nn.Module):
def __init__(
self,
b: torch.Tensor,
@@ -372,15 +378,11 @@ class BaselineMM(torch.nn.Module):
self.b = b.to(dtype=torch.float32)
self.out_dtype = out_dtype
def forward(
self,
a: torch.Tensor) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
return torch.mm(a.to(dtype=torch.float32),
self.b).to(self.out_dtype), None
def forward(self, a: torch.Tensor) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
return torch.mm(a.to(dtype=torch.float32), self.b).to(self.out_dtype), None
class TestMLP(torch.nn.Module):
def __init__(
self,
w1: torch.Tensor,
@@ -410,7 +412,6 @@ def make_naive_shared_experts(
class RealMLP(torch.nn.Module):
def __init__(
self,
hidden_size: int,
@@ -425,37 +426,48 @@ class RealMLP(torch.nn.Module):
w2_s: Optional[torch.Tensor] = None,
) -> None:
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, RowParallelLinear)
MergedColumnParallelLinear,
RowParallelLinear,
)
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
prefix=f"{prefix}.gate_up_proj",
)
self.gate_up_proj.register_parameter(
"weight", torch.nn.Parameter(w1, requires_grad=False))
"weight", torch.nn.Parameter(w1, requires_grad=False)
)
self.gate_up_proj.register_parameter(
"weight_scale", torch.nn.Parameter(w1_s, requires_grad=False))
"weight_scale", torch.nn.Parameter(w1_s, requires_grad=False)
)
self.gate_up_proj.register_parameter(
"input_scale",
None) #torch.nn.Parameter(None, requires_grad=False))
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj")
"input_scale", None
) # torch.nn.Parameter(None, requires_grad=False))
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
self.down_proj.register_parameter(
"weight", torch.nn.Parameter(w2, requires_grad=False))
"weight", torch.nn.Parameter(w2, requires_grad=False)
)
self.down_proj.register_parameter(
"weight_scale", torch.nn.Parameter(w2_s, requires_grad=False))
"weight_scale", torch.nn.Parameter(w2_s, requires_grad=False)
)
self.down_proj.register_parameter(
"input_scale",
None) #torch.nn.Parameter(None, requires_grad=False))
"input_scale", None
) # torch.nn.Parameter(None, requires_grad=False))
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self, x):
@@ -496,13 +508,6 @@ def make_shared_experts(
w2_s = None
quant_config = None
return RealMLP(K,
N,
w1,
w2,
"silu",
quant_config,
w1_s=w1_s,
w2_s=w2_s)
return RealMLP(K, N, w1, w2, "silu", quant_config, w1_s=w1_s, w2_s=w2_s)
finally:
torch.set_default_dtype(old_dtype)

View File

@@ -5,8 +5,7 @@ from typing import Optional, Union
import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import (
group_broadcast)
from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast
from vllm.platforms import current_platform
from vllm.utils import round_up
@@ -17,25 +16,31 @@ FP8_DTYPE = current_platform.fp8_dtype()
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
return torch.as_tensor(x, dtype=torch.float32, device='cuda')
return torch.as_tensor(x, dtype=torch.float32, device="cuda")
def ref_dynamic_per_token_quant(x: torch.tensor,
quant_dtype: torch.dtype,
scale_ub: Optional[torch.tensor] = None) \
-> tuple[torch.tensor, torch.tensor]:
def ref_dynamic_per_token_quant(
x: torch.tensor, quant_dtype: torch.dtype, scale_ub: Optional[torch.tensor] = None
) -> tuple[torch.tensor, torch.tensor]:
assert quant_dtype in [torch.int8, FP8_DTYPE]
if scale_ub is not None:
assert quant_dtype == FP8_DTYPE
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
else torch.finfo(quant_dtype)
qtype_traits_max = ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \
and current_platform.is_fp8_fnuz() \
else qtype_traits.max
qtype_traits_min = -ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \
and current_platform.is_fp8_fnuz() \
else qtype_traits.min
qtype_traits = (
torch.iinfo(quant_dtype)
if quant_dtype == torch.int8
else torch.finfo(quant_dtype)
)
qtype_traits_max = (
ROCM_FP8FNUZ_MAX
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
else qtype_traits.max
)
qtype_traits_min = (
-ROCM_FP8FNUZ_MAX
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
else qtype_traits.min
)
qtype_max = as_float32_tensor(qtype_traits_max)
s_1 = as_float32_tensor(1.0)
s_512 = as_float32_tensor(512.0)
@@ -56,15 +61,13 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
iscales = as_float32_tensor(s_1 / scales)
torch_out = as_float32_tensor(x) * iscales
torch_out = torch_out.round()
torch_out = torch_out.clamp(qtype_traits_min,
qtype_traits_max).to(quant_dtype)
torch_out = torch_out.clamp(qtype_traits_min, qtype_traits_max).to(quant_dtype)
else:
assert quant_dtype == FP8_DTYPE
min_scaling_factor = s_1 / (qtype_max * s_512)
scales = scales.clamp(min=min_scaling_factor)
torch_out = as_float32_tensor(x) / scales
torch_out = torch_out.clamp(qtype_traits_min,
qtype_traits_max).to(quant_dtype)
torch_out = torch_out.clamp(qtype_traits_min, qtype_traits_max).to(quant_dtype)
return torch_out, scales
@@ -72,16 +75,20 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
# The int8 version is very similar. Incorporate the int8 version, like in
# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant
# kernel
def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
-> tuple[torch.tensor, torch.tensor]:
def ref_dynamic_per_tensor_fp8_quant(
x: torch.tensor,
) -> tuple[torch.tensor, torch.tensor]:
fp8_traits = torch.finfo(FP8_DTYPE)
fp8_traits_max = ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \
and current_platform.is_fp8_fnuz() \
else fp8_traits.max
fp8_traits_min = -ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \
and current_platform.is_fp8_fnuz() \
else fp8_traits.min
fp8_traits_max = (
ROCM_FP8FNUZ_MAX
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
else fp8_traits.max
)
fp8_traits_min = (
-ROCM_FP8FNUZ_MAX
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
else fp8_traits.min
)
fp8_max = as_float32_tensor(fp8_traits_max)
one = as_float32_tensor(1.0)
@@ -92,9 +99,12 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
x_max = as_float32_tensor(x.abs().max())
ref_scale = x_max / fp8_max
ref_iscale = one / ref_scale
ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
fp8_traits_min, fp8_traits_max).to(FP8_DTYPE)
return ref_out, ref_scale.view((1, ))
ref_out = (
(as_float32_tensor(x) * ref_iscale)
.clamp(fp8_traits_min, fp8_traits_max)
.to(FP8_DTYPE)
)
return ref_out, ref_scale.view((1,))
def native_w8a8_block_matmul(
@@ -126,7 +136,7 @@ def native_w8a8_block_matmul(
M = A.numel() // A.shape[-1]
N, K = B.shape
origin_C_shape = A.shape[:-1] + (N, )
origin_C_shape = A.shape[:-1] + (N,)
A = A.reshape(M, A.shape[-1])
As = As.reshape(M, As.shape[-1])
n_tiles = (N + block_n - 1) // block_n
@@ -137,19 +147,19 @@ def native_w8a8_block_matmul(
C_shape = (M, N)
C = torch.zeros(C_shape, dtype=compute_type, device=A.device)
A_tiles = [
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)]
B_tiles = [
[
B[
j * block_n : min((j + 1) * block_n, N),
i * block_k : min((i + 1) * block_k, K),
]
for i in range(k_tiles)
]
for j in range(n_tiles)
]
B_tiles = [[
B[
j * block_n:min((j + 1) * block_n, N),
i * block_k:min((i + 1) * block_k, K),
] for i in range(k_tiles)
] for j in range(n_tiles)]
C_tiles = [
C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)
]
As_tiles = [As[:, i:i + 1] for i in range(k_tiles)]
C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)]
As_tiles = [As[:, i : i + 1] for i in range(k_tiles)]
for i in range(k_tiles):
for j in range(n_tiles):
@@ -163,14 +173,14 @@ def native_w8a8_block_matmul(
return C
def native_per_token_group_quant_fp8(x,
group_size,
eps=1e-10,
dtype=torch.float8_e4m3fn):
def native_per_token_group_quant_fp8(
x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn
):
"""Function to perform per-token-group quantization on an input tensor
`x` using native torch."""
assert x.shape[-1] % group_size == 0, ("the last dimension of `x` must "
"be divisible by `group_size`")
assert x.shape[-1] % group_size == 0, (
"the last dimension of `x` must be divisible by `group_size`"
)
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
@@ -178,28 +188,25 @@ def native_per_token_group_quant_fp8(x,
fp8_max = finfo.max
x_ = x.reshape(x.numel() // group_size, group_size)
amax = x_.abs().max(dim=-1,
keepdim=True)[0].clamp(min=eps).to(torch.float32)
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / fp8_max
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))
return x_q, x_s
def native_per_token_group_quant_int8(x,
group_size,
eps=1e-10,
dtype=torch.int8):
def native_per_token_group_quant_int8(x, group_size, eps=1e-10, dtype=torch.int8):
"""Function to perform per-token-group quantization on an input tensor
`x` using native torch.
It converts the tensor values into int8 values and returns the
quantized tensor along with the scaling factor used for quantization.
"""
assert (x.shape[-1] % group_size == 0
), "the last dimension of `x` must be divisible by `group_size`"
assert x.shape[-1] % group_size == 0, (
"the last dimension of `x` must be divisible by `group_size`"
)
assert x.is_contiguous(), "`x` is not contiguous"
iinfo = torch.iinfo(dtype)
@@ -208,13 +215,13 @@ def native_per_token_group_quant_int8(x,
x_ = x.reshape(x.numel() // group_size, group_size)
# Use float32 for scale calculation for stability
amax = x_.abs().max(dim=-1,
keepdim=True)[0].clamp(min=eps).to(torch.float32)
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / int8_max
x_q = (x_.to(torch.float32) / x_s).round().clamp(
min=int8_min, max=int8_max).to(dtype) # Round before clamping
x_q = (
(x_.to(torch.float32) / x_s).round().clamp(min=int8_min, max=int8_max).to(dtype)
) # Round before clamping
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))
return x_q, x_s
@@ -229,9 +236,9 @@ def per_block_cast_to_int8(
block_m, block_n = block_shape
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros((round_up(m, block_m), round_up(n, block_n)),
dtype=x.dtype,
device=x.device)
x_padded = torch.zeros(
(round_up(m, block_m), round_up(n, block_n)), dtype=x.dtype, device=x.device
)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
@@ -269,8 +276,9 @@ def batched_dequant(
assert t.shape[0] == scale.shape[0]
out = torch.empty_like(t, dtype=out_dtype)
for e in range(t.shape[0]):
out[e] = dequant(t[e], scale[e], block_shape, per_act_token_quant,
out_dtype)
out[e] = dequant(
t[e], scale[e], block_shape, per_act_token_quant, out_dtype
)
return out
return t.to(out_dtype)
@@ -294,15 +302,17 @@ def native_batched_masked_quant_matmul(
num_tokens = num_expert_tokens_cpu[e]
if A.dtype.itemsize == 1 and block_shape is not None:
assert A_scale is not None and B_scale is not None
tmp = native_w8a8_block_matmul(A[e], B[e], A_scale[e], B_scale[e],
block_shape, C.dtype)
tmp = native_w8a8_block_matmul(
A[e], B[e], A_scale[e], B_scale[e], block_shape, C.dtype
)
C[e, :num_tokens, :] = tmp[:num_tokens, :]
elif A.dtype.itemsize == 1 and block_shape is None:
assert A_scale is not None and B_scale is not None
A_dq = dequant(A[e], A_scale[e], block_shape, per_act_token_quant)
B_dq = dequant(B[e], B_scale[e], block_shape, per_act_token_quant)
C[e, :num_tokens, :] = (
A_dq[:num_tokens] @ B_dq.transpose(0, 1)).to(C.dtype)
C[e, :num_tokens, :] = (A_dq[:num_tokens] @ B_dq.transpose(0, 1)).to(
C.dtype
)
else:
assert A_scale is None
assert B_scale is None

View File

@@ -8,8 +8,9 @@ from vllm.scalar_type import scalar_types
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.],
dtype=torch.float32)
kE2M1ToFloat = torch.tensor(
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
)
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
@@ -22,12 +23,9 @@ def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
return out[0:m, 0:k]
def dequantize_nvfp4_to_dtype(tensor_fp4,
tensor_sf,
global_scale,
dtype,
device,
block_size=16):
def dequantize_nvfp4_to_dtype(
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert tensor_fp4.dtype == torch.uint8
@@ -69,7 +67,8 @@ def break_fp4_bytes(a, dtype):
def quant_nvfp4_tensor(a: torch.Tensor):
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.abs(a).max().to(torch.float32))
a_global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to(
torch.float32
)
a_quant, a_block_scale = scaled_fp4_quant(a, a_global_scale)
return a_quant, a_block_scale, a_global_scale

View File

@@ -6,24 +6,25 @@ import torch
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
ALLSPARK_AMPERE_K_ALIGN, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
ALLSPARK_AMPERE_N_ALIGN)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
quantize_weights)
ALLSPARK_AMPERE_K_ALIGN,
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
ALLSPARK_AMPERE_N_ALIGN,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_weights
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
def is_gptq_allspark_supported(min_capability: int,
max_capability: int) -> bool:
def is_gptq_allspark_supported(min_capability: int, max_capability: int) -> bool:
if not current_platform.is_cuda():
return False
capability = current_platform.get_device_capability()
assert capability is not None
return capability.to_int() >= min_capability \
and capability.to_int() <= max_capability
return (
capability.to_int() >= min_capability and capability.to_int() <= max_capability
)
MNK_FACTORS = [
@@ -43,7 +44,8 @@ HAS_ZP_OPTS = [False, True]
def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref))
torch.abs(output_ref)
)
def rand_data(shape, dtype=torch.float16):
@@ -52,7 +54,8 @@ def rand_data(shape, dtype=torch.float16):
@pytest.mark.skipif(
not is_gptq_allspark_supported(80, 89),
reason="AllSpark Ampere kernel is not supported on this GPU type.")
reason="AllSpark Ampere kernel is not supported on this GPU type.",
)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("group_size", [-1])
@pytest.mark.parametrize("has_zp", HAS_ZP_OPTS)
@@ -67,8 +70,9 @@ def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype):
weight = rand_data((k, n), dtype=dtype)
# Quantize (and apply act_order if provided)
w_ref, qw, s, zp = quantize_weights(weight, scalar_types.uint8b128,
group_size, has_zp)
w_ref, qw, s, zp = quantize_weights(
weight, scalar_types.uint8b128, group_size, has_zp
)
qw = qw.to(torch.uint8)
if has_zp:
@@ -79,20 +83,42 @@ def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype):
n_32align = (n + 32 - 1) // 32 * 32
qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(
qw, s, zp, has_zp)
opcheck(torch.ops._C.rearrange_kn_weight_as_n32k16_order,
(qw, s, zp, has_zp, qw_reorder, s_reorder, zp_reorder, k, n,
n_32align))
qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(qw, s, zp, has_zp)
opcheck(
torch.ops._C.rearrange_kn_weight_as_n32k16_order,
(qw, s, zp, has_zp, qw_reorder, s_reorder, zp_reorder, k, n, n_32align),
)
opcheck(torch.ops._C.allspark_w8a16_gemm,
(input, qw_reorder, s_reorder, zp_reorder, n, group_size, sm_count,
sm_version, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, has_zp, True),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
output = ops.allspark_w8a16_gemm(input, qw_reorder, s_reorder, zp_reorder,
n, group_size, sm_count, sm_version,
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
has_zp, True)
opcheck(
torch.ops._C.allspark_w8a16_gemm,
(
input,
qw_reorder,
s_reorder,
zp_reorder,
n,
group_size,
sm_count,
sm_version,
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
has_zp,
True,
),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
output = ops.allspark_w8a16_gemm(
input,
qw_reorder,
s_reorder,
zp_reorder,
n,
group_size,
sm_count,
sm_version,
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
has_zp,
True,
)
output_ref = torch.matmul(input, w_ref)
torch.cuda.synchronize()

View File

@@ -8,40 +8,42 @@ from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_dequantize"),
reason="AWQ is not supported on this GPU type.")
@pytest.mark.skipif(
not hasattr(torch.ops._C, "awq_dequantize"),
reason="AWQ is not supported on this GPU type.",
)
def test_awq_dequantize_opcheck(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_TRITON_AWQ", "0")
qweight = torch.randint(-2000000000,
2000000000, (8192, 256),
device='cuda',
dtype=torch.int32)
scales = torch.rand((64, 2048), device='cuda', dtype=torch.float16)
zeros = torch.empty((64, 256), device='cuda', dtype=torch.int32)
qweight = torch.randint(
-2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32
)
scales = torch.rand((64, 2048), device="cuda", dtype=torch.float16)
zeros = torch.empty((64, 256), device="cuda", dtype=torch.int32)
split_k_iters = 0
thx = 0
thy = 0
opcheck(torch.ops._C.awq_dequantize,
(qweight, scales, zeros, split_k_iters, thx, thy))
opcheck(
torch.ops._C.awq_dequantize,
(qweight, scales, zeros, split_k_iters, thx, thy),
)
@pytest.mark.skip(reason="Not working; needs investigation.")
@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_gemm"),
reason="AWQ is not supported on this GPU type.")
@pytest.mark.skipif(
not hasattr(torch.ops._C, "awq_gemm"),
reason="AWQ is not supported on this GPU type.",
)
def test_awq_gemm_opcheck(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_TRITON_AWQ", "0")
input = torch.rand((2, 8192), device='cuda', dtype=torch.float16)
qweight = torch.randint(-2000000000,
2000000000, (8192, 256),
device='cuda',
dtype=torch.int32)
scales = torch.randint(-2000000000,
2000000000, (64, 256),
device='cuda',
dtype=torch.int32)
qzeros = torch.empty((64, 2048), device='cuda', dtype=torch.float16)
input = torch.rand((2, 8192), device="cuda", dtype=torch.float16)
qweight = torch.randint(
-2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32
)
scales = torch.randint(
-2000000000, 2000000000, (64, 256), device="cuda", dtype=torch.int32
)
qzeros = torch.empty((64, 2048), device="cuda", dtype=torch.float16)
split_k_iters = 8
opcheck(torch.ops._C.awq_gemm,
(input, qweight, qzeros, scales, split_k_iters))
opcheck(torch.ops._C.awq_gemm, (input, qweight, qzeros, scales, split_k_iters))

View File

@@ -4,11 +4,15 @@
Run `pytest tests/kernels/quantization/test_awq_triton.py`.
"""
import pytest
import torch
from vllm.model_executor.layers.quantization.awq_triton import (
AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton)
AWQ_TRITON_SUPPORTED_GROUP_SIZES,
awq_dequantize_triton,
awq_gemm_triton,
)
from vllm.platforms import current_platform
device = "cuda"
@@ -33,23 +37,24 @@ def reverse_awq_order(t: torch.Tensor):
# qweights - [R , C // 8], int32
# scales - [R // G, C ], float16
# zeros - [R // G, C // 8], int32
def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor,
qzeros: torch.Tensor,
group_size: int) -> torch.Tensor:
def awq_dequantize_torch(
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, group_size: int
) -> torch.Tensor:
if group_size == -1:
group_size = qweight.shape[0]
bits = 4
shifts = torch.arange(0, 32, bits, device=qzeros.device)
iweights = torch.bitwise_right_shift(qweight[:, :, None],
shifts[None, None, :]).to(torch.int8)
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
torch.int8
)
iweights = iweights.view(iweights.shape[0], -1)
zeros = torch.bitwise_right_shift(qzeros[:, :, None],
shifts[None, None, :]).to(torch.int8)
zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(
torch.int8
)
zeros = zeros.view(qzeros.shape[0], -1)
zeros = reverse_awq_order(zeros)
@@ -70,7 +75,6 @@ def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor,
@pytest.mark.parametrize("qweight_cols", [448, 576, 4736, 16, 32, 64, 128])
@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
def test_dequantize(qweight_rows, qweight_cols, group_size):
if group_size == -1:
group_size = qweight_rows
@@ -84,25 +88,27 @@ def test_dequantize(qweight_rows, qweight_cols, group_size):
current_platform.seed_everything(0)
qweight = torch.randint(0,
torch.iinfo(torch.int32).max,
(qweight_rows, qweight_cols),
dtype=qweight_dtype,
device=device)
scales = torch.rand(scales_rows,
scales_cols,
dtype=scales_dtype,
device=device)
zeros = torch.randint(0,
torch.iinfo(torch.int32).max,
(zeros_rows, zeros_cols),
dtype=zeros_dtype,
device=device)
qweight = torch.randint(
0,
torch.iinfo(torch.int32).max,
(qweight_rows, qweight_cols),
dtype=qweight_dtype,
device=device,
)
scales = torch.rand(scales_rows, scales_cols, dtype=scales_dtype, device=device)
zeros = torch.randint(
0,
torch.iinfo(torch.int32).max,
(zeros_rows, zeros_cols),
dtype=zeros_dtype,
device=device,
)
iweights_triton = awq_dequantize_triton(qweight, scales, zeros)
assert (not torch.any(torch.isinf(iweights_triton))
and not torch.any(torch.isnan(iweights_triton)))
assert not torch.any(torch.isinf(iweights_triton)) and not torch.any(
torch.isnan(iweights_triton)
)
iweights_torch = awq_dequantize_torch(qweight, scales, zeros, group_size)
@@ -119,7 +125,6 @@ def test_dequantize(qweight_rows, qweight_cols, group_size):
@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("splitK", [1, 8])
def test_gemm(N, K, M, splitK, group_size):
if group_size == -1:
group_size = K
@@ -138,35 +143,29 @@ def test_gemm(N, K, M, splitK, group_size):
current_platform.seed_everything(0)
input = torch.rand((input_rows, input_cols),
dtype=input_dtype,
device=device)
qweight = torch.randint(0,
torch.iinfo(torch.int32).max,
(qweight_rows, qweight_cols),
device=device)
qzeros = torch.randint(0,
torch.iinfo(torch.int32).max,
(qzeros_rows, qzeros_cols),
device=device)
scales = torch.rand((scales_rows, scales_cols),
dtype=scales_dtype,
device=device)
input = torch.rand((input_rows, input_cols), dtype=input_dtype, device=device)
qweight = torch.randint(
0, torch.iinfo(torch.int32).max, (qweight_rows, qweight_cols), device=device
)
qzeros = torch.randint(
0, torch.iinfo(torch.int32).max, (qzeros_rows, qzeros_cols), device=device
)
scales = torch.rand((scales_rows, scales_cols), dtype=scales_dtype, device=device)
output_triton = awq_gemm_triton(input, qweight, scales, qzeros,
split_k_iters)
output_triton = awq_gemm_triton(input, qweight, scales, qzeros, split_k_iters)
assert (not torch.any(torch.isinf(output_triton))
and not torch.any(torch.isnan(output_triton)))
assert not torch.any(torch.isinf(output_triton)) and not torch.any(
torch.isnan(output_triton)
)
dequantized_weights = awq_dequantize_triton(qweight, scales, qzeros)
output_torch = torch.matmul(input, dequantized_weights)
assert (not torch.any(torch.isinf(output_torch))
and not torch.any(torch.isnan(output_torch)))
assert not torch.any(torch.isinf(output_torch)) and not torch.any(
torch.isnan(output_torch)
)
torch.testing.assert_close(output_triton.cpu(),
output_torch.cpu(),
atol=1e-1,
rtol=1e-1)
torch.testing.assert_close(
output_triton.cpu(), output_torch.cpu(), atol=1e-1, rtol=1e-1
)

View File

@@ -7,20 +7,26 @@ import itertools
import pytest
import torch
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
native_w8a8_block_matmul)
from tests.kernels.quant_utils import (
native_per_token_group_quant_fp8,
native_w8a8_block_matmul,
)
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
cutlass_scaled_mm, per_token_group_quant_fp8, w8a8_triton_block_scaled_mm)
cutlass_scaled_mm,
per_token_group_quant_fp8,
w8a8_triton_block_scaled_mm,
)
from vllm.platforms import current_platform
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import (fp8_gemm_nt,
get_col_major_tma_aligned_tensor,
per_block_cast_to_fp8)
from vllm.utils.deep_gemm import (
fp8_gemm_nt,
get_col_major_tma_aligned_tensor,
per_block_cast_to_fp8,
)
if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
allow_module_level=True)
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
@@ -51,7 +57,8 @@ def setup_cuda():
@pytest.mark.parametrize(
"num_tokens,d,dtype,group_size,seed",
itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS))
itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS),
)
@torch.inference_mode()
def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
torch.manual_seed(seed)
@@ -60,15 +67,14 @@ def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size)
out, scale = per_token_group_quant_fp8(x, group_size)
assert torch.allclose(out.to(torch.float32),
ref_out.to(torch.float32),
rtol=0.15)
assert torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15)
assert torch.allclose(scale, ref_scale)
@pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
)
@torch.inference_mode()
def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
torch.manual_seed(seed)
@@ -89,14 +95,12 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
out_dtype)
out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size,
out_dtype)
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
rel_diff = torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
assert rel_diff < 0.001
@@ -127,32 +131,32 @@ def test_w8a8_block_fp8_cutlass_matmul():
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
# Hopper requires row-major format for scales
Bs_cutlass = Bs.T.contiguous() if current_platform.is_device_capability(
90) else Bs
Bs_cutlass = Bs.T.contiguous() if current_platform.is_device_capability(90) else Bs
A_fp8, As = per_token_group_quant_fp8(A_fp32,
block_size[1],
column_major_scales=False)
A_fp8, As = per_token_group_quant_fp8(
A_fp32, block_size[1], column_major_scales=False
)
# CUTLASS uses column-major format for scales
A_fp8_cutlass, As_cutlass = per_token_group_quant_fp8(
A_fp32, block_size[1], column_major_scales=True)
A_fp32, block_size[1], column_major_scales=True
)
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
out_dtype)
out = cutlass_scaled_mm(A_fp8_cutlass, B_fp8, As_cutlass, Bs_cutlass,
block_size, out_dtype)
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
out = cutlass_scaled_mm(
A_fp8_cutlass, B_fp8, As_cutlass, Bs_cutlass, block_size, out_dtype
)
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
rel_diff = torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
assert rel_diff < 0.001
@pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
@pytest.mark.skipif(not has_deep_gemm(),
reason="DeepGemm kernels not available.")
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
)
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGemm kernels not available.")
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
# only aligned sizes
@@ -172,20 +176,20 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
As = As_fp8.to(torch.float32)
Bs = Bs_fp8.to(torch.float32)
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
out_dtype)
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
# Transpose earlier so that the testing will not trigger transposing kernels
As_fp8 = get_col_major_tma_aligned_tensor(As_fp8)
out = torch.zeros((M, N), device='cuda', dtype=out_dtype)
out = torch.zeros((M, N), device="cuda", dtype=out_dtype)
assert As_fp8.shape == (M, (K + 127) //
128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
assert As_fp8.shape == (M, (K + 127) // 128), (
f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
)
fp8_gemm_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
rel_diff = torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
assert rel_diff < 0.001

View File

@@ -10,12 +10,12 @@ import torch
from tests.kernels.quant_utils import native_w8a8_block_matmul
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization.utils.int8_utils import (
w8a8_block_int8_matmul)
w8a8_block_int8_matmul,
)
from vllm.platforms import current_platform
if current_platform.get_device_capability() < (7, 0):
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
allow_module_level=True)
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
@@ -36,8 +36,10 @@ def setup_cuda():
torch.set_default_device("cuda")
@pytest.mark.parametrize("M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, DTYPES, SEEDS))
@pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, DTYPES, SEEDS),
)
@torch.inference_mode()
def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed):
torch.manual_seed(seed)
@@ -58,11 +60,10 @@ def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed):
As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
out_dtype)
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
out = w8a8_block_int8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
rel_diff = torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
assert rel_diff < 0.001

View File

@@ -11,12 +11,11 @@ import torch
from tests.kernels.utils import baseline_scaled_mm, to_fp8, to_int8
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
sparse_cutlass_supported)
sparse_cutlass_supported,
)
from vllm.platforms import current_platform
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
@@ -40,9 +39,7 @@ def prune_to_2_4(tensor):
# Create binary mask
mask = torch.zeros_like(reshaped)
mask.scatter_(dim=1,
index=indices,
src=torch.ones_like(indices, dtype=mask.dtype))
mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype))
# Apply mask and reshape back
pruned = reshaped * mask
@@ -55,32 +52,31 @@ def prune_to_2_4(tensor):
# This function checks that applying an identity matrix multiplication
# to the compressed weights yields the original uncompressed weights.
def check_compress_decompress_invariance(dtype: torch.dtype, b: torch.Tensor,
b_compressed: torch.Tensor,
b_metadata: torch.Tensor):
def check_compress_decompress_invariance(
dtype: torch.dtype,
b: torch.Tensor,
b_compressed: torch.Tensor,
b_metadata: torch.Tensor,
):
# For float16 and bfloat16, cutlass_scaled_sparse_mm's output must be the
# same dtype as its inputs. This line addresses that constraint while
# arbitrarily using bfloat16 for the int8/fp8 cases.
out_dtype = torch.float16 if dtype is torch.float16 else torch.bfloat16
eye = torch.eye(b.shape[0], device='cuda', dtype=dtype)
eye_scale = torch.ones(1, device='cuda', dtype=torch.float32)
b_decomp = ops.cutlass_scaled_sparse_mm(eye,
b_compressed,
b_metadata,
eye_scale,
eye_scale,
out_dtype=out_dtype)
eye = torch.eye(b.shape[0], device="cuda", dtype=dtype)
eye_scale = torch.ones(1, device="cuda", dtype=torch.float32)
b_decomp = ops.cutlass_scaled_sparse_mm(
eye, b_compressed, b_metadata, eye_scale, eye_scale, out_dtype=out_dtype
)
torch.testing.assert_close(b.to(dtype=out_dtype), b_decomp)
def make_rand_sparse_tensors(
dtype: torch.dtype, m: int, n: int, k: int
dtype: torch.dtype, m: int, n: int, k: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda')
b = torch.randn((n, k), device='cuda').t()
a = torch.randn((m, k), device="cuda")
b = torch.randn((n, k), device="cuda").t()
if dtype == torch.int8:
# ensure A and B aren't all zeros after rounding
@@ -107,32 +103,25 @@ def make_rand_sparse_tensors(
return b_compressed, e, a, b
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.")
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.",
)
# Test working with a subset of A and B for sparse matmul
def test_cutlass_sparse_subset():
big_m = 1024
m, n, k = 512, 512, 512
# Create tensors
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn,
big_m, n, k)
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, big_m, n, k)
a = whole_a[0:m, 0:k]
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
out = ops.cutlass_scaled_sparse_mm(
a, b_comp, e, scale_a, scale_b, out_dtype=torch.bfloat16
)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16)
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
@@ -161,105 +150,87 @@ MNK_FACTORS = [
# Test working with a subset of A and B for sparse matmul
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.")
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.",
)
@pytest.mark.parametrize("m, n, k", MNK_FACTORS)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: type[torch.dtype],
use_bias: bool):
def test_cutlass_sparse_gemm(
m: int, k: int, n: int, dtype: type[torch.dtype], use_bias: bool
):
# Create tensors
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
scale_a = torch.ones((1, 1), device="cuda", dtype=torch.float32)
scale_b = torch.ones((1, 1), device="cuda", dtype=torch.float32)
bias = torch.rand((n, ), device="cuda", dtype=dtype) if use_bias else None
bias = torch.rand((n,), device="cuda", dtype=dtype) if use_bias else None
out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=dtype,
bias=bias)
out = ops.cutlass_scaled_sparse_mm(
a, b_comp, e, scale_a, scale_b, out_dtype=dtype, bias=bias
)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=dtype,
bias=bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=dtype, bias=bias)
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.")
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.",
)
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.")
@pytest.mark.skipif(
not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.",
)
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int, use_bias: bool):
# Create tensors
b_comp, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32)
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32)
out_dtype = torch.bfloat16
bias = torch.rand(
(n, ), device="cuda", dtype=out_dtype) * 10 if use_bias else None
bias = torch.rand((n,), device="cuda", dtype=out_dtype) * 10 if use_bias else None
out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=out_dtype,
bias=bias)
out = ops.cutlass_scaled_sparse_mm(
a, b_comp, e, scale_a, scale_b, out_dtype=out_dtype, bias=bias
)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=out_dtype,
bias=bias)
baseline = baseline_scaled_mm(
a, b, scale_a, scale_b, out_dtype=out_dtype, bias=bias
)
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.")
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.",
)
@pytest.mark.parametrize("m,k,n", MNK_FACTORS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_sparse_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
per_out_ch: bool, use_bias: bool):
def test_cutlass_sparse_int8_gemm(
m: int, n: int, k: int, per_act_token: bool, per_out_ch: bool, use_bias: bool
):
# Create tensors
b_comp, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32)
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32)
out_dtype = torch.bfloat16
bias = torch.rand(
(n, ), device="cuda", dtype=out_dtype) * 10 if use_bias else None
bias = torch.rand((n,), device="cuda", dtype=out_dtype) * 10 if use_bias else None
out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=out_dtype,
bias=bias)
out = ops.cutlass_scaled_sparse_mm(
a, b_comp, e, scale_a, scale_b, out_dtype=out_dtype, bias=bias
)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=out_dtype,
bias=bias)
baseline = baseline_scaled_mm(
a, b, scale_a, scale_b, out_dtype=out_dtype, bias=bias
)
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)

View File

@@ -4,6 +4,7 @@
Run `pytest tests/kernels/quantization/test_cutlass_scaled_mm.py`.
"""
import random
import pytest
@@ -36,9 +37,7 @@ MNK_FACTORS = [
(512, 24576, 128),
]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
# -1 means full extent in that dimension
TENSORWISE_GROUP_SHAPE = (-1, -1)
@@ -60,18 +59,19 @@ def group_scale_helper(shape, group_shape):
def scale_shape(shape, group_shape):
assert len(shape) == len(group_shape)
group_shape = group_scale_helper(shape, group_shape)
return tuple(
cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
def cutlass_fp8_gemm_helper(m: int,
n: int,
k: int,
a_scale_group_shape: tuple,
b_scale_group_shape: tuple,
use_bias: bool,
out_dtype: type[torch.dtype] = torch.bfloat16,
device: str = "cuda"):
def cutlass_fp8_gemm_helper(
m: int,
n: int,
k: int,
a_scale_group_shape: tuple,
b_scale_group_shape: tuple,
use_bias: bool,
out_dtype: type[torch.dtype] = torch.bfloat16,
device: str = "cuda",
):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
a = to_fp8(torch.randn((m, k), device=device))
@@ -80,8 +80,8 @@ def cutlass_fp8_gemm_helper(m: int,
a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
b_scales_shape = scale_shape(b.shape, b_scale_group_shape)
scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32))
scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32))
scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32)
scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32)
# make scales M-major for blockwise quant, doesn't affect 1D scales
scale_a = scale_a.t().contiguous().t()
@@ -89,7 +89,7 @@ def cutlass_fp8_gemm_helper(m: int,
scale_b = scale_b.t().contiguous().t()
if use_bias:
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
else:
bias = None
@@ -98,18 +98,19 @@ def cutlass_fp8_gemm_helper(m: int,
torch.testing.assert_close(out, baseline, rtol=5e-1, atol=1.5e-1)
opcheck(torch.ops._C.cutlass_scaled_mm,
(out, a, b, scale_a, scale_b, bias))
opcheck(torch.ops._C.cutlass_scaled_mm, (out, a, b, scale_a, scale_b, bias))
def cutlass_int8_gemm_helper(m: int,
n: int,
k: int,
a_scale_group_shape: tuple,
b_scale_group_shape: tuple,
use_bias: bool,
out_dtype: type[torch.dtype] = torch.bfloat16,
device: str = "cuda"):
def cutlass_int8_gemm_helper(
m: int,
n: int,
k: int,
a_scale_group_shape: tuple,
b_scale_group_shape: tuple,
use_bias: bool,
out_dtype: type[torch.dtype] = torch.bfloat16,
device: str = "cuda",
):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
a = to_int8(torch.randn((m, k), device=device) * 5)
@@ -118,11 +119,11 @@ def cutlass_int8_gemm_helper(m: int,
a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
b_scales_shape = scale_shape(b.shape, b_scale_group_shape)
scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32))
scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32))
scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32)
scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32)
if use_bias:
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
else:
bias = None
@@ -131,145 +132,192 @@ def cutlass_int8_gemm_helper(m: int,
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
opcheck(torch.ops._C.cutlass_scaled_mm,
(out, a, b, scale_a, scale_b, bias))
opcheck(torch.ops._C.cutlass_scaled_mm, (out, a, b, scale_a, scale_b, bias))
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("a_scale_group_shape",
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize(
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize(
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm(m: int, n: int, k: int, a_scale_group_shape,
b_scale_group_shape, use_bias: bool):
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
use_bias)
@pytest.mark.skipif(
not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.",
)
def test_cutlass_fp8_gemm(
m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool
):
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
[((1, 128), (128, 128))])
@pytest.mark.parametrize(
"a_scale_group_shape,b_scale_group_shape", [((1, 128), (128, 128))]
)
@pytest.mark.parametrize("use_bias", [False])
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="FP8 blockwise is not supported on this GPU type.")
def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int,
a_scale_group_shape,
b_scale_group_shape, use_bias: bool):
@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="FP8 blockwise is not supported on this GPU type.",
)
def test_cutlass_fp8_blockwise_scale_gemm(
m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool
):
if k % b_scale_group_shape[0] != 0 or n % b_scale_group_shape[1] != 0:
return
if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0:
return
if m % 4 != 0 and current_platform.has_device_capability(100):
return
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
use_bias)
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("a_scale_group_shape",
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize(
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize(
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape,
b_scale_group_shape, use_bias: bool):
cutlass_int8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
use_bias)
def test_cutlass_int8_gemm(
m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool
):
cutlass_int8_gemm_helper(
m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias
)
@pytest.mark.parametrize("a_scale_group_shape",
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize(
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize(
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
b_scale_group_shape,
out_dtype: type[torch.dtype],
use_bias: bool):
cutlass_int8_gemm_helper(512,
512,
512,
a_scale_group_shape,
b_scale_group_shape,
use_bias,
out_dtype=out_dtype)
def test_cutlass_int8_gemm_output_dtype(
a_scale_group_shape,
b_scale_group_shape,
out_dtype: type[torch.dtype],
use_bias: bool,
):
cutlass_int8_gemm_helper(
512,
512,
512,
a_scale_group_shape,
b_scale_group_shape,
use_bias,
out_dtype=out_dtype,
)
@pytest.mark.parametrize("a_scale_group_shape",
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize(
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize(
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
b_scale_group_shape,
out_dtype: type[torch.dtype],
use_bias: bool):
cutlass_fp8_gemm_helper(512,
512,
512,
a_scale_group_shape,
b_scale_group_shape,
use_bias,
out_dtype=out_dtype)
@pytest.mark.skipif(
not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.",
)
def test_cutlass_fp8_gemm_output_dtype(
a_scale_group_shape,
b_scale_group_shape,
out_dtype: type[torch.dtype],
use_bias: bool,
):
cutlass_fp8_gemm_helper(
512,
512,
512,
a_scale_group_shape,
b_scale_group_shape,
use_bias,
out_dtype=out_dtype,
)
@pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
[((1, 128), (128, 128))])
@pytest.mark.parametrize(
"a_scale_group_shape,b_scale_group_shape", [((1, 128), (128, 128))]
)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [False])
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="FP8 blockwise is not supported on this GPU type.")
def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape,
b_scale_group_shape,
out_dtype: type[torch.dtype],
use_bias: bool):
cutlass_fp8_gemm_helper(512,
512,
512,
a_scale_group_shape,
b_scale_group_shape,
use_bias,
out_dtype=out_dtype)
@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="FP8 blockwise is not supported on this GPU type.",
)
def test_cutlass_fp8_blockwise_scale_gemm_dtype(
a_scale_group_shape,
b_scale_group_shape,
out_dtype: type[torch.dtype],
use_bias: bool,
):
cutlass_fp8_gemm_helper(
512,
512,
512,
a_scale_group_shape,
b_scale_group_shape,
use_bias,
out_dtype=out_dtype,
)
@pytest.mark.parametrize("a_scale_group_shape",
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize(
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize(
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
use_bias: bool, device: str):
cutlass_fp8_gemm_helper(512, 512, 512, a_scale_group_shape,
b_scale_group_shape, use_bias, torch.bfloat16,
device)
@pytest.mark.skipif(
not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.",
)
def test_cutlass_fp8_gemm_devices(
a_scale_group_shape, b_scale_group_shape, use_bias: bool, device: str
):
cutlass_fp8_gemm_helper(
512,
512,
512,
a_scale_group_shape,
b_scale_group_shape,
use_bias,
torch.bfloat16,
device,
)
@pytest.mark.parametrize("a_scale_group_shape",
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize(
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize(
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
use_bias: bool, device: str):
cutlass_int8_gemm_helper(512,
512,
512,
a_scale_group_shape,
b_scale_group_shape,
use_bias,
out_dtype=torch.bfloat16,
device=device)
def test_cutlass_int8_gemm_devices(
a_scale_group_shape, b_scale_group_shape, use_bias: bool, device: str
):
cutlass_int8_gemm_helper(
512,
512,
512,
a_scale_group_shape,
b_scale_group_shape,
use_bias,
out_dtype=torch.bfloat16,
device=device,
)
# For the following two tests:
@@ -277,32 +325,42 @@ def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
# of a large power of two. In any case, the kernel will have a naive fallback
# when N and K are not divisible by 16. But M is the number of tokens and the
# kernel must handle any M thrown at it.
@pytest.mark.parametrize("a_scale_group_shape",
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize(
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize(
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape,
use_bias: bool):
@pytest.mark.skipif(
not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.",
)
def test_cutlass_fp8_gemm_m_sweep(
a_scale_group_shape, b_scale_group_shape, use_bias: bool
):
for nk in range(32, 128, 32):
for m in range(1, 128):
cutlass_fp8_gemm_helper(m, nk, nk, a_scale_group_shape,
b_scale_group_shape, use_bias)
cutlass_fp8_gemm_helper(
m, nk, nk, a_scale_group_shape, b_scale_group_shape, use_bias
)
@pytest.mark.parametrize("a_scale_group_shape",
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize(
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize(
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape,
use_bias: bool):
def test_cutlass_int8_gemm_m_sweep(
a_scale_group_shape, b_scale_group_shape, use_bias: bool
):
for nk in range(32, 128, 32):
for m in range(1, 128):
cutlass_int8_gemm_helper(m, nk, nk, a_scale_group_shape,
b_scale_group_shape, use_bias)
cutlass_int8_gemm_helper(
m, nk, nk, a_scale_group_shape, b_scale_group_shape, use_bias
)
@pytest.mark.parametrize("m", [32, 64, 128])
@@ -310,8 +368,7 @@ def test_cutlass_int8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape,
@pytest.mark.parametrize("k", [64, 128, 256])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.skip
def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
out_dtype: torch.dtype):
def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int, out_dtype: torch.dtype):
# Currently, the test is failing because folding azp into
# 16-bit bias loses too much precision
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
@@ -328,7 +385,7 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
b_dq = scale_b * bq_f32
azp_a = torch.rand((1, ), device="cuda", dtype=torch.float32) * 10 + 1.5
azp_a = torch.rand((1,), device="cuda", dtype=torch.float32) * 10 + 1.5
azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
@@ -340,18 +397,17 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
J = torch.ones((1, k), device="cuda", dtype=torch.float32)
azp_bias = (azp_a * scale_b * (J @ bq_f32)).to(out_dtype)
assert azp_bias.shape == (1, n)
assert azp_bias[0, :].shape == (n, )
assert azp_bias[0, :].shape == (n,)
baseline_q = (scale_a.to(device='cpu') * scale_b.to(device='cpu') * (
(aq_i32 + azp_aq_i8).to(device='cpu') @ bq_i32.to(device='cpu'))).to(
dtype=out_dtype, device='cuda')
baseline_q = (
scale_a.to(device="cpu")
* scale_b.to(device="cpu")
* ((aq_i32 + azp_aq_i8).to(device="cpu") @ bq_i32.to(device="cpu"))
).to(dtype=out_dtype, device="cuda")
out = ops.cutlass_scaled_mm(aq_i8,
bq_i8,
scale_a,
scale_b,
out_dtype=out_dtype,
bias=azp_bias[0, :])
out = ops.cutlass_scaled_mm(
aq_i8, bq_i8, scale_a, scale_b, out_dtype=out_dtype, bias=azp_bias[0, :]
)
torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0)
torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0)
@@ -362,8 +418,9 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("azp_per_token", [True, False])
def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
use_bias: bool, azp_per_token: bool):
def test_cutlass_int8_azp(
m: int, n: int, k: int, out_dtype: torch.dtype, use_bias: bool, azp_per_token: bool
):
m_azp = m if azp_per_token else 1
scale_a = torch.randn((m_azp, 1), device="cuda", dtype=torch.float32) / 10
scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
@@ -377,16 +434,12 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
bq_f32 = bq_i8.to(dtype=torch.float32)
b_dq = scale_b * bq_f32
azp_a = torch.rand(
(m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
azp_a = torch.rand((m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
torch.testing.assert_close(a_dq,
scale_a * aq_f32 - azp_a,
rtol=1e-4,
atol=1e-3)
torch.testing.assert_close(a_dq, scale_a * aq_f32 - azp_a, rtol=1e-4, atol=1e-3)
if use_bias:
bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
@@ -396,8 +449,8 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
baseline_dq = (torch.mm(a_dq, b_dq) + bias).to(out_dtype)
# int32 mm not supported on CUDA
a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device='cpu')
cq = (a_noazp_i32_cpu @ bq_i32.to(device='cpu')).to(device='cuda')
a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device="cpu")
cq = (a_noazp_i32_cpu @ bq_i32.to(device="cpu")).to(device="cuda")
baseline_q = (scale_a * scale_b * cq + bias).to(dtype=out_dtype)
# Hadamard is just the sum of the cols
@@ -406,14 +459,14 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
func_bias = bias if use_bias else None
if azp_per_token:
out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
out_dtype, azp_adj_i32, azp_i32,
func_bias)
out = ops.cutlass_scaled_mm_azp(
aq_i8, bq_i8, scale_a, scale_b, out_dtype, azp_adj_i32, azp_i32, func_bias
)
else:
azp_with_adj_i32 = azp_i32 * azp_adj_i32
out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
out_dtype, azp_with_adj_i32, None,
func_bias)
out = ops.cutlass_scaled_mm_azp(
aq_i8, bq_i8, scale_a, scale_b, out_dtype, azp_with_adj_i32, None, func_bias
)
# bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
@@ -423,13 +476,15 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
if azp_per_token:
opcheck(torch.ops._C.cutlass_scaled_mm_azp,
(out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32,
func_bias))
opcheck(
torch.ops._C.cutlass_scaled_mm_azp,
(out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32, func_bias),
)
else:
opcheck(torch.ops._C.cutlass_scaled_mm_azp,
(out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None,
func_bias))
opcheck(
torch.ops._C.cutlass_scaled_mm_azp,
(out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None, func_bias),
)
# Test working with a subset of A and B
@@ -445,23 +500,14 @@ def test_cutlass_subset():
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
out = ops.cutlass_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16)
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
# Test to make sure cuda graphs work
class CutlassLayer(torch.nn.Module):
def __init__(self, b, scale_a, scale_b, out_dtype):
super().__init__()
self.b = b
@@ -470,8 +516,9 @@ class CutlassLayer(torch.nn.Module):
self.out_dtype = out_dtype
def forward(self, a):
return ops.cutlass_scaled_mm(a, self.b, self.scale_a, self.scale_b,
self.out_dtype)
return ops.cutlass_scaled_mm(
a, self.b, self.scale_a, self.scale_b, self.out_dtype
)
@pytest.mark.parametrize("per_act_token", [True, False])
@@ -485,10 +532,8 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
m_a_scales = m if per_act_token else 1
n_b_scales = n if per_out_ch else 1
scale_a = (torch.randn(
(m_a_scales, 1), device="cuda", dtype=torch.float32) / 10)
scale_b = (torch.randn(
(1, n_b_scales), device="cuda", dtype=torch.float32) / 10)
scale_a = torch.randn((m_a_scales, 1), device="cuda", dtype=torch.float32) / 10
scale_b = torch.randn((1, n_b_scales), device="cuda", dtype=torch.float32) / 10
# Construct a trivial model with a single layer that calls a CUTLASS kernel
model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16)
@@ -502,13 +547,14 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
out.zero_()
g.replay()
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16)
baseline = torch.mm(
scale_a * a.to(dtype=torch.float32), scale_b * b.to(dtype=torch.float32)
).to(torch.bfloat16)
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
def test_cutlass_support_opcheck():
opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, ))
opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability,))
@pytest.mark.parametrize("num_experts", [8, 64])
@@ -517,11 +563,13 @@ def test_cutlass_support_opcheck():
@pytest.mark.parametrize("use_bias", [False])
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
reason="Grouped gemm is not supported on this GPU type.")
def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
per_out_ch: bool, use_bias: bool):
current_platform.get_device_capability()
),
reason="Grouped gemm is not supported on this GPU type.",
)
def test_cutlass_fp8_group_gemm(
num_experts: int, per_act_token: bool, per_out_ch: bool, use_bias: bool
):
# Device and dtype setup
device = "cuda"
out_dtype = torch.half
@@ -533,13 +581,9 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
b_scales_tensors = []
baseline_tensors = []
expert_offsets = torch.zeros((num_experts + 1),
device=device,
dtype=torch.int64)
expert_offsets = torch.zeros((num_experts + 1), device=device, dtype=torch.int64)
problem_sizes = torch.zeros((num_experts, 3),
device=device,
dtype=torch.int32)
problem_sizes = torch.zeros((num_experts, 3), device=device, dtype=torch.int32)
if not per_act_token:
one_scale_a = torch.randn((1, 1), device=device, dtype=torch.float32)
@@ -566,75 +610,76 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
b_tensors.append(b_g)
# Set up A/B scales
scale_b = torch.randn((1, n_b_scales),
device=device,
dtype=torch.float32)
scale_b = torch.randn((1, n_b_scales), device=device, dtype=torch.float32)
b_scales_tensors.append(scale_b)
if per_act_token:
scale_a = torch.randn((m_a_scales, 1),
device=device,
dtype=torch.float32)
scale_a = torch.randn((m_a_scales, 1), device=device, dtype=torch.float32)
a_scales_tensors.append(scale_a)
else:
scale_a = one_scale_a
# Compute baseline result for this group
baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype,
None)
baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype, None)
baseline_tensors.append(baseline_g)
a_tensors_stacked = torch.empty((expert_offsets[num_experts], k_g),
device=device,
dtype=torch.float8_e4m3fn)
b_tensors_stacked = torch.empty((num_experts, n_g, k_g),
device=device,
dtype=torch.float8_e4m3fn)
a_tensors_stacked = torch.empty(
(expert_offsets[num_experts], k_g), device=device, dtype=torch.float8_e4m3fn
)
b_tensors_stacked = torch.empty(
(num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn
)
for g in range(num_experts):
a_tensors_stacked[expert_offsets[g]:expert_offsets[g +
1]] = a_tensors[g]
a_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[g]
b_tensors_stacked[g] = b_tensors[g].t()
b_tensors_stacked = b_tensors_stacked.transpose(1, 2)
if per_act_token:
a_scales_tensors_stacked = torch.empty(
(expert_offsets[num_experts], 1),
device=device,
dtype=torch.float32)
(expert_offsets[num_experts], 1), device=device, dtype=torch.float32
)
for g in range(num_experts):
a_scales_tensors_stacked[
expert_offsets[g]:expert_offsets[g + 1]] = a_scales_tensors[g]
a_scales_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]] = (
a_scales_tensors[g]
)
else:
a_scales_tensors_stacked = one_scale_a
b_scales_tensors_stacked = torch.empty((num_experts, n_b_scales),
device=device,
dtype=torch.float32)
b_scales_tensors_stacked = torch.empty(
(num_experts, n_b_scales), device=device, dtype=torch.float32
)
for g in range(num_experts):
b_scales_tensors_stacked[g] = b_scales_tensors[g]
out_tensors_stacked = torch.zeros((expert_offsets[num_experts], n_g),
device=device,
dtype=out_dtype)
out_tensors_stacked = torch.zeros(
(expert_offsets[num_experts], n_g), device=device, dtype=out_dtype
)
ab_strides = torch.full((num_experts, ),
a_tensors_stacked.stride(0),
device="cuda",
dtype=torch.int64)
c_strides = torch.full((num_experts, ),
out_tensors_stacked.stride(0),
device="cuda",
dtype=torch.int64)
ab_strides = torch.full(
(num_experts,), a_tensors_stacked.stride(0), device="cuda", dtype=torch.int64
)
c_strides = torch.full(
(num_experts,), out_tensors_stacked.stride(0), device="cuda", dtype=torch.int64
)
ops.cutlass_moe_mm(out_tensors_stacked, a_tensors_stacked,
b_tensors_stacked, a_scales_tensors_stacked,
b_scales_tensors_stacked, expert_offsets[:-1],
problem_sizes, ab_strides, ab_strides, c_strides,
per_act_token, per_out_ch)
ops.cutlass_moe_mm(
out_tensors_stacked,
a_tensors_stacked,
b_tensors_stacked,
a_scales_tensors_stacked,
b_scales_tensors_stacked,
expert_offsets[:-1],
problem_sizes,
ab_strides,
ab_strides,
c_strides,
per_act_token,
per_out_ch,
)
# Validate each group's result against the baseline
for g in range(num_experts):
baseline = baseline_tensors[g]
c = out_tensors_stacked[expert_offsets[g]:expert_offsets[g + 1]]
c = out_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]]
torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-4)

View File

@@ -13,7 +13,9 @@ import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_rows, quantize_weights)
pack_rows,
quantize_weights,
)
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
@@ -24,16 +26,33 @@ from vllm.scalar_type import ScalarType, scalar_types
# have kernels and some kernels support multiple quantization methods.
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
MNK_SHAPES = [(1, 128, 128), (1, 512, 1024), (1, 4096, 4096), (1, 8192, 28672),
(13, 8192, 4096), (26, 4096, 8192), (64, 4096, 4096),
(64, 8192, 28672), (257, 128, 4096), (257, 4096, 4096),
(1024, 4096, 8192), (1024, 8192, 4096)]
MNK_SHAPES = [
(1, 128, 128),
(1, 512, 1024),
(1, 4096, 4096),
(1, 8192, 28672),
(13, 8192, 4096),
(26, 4096, 8192),
(64, 4096, 4096),
(64, 8192, 28672),
(257, 128, 4096),
(257, 4096, 4096),
(1024, 4096, 8192),
(1024, 8192, 4096),
]
# TODO(czhu): get supported schedules from fn
SCHEDULES = [
'128x16_1x1x1', '256x16_1x1x1', '128x32_1x1x1', '256x32_1x1x1',
'128x64_1x1x1', '256x64_1x1x1', '128x128_1x1x1', '256x128_1x1x1',
'128x256_1x1x1', '128x256_2x1x1'
"128x16_1x1x1",
"256x16_1x1x1",
"128x32_1x1x1",
"256x32_1x1x1",
"128x64_1x1x1",
"256x64_1x1x1",
"128x128_1x1x1",
"256x128_1x1x1",
"128x256_1x1x1",
"128x256_2x1x1",
]
@@ -60,19 +79,23 @@ class Tensors:
# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints,
# Ch Scales Type, Tok Scales Type)
TestTypeTuple = tuple[list[torch.dtype], ScalarType, Optional[torch.dtype],
Optional[torch.dtype], bool]
TestTypeTuple = tuple[
list[torch.dtype], ScalarType, Optional[torch.dtype], Optional[torch.dtype], bool
]
TEST_TYPES = [
*(
TypeConfig(act_type=torch.float8_e4m3fn,
weight_type=w_type,
output_type=o_type,
group_scale_type=torch.float8_e4m3fn,
channel_scale_type=torch.float32,
token_scale_type=torch.float32)
TypeConfig(
act_type=torch.float8_e4m3fn,
weight_type=w_type,
output_type=o_type,
group_scale_type=torch.float8_e4m3fn,
channel_scale_type=torch.float32,
token_scale_type=torch.float32,
)
for w_type in [scalar_types.int4]
# TODO(czhu): fp16 out type
for o_type in [torch.bfloat16]),
for o_type in [torch.bfloat16]
),
]
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
@@ -86,26 +109,28 @@ IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90)
# For testing quantized linear kernels
def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
return tensor.clamp(min=finfo.min,
max=finfo.max).to(dtype=torch.float8_e4m3fn)
return tensor.clamp(min=finfo.min, max=finfo.max).to(dtype=torch.float8_e4m3fn)
def cutlass_quantize_and_pack(atype: torch.dtype,
w: torch.Tensor,
wtype: ScalarType,
stype: Optional[torch.dtype],
group_size: Optional[int],
zero_points: bool = False):
def cutlass_quantize_and_pack(
atype: torch.dtype,
w: torch.Tensor,
wtype: ScalarType,
stype: Optional[torch.dtype],
group_size: Optional[int],
zero_points: bool = False,
):
assert wtype.is_integer(), "TODO: support floating point weights"
w_ref, w_q, w_s, w_zp = quantize_weights(w,
wtype,
group_size=group_size,
zero_points=zero_points)
w_ref, w_q, w_s, w_zp = quantize_weights(
w, wtype, group_size=group_size, zero_points=zero_points
)
# since scales are cast to fp8, we need to compute w_ref this way
w_ref = ((w_q).to(torch.float32) * w_s.to(atype).to(
torch.float32).repeat_interleave(group_size, dim=0)).to(atype)
w_ref = (
(w_q).to(torch.float32)
* w_s.to(atype).to(torch.float32).repeat_interleave(group_size, dim=0)
).to(atype)
# bit mask prevents sign extending int4 when packing
w_q = pack_rows(w_q & 0x0F, wtype.size_bits, *w_q.shape)
@@ -117,12 +142,14 @@ def cutlass_quantize_and_pack(atype: torch.dtype,
return w_ref, w_q_packed, w_s_packed, w_zp
def create_test_tensors(shape: tuple[int, int, int], types: TypeConfig,
group_size: Optional[int]) -> Tensors:
def create_test_tensors(
shape: tuple[int, int, int], types: TypeConfig, group_size: Optional[int]
) -> Tensors:
m, n, k = shape
print("create_test_tensors, shape:", shape, "types:", types, "group_size:",
group_size)
print(
"create_test_tensors, shape:", shape, "types:", types, "group_size:", group_size
)
a = to_fp8(torch.randn((m, k), device="cuda"))
w = to_fp8(torch.randn((k, n), device="cuda"))
@@ -133,30 +160,34 @@ def create_test_tensors(shape: tuple[int, int, int], types: TypeConfig,
w = w.to(torch.float16)
w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack(
a.dtype, w, types.weight_type, types.group_scale_type, group_size,
False)
a.dtype, w, types.weight_type, types.group_scale_type, group_size, False
)
a_ref = a.to(torch.float32)
w_ref = w_ref.to(torch.float32)
# for the practical use case we need per-tok scales for fp8 activations
w_tok_s = torch.randn((m, ), device='cuda', dtype=types.token_scale_type)
w_tok_s = torch.randn((m,), device="cuda", dtype=types.token_scale_type)
# weights are already per-group quantized, use placeholder here
w_ch_s = torch.ones((n, ), device='cuda', dtype=types.channel_scale_type)
w_ch_s = torch.ones((n,), device="cuda", dtype=types.channel_scale_type)
return Tensors(w_ref=w_ref,
a_ref=a_ref,
a=a,
w_q=w_q_packed,
w_g_s=w_s,
w_ch_s=w_ch_s,
w_tok_s=w_tok_s)
return Tensors(
w_ref=w_ref,
a_ref=a_ref,
a=a,
w_q=w_q_packed,
w_g_s=w_s,
w_ch_s=w_ch_s,
w_tok_s=w_tok_s,
)
def mm_test_helper(types: TypeConfig,
tensors: Tensors,
group_size: Optional[int] = None,
schedule: Optional[str] = None):
def mm_test_helper(
types: TypeConfig,
tensors: Tensors,
group_size: Optional[int] = None,
schedule: Optional[str] = None,
):
# CUTLASS upstream uses fp8 with fastaccum as reference
# https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L406
output_ref = torch._scaled_mm(
@@ -165,7 +196,8 @@ def mm_test_helper(types: TypeConfig,
tensors.w_tok_s.unsqueeze(1),
tensors.w_ch_s.unsqueeze(0),
out_dtype=types.output_type,
use_fast_accum=True)
use_fast_accum=True,
)
output = ops.cutlass_w4a8_mm(
a=tensors.a,
@@ -179,17 +211,15 @@ def mm_test_helper(types: TypeConfig,
print(output)
print(output_ref)
torch.testing.assert_close(output,
output_ref.to(output.dtype),
rtol=1e-3,
atol=1e-3)
torch.testing.assert_close(
output, output_ref.to(output.dtype), rtol=1e-3, atol=1e-3
)
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
reason="CUTLASS W4A8 is not supported on this GPU type.")
@pytest.mark.parametrize("shape",
MNK_SHAPES,
ids=lambda x: "x".join(str(v) for v in x))
@pytest.mark.skipif(
not IS_SUPPORTED_BY_GPU, reason="CUTLASS W4A8 is not supported on this GPU type."
)
@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x))
@pytest.mark.parametrize("types", TEST_TYPES)
@pytest.mark.parametrize("schedule", SCHEDULES)
def test_cutlass_w4a8(shape, types: TypeConfig, schedule):
@@ -201,7 +231,6 @@ def test_cutlass_w4a8(shape, types: TypeConfig, schedule):
# Test to make sure cuda graphs work
class W4A8Layer(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__()
self.kwargs = kwargs
@@ -210,8 +239,9 @@ class W4A8Layer(torch.nn.Module):
return ops.cutlass_w4a8_mm(a=a, **self.kwargs)
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
reason="CUTLASS W4A8 is not supported on this GPU type.")
@pytest.mark.skipif(
not IS_SUPPORTED_BY_GPU, reason="CUTLASS W4A8 is not supported on this GPU type."
)
def test_w4a8_cuda_graph():
m, n, k = 512, 4096, 4096
@@ -224,10 +254,11 @@ def test_w4a8_cuda_graph():
zero_points = False
w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack(
a.dtype, b.to(torch.float16), wtype, stype, group_size, zero_points)
a.dtype, b.to(torch.float16), wtype, stype, group_size, zero_points
)
w_tok_s = torch.randn((m, ), device='cuda', dtype=torch.float32)
w_ch_s = torch.ones((n, ), device='cuda', dtype=torch.float32)
w_tok_s = torch.randn((m,), device="cuda", dtype=torch.float32)
w_ch_s = torch.ones((n,), device="cuda", dtype=torch.float32)
# Construct a trivial model with a single layer that calls the kernel
model = W4A8Layer(
@@ -244,7 +275,8 @@ def test_w4a8_cuda_graph():
w_tok_s.unsqueeze(1),
w_ch_s.unsqueeze(0),
out_dtype=torch.bfloat16,
use_fast_accum=True)
use_fast_accum=True,
)
# Run the model with a cuda graph
stream = torch.cuda.Stream()

View File

@@ -2,8 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX,
convert_swizzled_to_linear, dequantize_nvfp4_to_dtype)
from nvfp4_utils import (
FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
convert_swizzled_to_linear,
dequantize_nvfp4_to_dtype,
)
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
@@ -41,18 +45,12 @@ def get_ref_results(
_, m_k = a_fp4.shape
_, n_k = b_fp4.shape
assert m_k == n_k
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
a_sf,
a_global_scale,
dtype=dtype,
device=device,
block_size=block_size)
b_in_dtype = dequantize_nvfp4_to_dtype(b_fp4,
b_sf,
b_global_scale,
dtype=dtype,
device=device,
block_size=block_size)
a_in_dtype = dequantize_nvfp4_to_dtype(
a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size
)
b_in_dtype = dequantize_nvfp4_to_dtype(
b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size
)
return torch.matmul(a_in_dtype, b_in_dtype.t())
@@ -72,8 +70,7 @@ def test_flashinfer_nvfp4_gemm(
autotune: bool,
) -> None:
if backend == "trtllm" and dtype == torch.float16:
pytest.skip(
"Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations")
pytest.skip("Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations")
current_platform.seed_everything(seed)
m, n, packed_k = shape
@@ -82,10 +79,12 @@ def test_flashinfer_nvfp4_gemm(
a_dtype = torch.randn((m, k), dtype=dtype, device=device)
b_dtype = torch.randn((n, k), dtype=dtype, device=device)
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32)
b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32)
a_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1)
).to(torch.float32)
b_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)
).to(torch.float32)
alpha = 1.0 / (a_global_scale * b_global_scale)
# ops.scaled_fp4_quant returns swizzled scales, while weights
# from checkpoints are in linear scales.
@@ -113,14 +112,18 @@ def test_flashinfer_nvfp4_gemm(
if backend == "trtllm":
epilogue_tile_m = 128
b_fp4 = flashinfer.shuffle_matrix_a(b_fp4.view(torch.uint8),
epilogue_tile_m)
b_fp4 = flashinfer.shuffle_matrix_a(b_fp4.view(torch.uint8), epilogue_tile_m)
b_scale_interleaved = convert_swizzled_to_linear(
b_scale_interleaved, n, k, block_size)
b_scale_interleaved = (flashinfer.shuffle_matrix_sf_a(
b_scale_interleaved.view(torch.uint8), epilogue_tile_m).reshape(
b_scale_interleaved.shape).view(torch.float8_e4m3fn))
b_scale_interleaved, n, k, block_size
)
b_scale_interleaved = (
flashinfer.shuffle_matrix_sf_a(
b_scale_interleaved.view(torch.uint8), epilogue_tile_m
)
.reshape(b_scale_interleaved.shape)
.view(torch.float8_e4m3fn)
)
with flashinfer.autotune(autotune):
out = flashinfer_scaled_fp4_mm(
@@ -133,7 +136,4 @@ def test_flashinfer_nvfp4_gemm(
backend=backend,
)
torch.testing.assert_close(out,
expected_out.to(dtype=dtype),
atol=1e-1,
rtol=1e-1)
torch.testing.assert_close(out, expected_out.to(dtype=dtype), atol=1e-1, rtol=1e-1)

View File

@@ -9,8 +9,7 @@ from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm
if not current_platform.has_device_capability(100):
pytest.skip(
reason=
"Flashinfer FP8 gemms requires compute capability of 10.0 or above.",
reason="Flashinfer FP8 gemms requires compute capability of 10.0 or above.",
allow_module_level=True,
)
@@ -53,7 +52,7 @@ def test_flashinfer_fp8_gemm(
).to(dtype=dtype)
if use_bias:
bias = torch.randn((n, ), dtype=dtype, device=device)
bias = torch.randn((n,), dtype=dtype, device=device)
expected_out = expected_out + bias
else:
bias = None

View File

@@ -5,9 +5,11 @@ import pytest
import torch
import vllm._custom_ops as ops
from tests.kernels.quant_utils import (FP8_DTYPE,
ref_dynamic_per_tensor_fp8_quant,
ref_dynamic_per_token_quant)
from tests.kernels.quant_utils import (
FP8_DTYPE,
ref_dynamic_per_tensor_fp8_quant,
ref_dynamic_per_token_quant,
)
from tests.kernels.utils import opcheck
from vllm.platforms import current_platform
@@ -18,23 +20,25 @@ SCALE_UBS = [True, False]
SEEDS = [0]
def opcheck_fp8_quant(output,
input,
scale=None,
scale_ub=None,
use_per_token_if_dynamic=False):
def opcheck_fp8_quant(
output, input, scale=None, scale_ub=None, use_per_token_if_dynamic=False
):
if scale is not None:
opcheck(torch.ops._C.static_scaled_fp8_quant, (output, input, scale))
elif use_per_token_if_dynamic:
scale = torch.empty((input.shape[0], 1),
device=input.device,
dtype=torch.float32)
opcheck(torch.ops._C.dynamic_per_token_scaled_fp8_quant,
(output, input, scale, scale_ub))
scale = torch.empty(
(input.shape[0], 1), device=input.device, dtype=torch.float32
)
opcheck(
torch.ops._C.dynamic_per_token_scaled_fp8_quant,
(output, input, scale, scale_ub),
)
else:
scale = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
scale = torch.empty(
(input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32,
)
opcheck(torch.ops._C.dynamic_scaled_fp8_quant, (output, input, scale))
@@ -44,30 +48,29 @@ def opcheck_fp8_quant(output,
@pytest.mark.parametrize("scale_ub", SCALE_UBS)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, scale_ub: bool,
seed: int) -> None:
def test_dynamic_per_token_fp8_quant(
num_tokens: int, hidden_size: int, dtype: torch.dtype, scale_ub: bool, seed: int
) -> None:
current_platform.seed_everything(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
device="cuda") + 1e-6 # avoid nans
x = (
torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") + 1e-6
) # avoid nans
scale_ub = torch.mean(x).to(dtype=torch.float32, device='cuda') \
if scale_ub else None
scale_ub = (
torch.mean(x).to(dtype=torch.float32, device="cuda") if scale_ub else None
)
ref_out, ref_scales = ref_dynamic_per_token_quant(x, FP8_DTYPE, scale_ub)
ops_out, ops_scales = ops.scaled_fp8_quant(x,
scale_ub=scale_ub,
use_per_token_if_dynamic=True)
ops_out, ops_scales = ops.scaled_fp8_quant(
x, scale_ub=scale_ub, use_per_token_if_dynamic=True
)
torch.testing.assert_close(ref_scales, ops_scales)
torch.testing.assert_close(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))
torch.testing.assert_close(
ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)
)
opcheck_fp8_quant(ops_out,
x,
None,
scale_ub,
use_per_token_if_dynamic=True)
opcheck_fp8_quant(ops_out, x, None, scale_ub, use_per_token_if_dynamic=True)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@@ -75,8 +78,9 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
def test_dynamic_per_tensor_fp8_quant(
num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int
) -> None:
current_platform.seed_everything(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
@@ -85,8 +89,9 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
ops_out, ops_scale = ops.scaled_fp8_quant(x)
torch.testing.assert_close(ref_scale, ops_scale)
torch.testing.assert_close(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))
torch.testing.assert_close(
ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)
)
opcheck_fp8_quant(ops_out, x)

View File

@@ -6,8 +6,7 @@ import pytest
import torch
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
@@ -18,13 +17,14 @@ from vllm.platforms import current_platform
(64, 1024, 64), # Medium
(128, 2048, 128), # Large
(8, 513, 64), # Non-divisible (native only)
])
],
)
@pytest.mark.parametrize("seed", [42])
@pytest.mark.parametrize("use_ue8m0", [True, False])
@torch.inference_mode()
def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
group_size: int, seed: int,
use_ue8m0: bool) -> None:
def test_quantfp8_group_functionality(
batch_size: int, hidden_dim: int, group_size: int, seed: int, use_ue8m0: bool
) -> None:
"""Test QuantFP8 group quantization with various configurations.
Tests both CUDA and native implementations, column-major scales,
@@ -32,16 +32,17 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
"""
current_platform.seed_everything(seed)
x = torch.randn(
(batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
x = torch.randn((batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
expected_num_groups = (hidden_dim + group_size - 1) // group_size
is_divisible = hidden_dim % group_size == 0
group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False,
use_ue8m0=use_ue8m0)
quant_op = QuantFP8(
static=False,
group_shape=group_shape,
column_major_scales=False,
use_ue8m0=use_ue8m0,
)
# 1. Test native implementation (always available)
x_quant_native, scales_native = quant_op.forward_native(x.clone())
@@ -49,10 +50,12 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
assert scales_native.shape == (batch_size, expected_num_groups)
# 2. Test column-major scales configuration
quant_op_col = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=True,
use_ue8m0=use_ue8m0)
quant_op_col = QuantFP8(
static=False,
group_shape=group_shape,
column_major_scales=True,
use_ue8m0=use_ue8m0,
)
_, scales_col = quant_op_col.forward_native(x.clone())
assert scales_col.shape == (batch_size, expected_num_groups)
assert scales_col.stride(0) == 1
@@ -86,41 +89,48 @@ def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
# Test with 3D input
batch1, batch2, hidden_dim = 4, 8, 1024
x_3d = torch.randn(
(batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
x_3d = (
torch.randn((batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda")
* 8
)
group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False,
use_ue8m0=use_ue8m0)
quant_op = QuantFP8(
static=False,
group_shape=group_shape,
column_major_scales=False,
use_ue8m0=use_ue8m0,
)
x_quant, scales = quant_op.forward_native(x_3d.clone())
assert x_quant.shape == x_3d.shape
assert scales.shape == (batch1, batch2, hidden_dim // group_size)
# Test column_major_scales with multi-dim
quant_op_col = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=True,
use_ue8m0=use_ue8m0)
quant_op_col = QuantFP8(
static=False,
group_shape=group_shape,
column_major_scales=True,
use_ue8m0=use_ue8m0,
)
_, scales_col = quant_op_col.forward_native(x_3d.clone())
assert scales_col.shape == (batch1, batch2, hidden_dim // group_size)
# Test with 4D input
batch1, batch2, batch3, hidden_dim = 2, 3, 4, 256
x_4d = torch.randn((batch1, batch2, batch3, hidden_dim),
dtype=torch.bfloat16,
device="cuda") * 8
x_4d = (
torch.randn(
(batch1, batch2, batch3, hidden_dim), dtype=torch.bfloat16, device="cuda"
)
* 8
)
x_quant_4d, scales_4d = quant_op.forward_native(x_4d.clone())
assert x_quant_4d.shape == x_4d.shape
assert scales_4d.shape == (batch1, batch2, batch3,
hidden_dim // group_size)
assert scales_4d.shape == (batch1, batch2, batch3, hidden_dim // group_size)
_, scales_4d_col = quant_op_col.forward_native(x_4d.clone())
assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size,
batch3)
assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size, batch3)
@pytest.mark.parametrize("seed", [42])
@@ -132,30 +142,24 @@ def test_quantfp8_group_edge_cases(seed: int) -> None:
group_size = 64
# Test with single group (group_size >= hidden_dim)
x_small = torch.randn(
(batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8
x_small = torch.randn((batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8
group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False)
quant_op = QuantFP8(
static=False, group_shape=group_shape, column_major_scales=False
)
x_quant_small, scales_small = quant_op.forward_native(x_small.clone())
assert x_quant_small.shape == x_small.shape
assert scales_small.shape == (batch_size, 1)
# Test with zero inputs
x_zero = torch.zeros((batch_size, 256),
dtype=torch.bfloat16,
device="cuda")
x_zero = torch.zeros((batch_size, 256), dtype=torch.bfloat16, device="cuda")
x_quant_zero, scales_zero = quant_op.forward_native(x_zero.clone())
assert x_quant_zero.shape == x_zero.shape
assert (scales_zero > 0).all(), "Scales should be clamped to minimum"
# Test very large values
x_large = torch.full((batch_size, 256),
1000.0,
dtype=torch.bfloat16,
device="cuda")
x_large = torch.full((batch_size, 256), 1000.0, dtype=torch.bfloat16, device="cuda")
x_quant_large, scales_large = quant_op.forward_native(x_large.clone())
assert x_quant_large.shape == x_large.shape
# FP8 max is typically 448 or 224, so scales should be > 1

View File

@@ -13,33 +13,42 @@ from vllm import _custom_ops as ops # noqa: F401
def test_ggml_opcheck(quant_type):
block_size, type_size = gguf.GGML_QUANT_SIZES[quant_type]
shape = [256, 1152]
qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8)
qweight = torch.randint(0, 100, shape, device="cuda", dtype=torch.uint8)
m = qweight.shape[0]
n = qweight.shape[1] // type_size * block_size
opcheck(torch.ops._C.ggml_dequantize,
(qweight, quant_type, m, n, torch.float16))
opcheck(torch.ops._C.ggml_dequantize, (qweight, quant_type, m, n, torch.float16))
x = torch.rand((m, 512), device='cuda', dtype=torch.float16)
opcheck(torch.ops._C.ggml_mul_mat_a8,
(qweight, x, quant_type, qweight.shape[0]))
opcheck(torch.ops._C.ggml_mul_mat_vec_a8,
(qweight, x, quant_type, qweight.shape[0]))
x = torch.rand((m, 512), device="cuda", dtype=torch.float16)
opcheck(torch.ops._C.ggml_mul_mat_a8, (qweight, x, quant_type, qweight.shape[0]))
opcheck(
torch.ops._C.ggml_mul_mat_vec_a8, (qweight, x, quant_type, qweight.shape[0])
)
shape = [256, 1024, 336]
qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8)
x = torch.rand((1, 1024), device='cuda', dtype=torch.float16)
sorted_token_ids = torch.arange(776, device='cuda')
expert_ids = torch.randint(0, 256, (194, ), device='cuda')
num_tokens_post_padded = torch.tensor([1],
dtype=torch.int64,
device='cuda')
qweight = torch.randint(0, 100, shape, device="cuda", dtype=torch.uint8)
x = torch.rand((1, 1024), device="cuda", dtype=torch.float16)
sorted_token_ids = torch.arange(776, device="cuda")
expert_ids = torch.randint(0, 256, (194,), device="cuda")
num_tokens_post_padded = torch.tensor([1], dtype=torch.int64, device="cuda")
opcheck(torch.ops._C.ggml_moe_a8,
(x, qweight, sorted_token_ids, expert_ids, num_tokens_post_padded,
quant_type, qweight.shape[0], 1, x.shape[0]))
opcheck(
torch.ops._C.ggml_moe_a8,
(
x,
qweight,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
quant_type,
qweight.shape[0],
1,
x.shape[0],
),
)
topk_ids = torch.zeros((1, 1), device='cuda', dtype=torch.int32)
topk_ids = torch.zeros((1, 1), device="cuda", dtype=torch.int32)
opcheck(
torch.ops._C.ggml_moe_a8_vec,
(x, qweight, topk_ids, 1, quant_type, qweight.shape[0], x.shape[0]))
(x, qweight, topk_ids, 1, quant_type, qweight.shape[0], x.shape[0]),
)

View File

@@ -18,8 +18,8 @@ GGUF_SAMPLE_MOE = snapshot_download("SzymonOzog/test-gguf-moe-sample")
def get_gguf_sample_tensors(
hidden_size: int,
quant_type: GGMLQuantizationType) -> list[ReaderTensor]:
hidden_size: int, quant_type: GGMLQuantizationType
) -> list[ReaderTensor]:
sample_dir = GGUF_SAMPLE
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
sample_file = Path(sample_dir) / filename
@@ -27,8 +27,8 @@ def get_gguf_sample_tensors(
def get_gguf_MoE_tensors(
hidden_size: int,
quant_type: GGMLQuantizationType) -> list[ReaderTensor]:
hidden_size: int, quant_type: GGMLQuantizationType
) -> list[ReaderTensor]:
sample_dir = GGUF_SAMPLE_MOE
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
sample_file = Path(sample_dir) / filename
@@ -68,17 +68,20 @@ QUANT_TYPES = [
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
@torch.inference_mode()
def test_dequantize(hidden_size: int, dtype: torch.dtype,
quant_type: GGMLQuantizationType):
def test_dequantize(
hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType
):
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
for tensor in tensors:
shape_str = tensor.name.split("_")[-1]
shape = map(int, shape_str.split("x"))
ref_output = torch.tensor(dequantize(tensor.data, quant_type),
device="cuda").to(dtype)
output = ops.ggml_dequantize(torch.tensor(tensor.data, device="cuda"),
quant_type, *list(shape), dtype)
ref_output = torch.tensor(
dequantize(tensor.data, quant_type), device="cuda"
).to(dtype)
output = ops.ggml_dequantize(
torch.tensor(tensor.data, device="cuda"), quant_type, *list(shape), dtype
)
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=4e-2)
@@ -87,20 +90,21 @@ def test_dequantize(hidden_size: int, dtype: torch.dtype,
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
@torch.inference_mode()
def test_mmvq(hidden_size: int, dtype: torch.dtype,
quant_type: GGMLQuantizationType):
def test_mmvq(hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType):
current_platform.seed_everything(0)
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
x = torch.rand((1, hidden_size), dtype=dtype, device="cuda")
for tensor in tensors:
weight = torch.tensor(dequantize(tensor.data, quant_type),
device="cuda").to(dtype)
weight = torch.tensor(dequantize(tensor.data, quant_type), device="cuda").to(
dtype
)
ref_output = x @ weight.T
qweight = torch.tensor(tensor.data, device="cuda")
output = ops.ggml_mul_mat_vec_a8(qweight, x, quant_type,
qweight.shape[0]).to(dtype)
output = ops.ggml_mul_mat_vec_a8(qweight, x, quant_type, qweight.shape[0]).to(
dtype
)
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)
@@ -121,17 +125,23 @@ def test_mmvq(hidden_size: int, dtype: torch.dtype,
GGMLQuantizationType.Q4_0,
GGMLQuantizationType.Q5_0,
GGMLQuantizationType.Q8_0,
])
],
)
@torch.inference_mode()
def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
quant_type: GGMLQuantizationType):
def test_mmq(
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
quant_type: GGMLQuantizationType,
):
current_platform.seed_everything(0)
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda")
for tensor in tensors:
weight = torch.tensor(dequantize(tensor.data, quant_type),
device="cuda").to(dtype)
weight = torch.tensor(dequantize(tensor.data, quant_type), device="cuda").to(
dtype
)
ref_output = x @ weight.T
qweight = torch.tensor(tensor.data, device="cuda")
@@ -141,10 +151,9 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
# bfloat16 tends to accumulate and can greatly inflate rtol
# since outputs are also very close to 0
rtols = {torch.half: 1e-1, torch.bfloat16: 1e4, torch.float: 2e1}
torch.testing.assert_close(output,
ref_output,
atol=atols[dtype],
rtol=rtols[dtype])
torch.testing.assert_close(
output, ref_output, atol=atols[dtype], rtol=rtols[dtype]
)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@@ -153,35 +162,46 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
@torch.inference_mode()
def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype,
quant_type: GGMLQuantizationType, top_k: int):
def test_moe(
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
quant_type: GGMLQuantizationType,
top_k: int,
):
current_platform.seed_everything(0)
H, E = 1024, 256
x = torch.rand((num_tokens, H), dtype=dtype, device="cuda")
topk_weights = torch.rand(num_tokens, top_k, device="cuda", dtype=dtype)
topk_ids = torch.randint(0,
E, (num_tokens, top_k),
device="cuda",
dtype=torch.int32)
topk_ids = torch.randint(
0, E, (num_tokens, top_k), device="cuda", dtype=torch.int32
)
tensors = get_gguf_MoE_tensors(hidden_size, quant_type)
w13 = tensors[0]
w2 = tensors[1]
w13_dequant = torch.tensor(dequantize(w13.data, quant_type),
device="cuda").to(dtype)
w13_dequant = torch.tensor(dequantize(w13.data, quant_type), device="cuda").to(
dtype
)
w2_dequant = torch.tensor(dequantize(w2.data, quant_type),
device="cuda").to(dtype)
w2_dequant = torch.tensor(dequantize(w2.data, quant_type), device="cuda").to(dtype)
output = _fused_moe_gguf(x, torch.tensor(w13.data, device="cuda"),
torch.tensor(w2.data,
device="cuda"), topk_weights,
topk_ids, quant_type, quant_type, "silu")
output = _fused_moe_gguf(
x,
torch.tensor(w13.data, device="cuda"),
torch.tensor(w2.data, device="cuda"),
topk_weights,
topk_ids,
quant_type,
quant_type,
"silu",
)
ref_output = fused_experts(x, w13_dequant, w2_dequant, topk_weights,
topk_ids).reshape(output.shape)
ref_output = fused_experts(
x, w13_dequant, w2_dequant, topk_weights, topk_ids
).reshape(output.shape)
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)

View File

@@ -8,25 +8,22 @@ from vllm import _custom_ops as ops # noqa: F401
def test_gptq_shuffle_opcheck():
weight = torch.randint(-2000000,
2000000, (1792, 4096),
device='cuda',
dtype=torch.int32)
perm = torch.empty((0, ), device='cuda', dtype=torch.int32)
weight = torch.randint(
-2000000, 2000000, (1792, 4096), device="cuda", dtype=torch.int32
)
perm = torch.empty((0,), device="cuda", dtype=torch.int32)
bit = 4
opcheck(torch.ops._C.gptq_shuffle, (weight, perm, bit))
def test_gptq_gemm_opcheck():
a = torch.rand((240, 4096), device='cuda', dtype=torch.float16)
weight = torch.randint(-2000000,
2000000, (512, 6144),
device='cuda',
dtype=torch.int32)
zeros = torch.zeros((32, 768), device='cuda', dtype=torch.int32)
scales = torch.rand((32, 6144), device='cuda', dtype=torch.float16)
idx = torch.empty((0, ), device='cuda', dtype=torch.int32)
a = torch.rand((240, 4096), device="cuda", dtype=torch.float16)
weight = torch.randint(
-2000000, 2000000, (512, 6144), device="cuda", dtype=torch.int32
)
zeros = torch.zeros((32, 768), device="cuda", dtype=torch.int32)
scales = torch.rand((32, 6144), device="cuda", dtype=torch.float16)
idx = torch.empty((0,), device="cuda", dtype=torch.int32)
use_exllama = True
bit = 4
opcheck(torch.ops._C.gptq_gemm,
(a, weight, zeros, scales, idx, use_exllama, bit))
opcheck(torch.ops._C.gptq_gemm, (a, weight, zeros, scales, idx, use_exllama, bit))

View File

@@ -15,7 +15,8 @@ from vllm import _custom_ops as ops
def test_hadacore(batch_size, hidden_dim, dtype=torch.bfloat16, device="cuda"):
x = torch.eye(hidden_dim, dtype=dtype, device=device)
hadamard = deterministic_hadamard_matrix(
hidden_dim, dtype=torch.float64, device="cuda") / math.sqrt(hidden_dim)
hidden_dim, dtype=torch.float64, device="cuda"
) / math.sqrt(hidden_dim)
y = ops.hadacore_transform(x.clone())
y_true = (x.to(hadamard.dtype) @ hadamard.T).to(y.dtype)

View File

@@ -11,12 +11,12 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_quant_int8)
per_token_quant_int8,
)
from vllm.platforms import current_platform
if current_platform.get_device_capability() < (7, 0):
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
allow_module_level=True)
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True)
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
@@ -26,14 +26,13 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
assert B.ndim == 2 and B.is_contiguous(
), "B must be a 2D contiguous tensor"
assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor"
# Reshape input
M = A.numel() // A.shape[-1]
B = B.t() # Transpose weight matrix
N, K = B.shape
origin_C_shape = A.shape[:-1] + (K, )
origin_C_shape = A.shape[:-1] + (K,)
A = A.reshape(M, N)
# As is per-token [M, 1], Bs is per-column [1, K]
@@ -43,8 +42,7 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
return C.reshape(origin_C_shape).to(output_dtype)
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, topk_weight,
topk_ids):
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, topk_weight, topk_ids):
"""This function performs fused moe with per-column int8 quantization
using native torch."""
@@ -66,25 +64,22 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, topk_weight,
mask = topk_ids == i
if mask.sum():
# First MLP layer: note that a_s is now per-token
inter_out = native_w8a8_per_token_matmul(a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
output_dtype=a.dtype)
inter_out = native_w8a8_per_token_matmul(
a_q[mask], w1[i], a_s[mask], w1_s[i], output_dtype=a.dtype
)
# Activation function
act_out = SiluAndMul().forward_native(inter_out)
# Quantize activation output with per-token
act_out_q, act_out_s = per_token_quant_int8(act_out)
# Second MLP layer
out[mask] = native_w8a8_per_token_matmul(act_out_q,
w2[i],
act_out_s,
w2_s[i],
output_dtype=a.dtype)
out[mask] = native_w8a8_per_token_matmul(
act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype
)
# Apply routing weights and sum
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
@pytest.fixture(autouse=True, scope="module")
@@ -102,8 +97,10 @@ TOP_KS = [2, 6]
SEEDS = [0]
@pytest.mark.parametrize("M, N, K, E, topk, dtype, seed",
itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS))
@pytest.mark.parametrize(
"M, N, K, E, topk, dtype, seed",
itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS),
)
@torch.inference_mode()
def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
torch.manual_seed(seed)
@@ -130,8 +127,9 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weights, topk_ids = torch.topk(score, topk)
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk,
topk_weights, topk_ids)
ref_out = torch_w8a8_per_column_moe(
a, w1, w2, w1_s, w2_s, topk, topk_weights, topk_ids
)
quant_config = FusedMoEQuantConfig.make(
torch.int8,
@@ -151,7 +149,7 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
)
# Check results
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
rel_diff = torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
assert rel_diff < 0.05

View File

@@ -18,26 +18,24 @@ SCALE = [0.1, 2.1]
def opcheck_int8_quant_static(output, input, scale, azp=None):
if azp is None:
opcheck(torch.ops._C.static_scaled_int8_quant,
(output, input, scale, None))
opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale, None))
else:
opcheck(torch.ops._C.static_scaled_int8_quant,
(output, input, scale, azp))
opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale, azp))
def opcheck_int8_quant_dynamic(output, input, symmetric=True):
scale = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
scale = torch.empty(
(input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32
)
if symmetric:
opcheck(torch.ops._C.dynamic_scaled_int8_quant,
(output, input, scale, None))
opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale, None))
else:
azp = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.int32)
opcheck(torch.ops._C.dynamic_scaled_int8_quant,
(output, input, scale, azp))
azp = torch.empty(
(input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.int32,
)
opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale, azp))
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@@ -45,8 +43,9 @@ def opcheck_int8_quant_dynamic(output, input, symmetric=True):
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
def test_dynamic_scaled_int8_quant(
num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int
) -> None:
current_platform.seed_everything(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
@@ -68,30 +67,31 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
def test_dynamic_scaled_int8_azp_quant(
num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int
) -> None:
current_platform.seed_everything(seed)
int8_traits = torch.iinfo(torch.int8)
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
device="cuda") * 1000 - 300
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - 300
x_token_max, _ = x.to(dtype=torch.float32).max(dim=1, keepdim=True)
x_token_min, _ = x.to(dtype=torch.float32).min(dim=1, keepdim=True)
# calculate scale and azp, and adjust the range
scales = (x_token_max - x_token_min) / torch.tensor(255.0)
azps = torch.round(torch.tensor(-128.0) - x_token_min / scales).to(
torch.int32)
azps = torch.round(torch.tensor(-128.0) - x_token_min / scales).to(torch.int32)
torch_out = ((x / scales).round() + azps).clamp(
int8_traits.min, int8_traits.max).to(torch.int8)
assert torch_out.min() >= int8_traits.min and torch_out.max(
) <= int8_traits.max
torch_out = (
((x / scales).round() + azps)
.clamp(int8_traits.min, int8_traits.max)
.to(torch.int8)
)
assert torch_out.min() >= int8_traits.min and torch_out.max() <= int8_traits.max
ops_out, scales_out, azp_out = scaled_int8_quant(x, symmetric=False)
if (not torch.allclose(scales_out, scales)):
if not torch.allclose(scales_out, scales):
print(torch.argmax(torch.abs(scales_out - scales)))
torch.testing.assert_close(scales_out, scales)
# big atol to account for rounding errors
@@ -108,17 +108,18 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("scale", SCALE)
@torch.inference_mode()
def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int,
scale: float) -> None:
def test_static_scaled_int8_quant(
num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int, scale: float
) -> None:
current_platform.seed_everything(seed)
int8_traits = torch.iinfo(torch.int8)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda")
out1 = (x / scale_arg).round().clamp(int8_traits.min,
int8_traits.max).to(torch.int8)
out1 = (
(x / scale_arg).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8)
)
out2, scale2, _ = scaled_int8_quant(x, scale_arg)
assert scale2 is scale_arg
@@ -135,24 +136,28 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
@pytest.mark.parametrize("scale", SCALE)
@pytest.mark.parametrize("azp", [-255, 54])
@torch.inference_mode()
def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int,
scale: float, azp: int) -> None:
def test_static_scaled_int8_azp_quant(
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
seed: int,
scale: float,
azp: int,
) -> None:
current_platform.seed_everything(seed)
int8_traits = torch.iinfo(torch.int8)
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
device="cuda") * 1000 - 300
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - 300
out1 = ((x / scale).round() + azp).clamp(int8_traits.min,
int8_traits.max).to(torch.int8)
out1 = (
((x / scale).round() + azp)
.clamp(int8_traits.min, int8_traits.max)
.to(torch.int8)
)
scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda")
azp_arg = torch.tensor([azp], dtype=torch.int32, device="cuda")
out2, scale2, azp2 = scaled_int8_quant(x,
scale_arg,
azp_arg,
symmetric=False)
out2, scale2, azp2 = scaled_int8_quant(x, scale_arg, azp_arg, symmetric=False)
assert scale2 is scale_arg
assert azp2 is azp_arg
@@ -172,10 +177,7 @@ def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None:
int32_traits = torch.iinfo(torch.int32)
val = float(int32_traits.max if is_max else int32_traits.min)
x_vals = [[
nextafter(val, inf), val + 1, val, val - 1,
nextafter(val, -inf)
]]
x_vals = [[nextafter(val, inf), val + 1, val, val - 1, nextafter(val, -inf)]]
x = torch.tensor(x_vals, dtype=torch.float32, device="cuda")
# The calculation in the kernel is: cast<int8>(cast<int32>(x / scale) + azp)

View File

@@ -15,15 +15,16 @@ import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.machete_utils import (
query_machete_supported_group_sizes)
query_machete_supported_group_sizes,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_rows, quantize_weights)
pack_rows,
quantize_weights,
)
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
# unit tests to a common utility function. Currently the use of
@@ -72,29 +73,38 @@ class Tensors:
# Ch Scales Type, Tok Scales Type)
# NOTE: None "Scale Type" means the act type is floating point
# None "Output Type" means the output type is the same as the act type
TestTypeTuple = tuple[list[torch.dtype], ScalarType, Optional[torch.dtype],
Optional[torch.dtype], bool]
TestTypeTuple = tuple[
list[torch.dtype], ScalarType, Optional[torch.dtype], Optional[torch.dtype], bool
]
TEST_TYPES = [
# GPTQ style
*(TypeConfig(act_type=a_type,
weight_type=w_type,
output_type=None,
group_scale_type=a_type,
group_zero_type=None,
channel_scale_type=None,
token_scale_type=None)
for w_type in [scalar_types.uint4b8, scalar_types.uint8b128]
for a_type in [torch.float16, torch.bfloat16]),
*(
TypeConfig(
act_type=a_type,
weight_type=w_type,
output_type=None,
group_scale_type=a_type,
group_zero_type=None,
channel_scale_type=None,
token_scale_type=None,
)
for w_type in [scalar_types.uint4b8, scalar_types.uint8b128]
for a_type in [torch.float16, torch.bfloat16]
),
# AWQ style
*(TypeConfig(act_type=a_type,
weight_type=w_type,
output_type=None,
group_scale_type=a_type,
group_zero_type=a_type,
channel_scale_type=None,
token_scale_type=None)
for w_type in [scalar_types.uint4, scalar_types.uint8]
for a_type in [torch.float16, torch.bfloat16]),
*(
TypeConfig(
act_type=a_type,
weight_type=w_type,
output_type=None,
group_scale_type=a_type,
group_zero_type=a_type,
channel_scale_type=None,
token_scale_type=None,
)
for w_type in [scalar_types.uint4, scalar_types.uint8]
for a_type in [torch.float16, torch.bfloat16]
),
# # QQQ style
# *(TypeConfig(act_type=torch.int8,
# weight_type=scalar_types.uint4b8,
@@ -133,17 +143,18 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
return zps if zps is None else -1 * s * (zps.to(s.dtype))
def group_size_valid(shape: tuple[int, int, int],
group_size: Optional[int]) -> bool:
def group_size_valid(shape: tuple[int, int, int], group_size: Optional[int]) -> bool:
return group_size is None or group_size == -1 or shape[2] % group_size == 0
def machete_quantize_and_pack(atype: torch.dtype,
w: torch.Tensor,
wtype: ScalarType,
stype: Optional[torch.dtype],
group_size: Optional[int],
zero_points: bool = False):
def machete_quantize_and_pack(
atype: torch.dtype,
w: torch.Tensor,
wtype: ScalarType,
stype: Optional[torch.dtype],
group_size: Optional[int],
zero_points: bool = False,
):
assert wtype.is_integer(), "TODO: support floating point weights"
w_ref, w_q, w_s, w_zp = quantize_weights(
@@ -152,7 +163,8 @@ def machete_quantize_and_pack(atype: torch.dtype,
group_size=group_size,
zero_points=zero_points,
# to match how the kernel applies zps
ref_zero_points_after_scales=True)
ref_zero_points_after_scales=True,
)
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
w_q = w_q.t().contiguous().t() # convert to col major
@@ -163,15 +175,18 @@ def machete_quantize_and_pack(atype: torch.dtype,
return w_ref, w_q_machete, w_s, w_zp
def create_test_tensors(shape: tuple[int, int, int],
types: TypeConfig,
group_size: Optional[int],
subset_stride_factor: Optional[int] = None) -> Tensors:
def create_test_tensors(
shape: tuple[int, int, int],
types: TypeConfig,
group_size: Optional[int],
subset_stride_factor: Optional[int] = None,
) -> Tensors:
m, n, k = shape
factor = subset_stride_factor or 1
print("create_test_tensors, shape:", shape, "types:", types, "group_size:",
group_size)
print(
"create_test_tensors, shape:", shape, "types:", types, "group_size:", group_size
)
a = rand_data((m * factor, k * factor), types.act_type, scale=3, offset=2)
w = rand_data((k * factor, n * factor), types.act_type, scale=3, offset=1)
@@ -186,8 +201,13 @@ def create_test_tensors(shape: tuple[int, int, int],
w = w.to(torch.float16)
w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
a.dtype, w, types.weight_type, types.group_scale_type, group_size,
types.group_zero_type is not None)
a.dtype,
w,
types.weight_type,
types.group_scale_type,
group_size,
types.group_zero_type is not None,
)
if not a.dtype.is_floating_point:
aiinfo = torch.iinfo(a.dtype)
@@ -196,35 +216,47 @@ def create_test_tensors(shape: tuple[int, int, int],
a_ref = a.to(torch.float32)
w_ref = w_ref.to(torch.float32)
w_ch_s = None if types.channel_scale_type is None else\
rand_data((n,), types.channel_scale_type)
w_tok_s = None if types.token_scale_type is None else\
rand_data((m,), types.token_scale_type)
w_ch_s = (
None
if types.channel_scale_type is None
else rand_data((n,), types.channel_scale_type)
)
w_tok_s = (
None
if types.token_scale_type is None
else rand_data((m,), types.token_scale_type)
)
return Tensors(w_ref=w_ref,
a_ref=a_ref,
a=a,
w_q=w_q_packed,
w_g_s=w_s,
w_g_zp=maybe_convert_zeropoints(w_zp, w_s),
w_ch_s=w_ch_s,
w_tok_s=w_tok_s)
return Tensors(
w_ref=w_ref,
a_ref=a_ref,
a=a,
w_q=w_q_packed,
w_g_s=w_s,
w_g_zp=maybe_convert_zeropoints(w_zp, w_s),
w_ch_s=w_ch_s,
w_tok_s=w_tok_s,
)
# None stype means scales use the same dtype as a
def machete_mm_test_helper(types: TypeConfig,
tensors: Tensors,
group_size: Optional[int] = None,
schedule: Optional[str] = None):
def machete_mm_test_helper(
types: TypeConfig,
tensors: Tensors,
group_size: Optional[int] = None,
schedule: Optional[str] = None,
):
output_ref = torch.matmul(tensors.a_ref, tensors.w_ref)
output_ref_type = output_ref.dtype
if tensors.w_ch_s is not None:
output_ref = (output_ref.to(tensors.w_ch_s.dtype) *
tensors.w_ch_s.unsqueeze(0)).to(output_ref_type)
output_ref = (
output_ref.to(tensors.w_ch_s.dtype) * tensors.w_ch_s.unsqueeze(0)
).to(output_ref_type)
if tensors.w_tok_s is not None:
output_ref = (output_ref.to(tensors.w_tok_s.dtype) *
tensors.w_tok_s.unsqueeze(1)).to(output_ref_type)
output_ref = (
output_ref.to(tensors.w_tok_s.dtype) * tensors.w_tok_s.unsqueeze(1)
).to(output_ref_type)
output = ops.machete_mm(
a=tensors.a,
@@ -245,23 +277,23 @@ def machete_mm_test_helper(types: TypeConfig,
# Relax atol as our reduction dim becomes larger (more rounding error)
# Relax atol when we have zeropoints since the way machete applies
# zeropoints (after scales) causes noise around 0
atol = 1 if tensors.w_g_zp is not None\
atol = (
1
if tensors.w_g_zp is not None
else min(5e-2 * math.sqrt(tensors.a.shape[1]), 1)
)
rtol = 1e-1 if tensors.a.element_size() >= 2 else 2e-1
torch.testing.assert_close(output,
output_ref.to(output.dtype),
rtol=rtol,
atol=atol)
torch.testing.assert_close(
output, output_ref.to(output.dtype), rtol=rtol, atol=atol
)
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
reason="Machete is not supported on this GPU type.")
@pytest.mark.parametrize("shape",
MNK_SHAPES,
ids=lambda x: "x".join(str(v) for v in x))
@pytest.mark.skipif(
not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type."
)
@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x))
@pytest.mark.parametrize("types", TEST_TYPES)
def test_machete_all_schedules(shape, types: TypeConfig):
group_sizes: list[Optional[int]] = []
if types.group_scale_type is None:
group_sizes = [None]
@@ -275,20 +307,20 @@ def test_machete_all_schedules(shape, types: TypeConfig):
tensors = create_test_tensors(shape, types, group_size)
print(f"MNK = {shape}")
for schedule in ops.machete_supported_schedules(
types.act_type,
types.weight_type,
group_scales_type=types.group_scale_type,
group_zeros_type=types.group_scale_type,
out_type=types.output_type):
types.act_type,
types.weight_type,
group_scales_type=types.group_scale_type,
group_zeros_type=types.group_scale_type,
out_type=types.output_type,
):
print(f"Testing schedule {schedule}")
machete_mm_test_helper(types, tensors, group_size, schedule)
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
reason="Machete is not supported on this GPU type.")
@pytest.mark.parametrize("shape",
MNK_SHAPES,
ids=lambda x: "x".join(str(v) for v in x))
@pytest.mark.skipif(
not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type."
)
@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x))
@pytest.mark.parametrize("types", TEST_TYPES)
def test_machete_heuristic(shape, types: TypeConfig):
group_sizes: list[Optional[int]] = []
@@ -306,19 +338,22 @@ def test_machete_heuristic(shape, types: TypeConfig):
# Test working on other devices
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
reason="Machete is not supported on this GPU type.")
@pytest.mark.skipif(
not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type."
)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_machete_devices(device: str):
group_size = 128
type_config = TypeConfig(act_type=torch.float16,
weight_type=scalar_types.uint4b8,
output_type=None,
group_scale_type=torch.float16,
group_zero_type=None,
channel_scale_type=None,
token_scale_type=None)
type_config = TypeConfig(
act_type=torch.float16,
weight_type=scalar_types.uint4b8,
output_type=None,
group_scale_type=torch.float16,
group_zero_type=None,
channel_scale_type=None,
token_scale_type=None,
)
tensors = create_test_tensors((512, 4096, 4096), type_config, group_size)
@@ -331,29 +366,30 @@ def test_machete_devices(device: str):
# Test working with a subset of A and B
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
reason="Machete is not supported on this GPU type.")
@pytest.mark.skipif(
not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type."
)
def test_machete_subset():
group_size = 128
type_config = TypeConfig(act_type=torch.float16,
weight_type=scalar_types.uint4b8,
output_type=None,
group_scale_type=torch.float16,
group_zero_type=None,
channel_scale_type=None,
token_scale_type=None)
type_config = TypeConfig(
act_type=torch.float16,
weight_type=scalar_types.uint4b8,
output_type=None,
group_scale_type=torch.float16,
group_zero_type=None,
channel_scale_type=None,
token_scale_type=None,
)
tensors = create_test_tensors((512, 4096, 4096),
type_config,
group_size,
subset_stride_factor=2)
tensors = create_test_tensors(
(512, 4096, 4096), type_config, group_size, subset_stride_factor=2
)
machete_mm_test_helper(type_config, tensors, group_size)
# Test to make sure cuda graphs work
class MacheteLayer(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__()
self.kwargs = kwargs
@@ -362,8 +398,9 @@ class MacheteLayer(torch.nn.Module):
return ops.machete_mm(a=a, **self.kwargs)
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
reason="Machete is not supported on this GPU type.")
@pytest.mark.skipif(
not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type."
)
def test_machete_cuda_graph():
m, n, k = 512, 4096, 4096
@@ -375,7 +412,8 @@ def test_machete_cuda_graph():
zero_points = False
w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
a.dtype, b, wtype, stype, group_size, zero_points)
a.dtype, b, wtype, stype, group_size, zero_points
)
# Construct a trivial model with a single layer that calls a machete kernel
model = MacheteLayer(

View File

@@ -4,6 +4,7 @@
Run `pytest tests/kernels/quantization/test_marlin_gemm.py`.
"""
import pytest
import torch
@@ -11,24 +12,44 @@ from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from tests.quantization.utils import is_quant_method_supported
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
GPTQ_MARLIN_24_MAX_PARALLEL,
GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales,
query_marlin_supported_quant_types)
MARLIN_SUPPORTED_GROUP_SIZES,
marlin_make_empty_g_idx,
marlin_make_workspace_new,
marlin_permute_bias,
marlin_permute_scales,
query_marlin_supported_quant_types,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_mxfp4_like,
rand_marlin_weight_nvfp4_like)
FP4_MARLIN_SUPPORTED_GROUP_SIZES,
rand_marlin_weight_mxfp4_like,
rand_marlin_weight_nvfp4_like,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
marlin_quant_fp8_torch)
marlin_quant_fp8_torch,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize,
marlin_weights)
MarlinWorkspace,
awq_marlin_quantize,
get_weight_perm,
marlin_quantize,
marlin_weights,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
marlin_24_quantize)
marlin_24_quantize,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
awq_pack,
gptq_pack,
gptq_quantize_weights,
quantize_weights,
sort_weights,
)
from vllm.scalar_type import scalar_types
ACT_ORDER_OPTS = [False, True]
@@ -56,24 +77,27 @@ DTYPES = [torch.float16, torch.bfloat16]
def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref))
torch.abs(output_ref)
)
def rand_data(shape, dtype=torch.float16):
return torch.randn(shape, dtype=dtype, device="cuda")
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.",
)
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(False, False))
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False, False))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
act_order, mnk_factors):
def test_gptq_marlin_repack(
k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors
):
m_factor, n_factor, k_factor = mnk_factors
size_k = k_chunk * k_factor
@@ -96,7 +120,8 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
# Quantize (and apply act_order if provided)
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
b_weight, quant_type, group_size, act_order)
b_weight, quant_type, group_size, act_order
)
# Pack to GPTQ format
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
@@ -109,11 +134,14 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
# Pack to Marlin format
weight_perm = get_weight_perm(quant_type.size_bits)
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
weight_perm)
marlin_q_w_1 = marlin_weights(
q_w, size_k, size_n, quant_type.size_bits, weight_perm
)
opcheck(torch.ops._C.gptq_marlin_repack,
(q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits))
opcheck(
torch.ops._C.gptq_marlin_repack,
(q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits),
)
# Run Marlin repack GPU kernel
marlin_q_w_2 = ops.gptq_marlin_repack(
@@ -128,16 +156,16 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.",
)
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(True))
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(True))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
mnk_factors):
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
m_factor, n_factor, k_factor = mnk_factors
size_k = k_chunk * k_factor
@@ -152,21 +180,22 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
b_weight = rand_data((size_k, size_n))
# Quantize
w_ref, q_w, s, zp = quantize_weights(b_weight,
quant_type,
group_size,
zero_points=True)
w_ref, q_w, s, zp = quantize_weights(
b_weight, quant_type, group_size, zero_points=True
)
# Pack to AWQ format
q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
# Pack to Marlin format
weight_perm = get_weight_perm(quant_type.size_bits)
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
weight_perm)
marlin_q_w_1 = marlin_weights(
q_w, size_k, size_n, quant_type.size_bits, weight_perm
)
opcheck(torch.ops._C.awq_marlin_repack,
(q_w_awq, size_k, size_n, quant_type.size_bits))
opcheck(
torch.ops._C.awq_marlin_repack, (q_w_awq, size_k, size_n, quant_type.size_bits)
)
# Run Marlin repack GPU kernel
marlin_q_w_2 = ops.awq_marlin_repack(
@@ -180,23 +209,34 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.",
)
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types())
@pytest.mark.parametrize(
"group_size",
set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES))
"group_size", set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES)
)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
mnk_factors, act_order, is_k_full, use_atomic_add,
use_fp32_reduce, dtype):
def test_gptq_marlin_gemm(
k_chunk,
n_chunk,
quant_type,
group_size,
mnk_factors,
act_order,
is_k_full,
use_atomic_add,
use_fp32_reduce,
dtype,
):
m_factor, n_factor, k_factor = mnk_factors
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
@@ -225,11 +265,13 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
return
if group_size == 16:
w_ref, marlin_q_w, marlin_s, marlin_s2 = \
rand_marlin_weight_nvfp4_like(b_weight.T, group_size)
w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_nvfp4_like(
b_weight.T, group_size
)
else:
w_ref, marlin_q_w, marlin_s = \
rand_marlin_weight_mxfp4_like(b_weight.T, group_size)
w_ref, marlin_q_w, marlin_s = rand_marlin_weight_mxfp4_like(
b_weight.T, group_size
)
marlin_s2 = None
g_idx = None
@@ -240,8 +282,7 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
return
if act_order:
return
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
b_weight.T, group_size)
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b_weight.T, group_size)
g_idx = None
sort_indices = None
marlin_zp = None
@@ -250,7 +291,8 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
if group_size == 16:
return
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
b_weight, quant_type, group_size)
b_weight, quant_type, group_size
)
g_idx = None
sort_indices = None
marlin_s2 = None
@@ -258,18 +300,37 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
if group_size == 16:
return
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, quant_type, group_size, act_order)
b_weight, quant_type, group_size, act_order
)
marlin_zp = None
marlin_s2 = None
workspace = marlin_make_workspace_new(w_ref.device)
opcheck(torch.ops._C.gptq_marlin_gemm,
(a_input, None, marlin_q_w, None, marlin_s, marlin_s2, marlin_zp,
g_idx, sort_indices, workspace, quant_type.id, a_input.shape[0],
b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add,
use_fp32_reduce, False),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
opcheck(
torch.ops._C.gptq_marlin_gemm,
(
a_input,
None,
marlin_q_w,
None,
marlin_s,
marlin_s2,
marlin_zp,
g_idx,
sort_indices,
workspace,
quant_type.id,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full,
use_atomic_add,
use_fp32_reduce,
False,
),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
output = ops.gptq_marlin_gemm(
a_input,
@@ -302,23 +363,40 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
# TODO: find better way to test this?
@torch.compile(fullgraph=True)
def marlin_24_gemm_tester(a_input, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s, scratch, quant_type, size_m, size_n,
size_k):
return ops.gptq_marlin_24_gemm(a_input, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s, scratch, quant_type, size_m,
size_n, size_k)
def marlin_24_gemm_tester(
a_input,
marlin_24_q_w_comp,
marlin_24_meta,
marlin_24_s,
scratch,
quant_type,
size_m,
size_n,
size_k,
):
return ops.gptq_marlin_24_gemm(
a_input,
marlin_24_q_w_comp,
marlin_24_meta,
marlin_24_s,
scratch,
quant_type,
size_m,
size_n,
size_k,
)
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.",
)
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
mnk_factors):
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
@@ -328,19 +406,31 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))
(w_24_ref, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s) = marlin_24_quantize(b_weight, quant_type, group_size)
(w_24_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = marlin_24_quantize(
b_weight, quant_type, group_size
)
workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_MAX_PARALLEL)
workspace_24 = MarlinWorkspace(
size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
)
output_ref = torch.matmul(a_input, w_24_ref)
opcheck(torch.ops._C.gptq_marlin_24_gemm,
(a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s,
workspace_24.scratch, quant_type.id, a_input.shape[0],
b_weight.shape[1], a_input.shape[1]),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
opcheck(
torch.ops._C.gptq_marlin_24_gemm,
(
a_input,
marlin_24_q_w_comp,
marlin_24_meta,
marlin_24_s,
workspace_24.scratch,
quant_type.id,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
output = marlin_24_gemm_tester(
a_input,
@@ -361,8 +451,10 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
assert max_diff < 0.04
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.",
)
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("group_size", HQQ_SUPPORTED_GROUP_SIZES)
@@ -386,22 +478,22 @@ def test_hqq_marlin_gemm(
a_input = rand_data((size_m, size_k))
dev = a_input.device
b_weight = torch.randint(0,
10, (size_n, size_k),
dtype=torch.uint8,
device=dev)
b_weight = torch.randint(0, 10, (size_n, size_k), dtype=torch.uint8, device=dev)
scale = rand_data((size_n, size_k // group_size))
zero = rand_data((size_n, size_k // group_size))
gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n)
sort_indices = torch.empty(0, dtype=torch.int, device=dev)
marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n,
4).to(dev)
marlin_s = marlin_permute_scales(scale.transpose(1, 0), size_k, size_n,
group_size).to(dev)
marlin_zp = marlin_permute_scales(zero.transpose(1, 0), size_k, size_n,
group_size).to(dev)
marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n, 4).to(
dev
)
marlin_s = marlin_permute_scales(
scale.transpose(1, 0), size_k, size_n, group_size
).to(dev)
marlin_zp = marlin_permute_scales(
zero.transpose(1, 0), size_k, size_n, group_size
).to(dev)
g_idx = marlin_make_empty_g_idx(dev)
g_idx_sort_indices = marlin_make_empty_g_idx(dev)
@@ -433,8 +525,7 @@ def test_hqq_marlin_gemm(
s_flat = scale.reshape(-1, 1)
dequant = (b_flat - zp_flat) * s_flat
output_ref = torch.matmul(a_input,
dequant.reshape(b_weight.shape).transpose(1, 0))
output_ref = torch.matmul(a_input, dequant.reshape(b_weight.shape).transpose(1, 0))
torch.cuda.synchronize()
@@ -451,11 +542,12 @@ def test_marlin_gemm_subset_input():
big_m = size_m * 2
big_k = size_k * 2
a_input = rand_data((big_m, big_k))[8:size_m + 8, 8:size_k + 8]
a_input = rand_data((big_m, big_k))[8 : size_m + 8, 8 : size_k + 8]
b_weight = rand_data((size_k, size_n))
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, quant_type, group_size, False)
b_weight, quant_type, group_size, False
)
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
workspace = marlin_make_workspace_new(a_input.device)
@@ -497,12 +589,13 @@ def test_marlin_gemm_with_bias(size_m):
size_k, size_n = 1024, 2048
a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))
b_bias = rand_data((size_n, )) * 10
b_bias = rand_data((size_n,)) * 10
marlin_bias = marlin_permute_bias(b_bias)
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, quant_type, group_size, False)
b_weight, quant_type, group_size, False
)
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
workspace = marlin_make_workspace_new(a_input.device)

View File

@@ -8,15 +8,27 @@ from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
if not current_platform.has_device_capability(100):
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True)
pytest.skip(
reason="Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True,
)
DTYPES = [torch.float16, torch.bfloat16]
SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)]
PAD_SHAPES = [(90, 64), (150, 64), (128, 48), (128, 80), (150, 80), (90, 48),
(90, 128), (150, 128), (150, 48), (90, 80)]
PAD_SHAPES = [
(90, 64),
(150, 64),
(128, 48),
(128, 80),
(150, 80),
(90, 48),
(90, 128),
(150, 128),
(150, 48),
(90, 80),
]
SEEDS = [42]
CUDA_DEVICES = ['cuda:0']
CUDA_DEVICES = ["cuda:0"]
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
@@ -31,7 +43,22 @@ FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
# 0001 -> 0.5
# 0000 -> 0
E2M1_TO_FLOAT32 = [
0., 0.5, 1., 1.5, 2., 3., 4., 6., 0., -0.5, -1., -1.5, -2., -3., -4., -6.
0.0,
0.5,
1.0,
1.5,
2.0,
3.0,
4.0,
6.0,
0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
]
BLOCK_SIZE = 16
@@ -74,8 +101,7 @@ def ref_nvfp4_quant(x, global_scale):
assert x.ndim == 2
m, n = x.shape
x = torch.reshape(x, (m, n // BLOCK_SIZE, BLOCK_SIZE))
vec_max = torch.max(torch.abs(x), dim=-1,
keepdim=True)[0].to(torch.float32)
vec_max = torch.max(torch.abs(x), dim=-1, keepdim=True)[0].to(torch.float32)
scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX))
scale = scale.to(torch.float8_e4m3fn).to(torch.float32)
output_scale = get_reciprocal(scale * get_reciprocal(global_scale))
@@ -131,7 +157,7 @@ def test_quantize_to_fp4(
def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
dtype = torch.float16
current_platform.seed_everything(42)
torch.set_default_device('cuda:0')
torch.set_default_device("cuda:0")
m, n = pad_shape

View File

@@ -2,15 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, dequantize_nvfp4_to_dtype
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
if not current_platform.has_device_capability(100):
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True)
pytest.skip(
reason="Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True,
)
DTYPES = [torch.float16, torch.bfloat16]
# m, n, k
@@ -19,26 +20,31 @@ PAD_SHAPES = [(150, 128, 64), (128, 128, 96)]
SHAPES.extend(PAD_SHAPES)
SEEDS = [42]
CUDA_DEVICES = ['cuda:0']
CUDA_DEVICES = ["cuda:0"]
def get_ref_results(a_fp4, b_fp4, a_sf, b_sf, a_global_scale, b_global_scale,
m, n, dtype, block_size, device):
def get_ref_results(
a_fp4,
b_fp4,
a_sf,
b_sf,
a_global_scale,
b_global_scale,
m,
n,
dtype,
block_size,
device,
):
_, m_k = a_fp4.shape
_, n_k = b_fp4.shape
assert (m_k == n_k)
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
a_sf,
a_global_scale,
dtype=dtype,
device=device,
block_size=block_size)
b_in_dtype = dequantize_nvfp4_to_dtype(b_fp4,
b_sf,
b_global_scale,
dtype=dtype,
device=device,
block_size=block_size)
assert m_k == n_k
a_in_dtype = dequantize_nvfp4_to_dtype(
a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size
)
b_in_dtype = dequantize_nvfp4_to_dtype(
b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size
)
return torch.matmul(a_in_dtype, b_in_dtype.t())
@@ -60,25 +66,34 @@ def test_nvfp4_gemm(
a_dtype = torch.randn((m, k), dtype=dtype, device=device)
b_dtype = torch.randn((n, k), dtype=dtype, device=device)
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32)
b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32)
alpha = 1. / (a_global_scale * b_global_scale)
a_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1)
).to(torch.float32)
b_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)
).to(torch.float32)
alpha = 1.0 / (a_global_scale * b_global_scale)
# ops.scaled_fp4_quant returns swizzled scales, while weights
# from checkpoints are in linear scales.
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale)
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale)
# get_ref_results unswizzles the scales internally.
expected_out = get_ref_results(a_fp4, b_fp4, a_scale_interleaved,
b_scale_interleaved, a_global_scale,
b_global_scale, m, n, dtype, block_size,
device)
out = ops.cutlass_scaled_fp4_mm(a_fp4, b_fp4, a_scale_interleaved,
b_scale_interleaved, alpha, dtype)
expected_out = get_ref_results(
a_fp4,
b_fp4,
a_scale_interleaved,
b_scale_interleaved,
a_global_scale,
b_global_scale,
m,
n,
dtype,
block_size,
device,
)
out = ops.cutlass_scaled_fp4_mm(
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
)
torch.testing.assert_close(out,
expected_out.to(dtype=dtype),
atol=1e-1,
rtol=1e-1)
torch.testing.assert_close(out, expected_out.to(dtype=dtype), atol=1e-1, rtol=1e-1)

View File

@@ -13,15 +13,15 @@ from vllm.model_executor.layers.quantization.utils import fp8_utils, int8_utils
@pytest.mark.parametrize("scale_ue8m0", [False, True])
@pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_per_token_group_quant_fp8(shape, column_major: bool,
scale_ue8m0: bool, group_size: int):
def test_per_token_group_quant_fp8(
shape, column_major: bool, scale_ue8m0: bool, group_size: int
):
device = "cuda"
torch.manual_seed(42)
num_tokens, hidden_dim = shape
x = (torch.randn(
(num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8)
x = torch.randn((num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8
# cuda path
out_q, scale = fp8_utils.per_token_group_quant_fp8(
@@ -53,8 +53,7 @@ def test_per_token_group_quant_int8(shape, group_size: int):
torch.manual_seed(42)
num_tokens, hidden_dim = shape
x = (torch.randn(
(num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8)
x = torch.randn((num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8
# cuda path
out_q, scale = int8_utils.per_token_group_quant_int8(

View File

@@ -63,12 +63,11 @@ SEEDS = [0]
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16])
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="only test for rocm")
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
@torch.inference_mode()
def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
torch.manual_seed(seed)
#TODO: Zero-centering the inputs causes errors for LLMM1!
# TODO: Zero-centering the inputs causes errors for LLMM1!
# Without that the numbers quickly saturate, and may
# be giving false matches.
A = torch.rand(n, k, dtype=dtype, device="cuda")
@@ -83,14 +82,13 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="only test for rocm")
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = current_platform.get_cu_count()
A = torch.rand(n, k, dtype=dtype, device="cuda") - .5
B = torch.rand(m, k, dtype=dtype, device="cuda") - .5
A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5
B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5
ref_out = torch.nn.functional.linear(A, B)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count)
@@ -101,16 +99,15 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="only test for rocm")
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = current_platform.get_cu_count()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") - .5) * xavier
B = (torch.rand(m, k, dtype=dtype, device="cuda") - .5) * xavier
BIAS = torch.rand(m, dtype=dtype, device="cuda") - .5
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier
BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5
ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
@@ -121,16 +118,15 @@ def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="only test for rocm")
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = current_platform.get_cu_count()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") - .5) * xavier
B = (torch.rand(m, k, dtype=dtype, device="cuda") - .5) * xavier
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - .5
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - 0.5
ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
@@ -143,7 +139,8 @@ def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(
not (current_platform.is_rocm() and current_platform.supports_fp8()),
reason="only test for rocm fp8")
reason="only test for rocm fp8",
)
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
@@ -153,13 +150,10 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
ref_out = torch._scaled_mm(A,
B.t(),
out_dtype=dtype,
scale_a=scale_a,
scale_b=scale_b)
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b,
current_platform.get_cu_count())
ref_out = torch._scaled_mm(
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b
)
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, current_platform.get_cu_count())
assert torch.allclose(out, ref_out, rtol=0.01)
@@ -169,25 +163,24 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(
not (current_platform.is_rocm() and current_platform.supports_fp8()),
reason="only test for rocm fp8")
reason="only test for rocm fp8",
)
def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, device="cuda") - .5) * xavier
B = (torch.rand(m, k, device="cuda") - .5) * xavier
BIAS = torch.rand(m, dtype=dtype, device="cuda") - .5
A = (torch.rand(n, k, device="cuda") - 0.5) * xavier
B = (torch.rand(m, k, device="cuda") - 0.5) * xavier
BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
ref_out = torch._scaled_mm(A,
B.t(),
out_dtype=dtype,
scale_a=scale_a,
scale_b=scale_b,
bias=BIAS)
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b,
current_platform.get_cu_count(), BIAS)
ref_out = torch._scaled_mm(
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS
)
out = ops.wvSplitKQ(
B, A, dtype, scale_a, scale_b, current_platform.get_cu_count(), BIAS
)
assert torch.allclose(out, ref_out, rtol=0.01)

View File

@@ -3,16 +3,20 @@
import pytest
import torch
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from tests.kernels.quantization.nvfp4_utils import (
FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype,
)
from vllm._custom_ops import scaled_fp4_quant
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.platforms import current_platform
if not current_platform.has_device_capability(100):
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True)
pytest.skip(
reason="Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True,
)
FP4_DTYPE = torch.uint8
FP8_DTYPE = current_platform.fp8_dtype()
@@ -30,24 +34,24 @@ def test_silu_mul_nvfp4_quant(
shape: tuple[int, int],
) -> None:
current_platform.seed_everything(42)
device = 'cuda:0'
device = "cuda:0"
torch.set_default_device(device)
x = torch.randn(shape, dtype=dtype)
# ref op
ref_output = SiluAndMul().forward_native(x)
ref_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.abs(ref_output).max().to(torch.float32))
ref_output_quant, ref_block_scale = scaled_fp4_quant(
ref_output, ref_global_scale)
ref_global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(
ref_output
).max().to(torch.float32)
ref_output_quant, ref_block_scale = scaled_fp4_quant(ref_output, ref_global_scale)
# fused op
fused_output_quant = torch.empty_like(ref_output_quant)
fused_block_scale = torch.empty_like(ref_block_scale)
torch.ops._C.silu_and_mul_nvfp4_quant(fused_output_quant,
fused_block_scale, x,
ref_global_scale)
torch.ops._C.silu_and_mul_nvfp4_quant(
fused_output_quant, fused_block_scale, x, ref_global_scale
)
# check dtype
assert ref_output_quant.dtype == FP4_DTYPE
@@ -59,17 +63,14 @@ def test_silu_mul_nvfp4_quant(
assert ref_block_scale.shape == fused_block_scale.shape
# check dequantized output
ref_output_dequant = dequantize_nvfp4_to_dtype(ref_output_quant,
ref_block_scale,
ref_global_scale, dtype,
device)
fused_output_dequant = dequantize_nvfp4_to_dtype(fused_output_quant,
fused_block_scale,
ref_global_scale, dtype,
device)
ref_output_dequant = dequantize_nvfp4_to_dtype(
ref_output_quant, ref_block_scale, ref_global_scale, dtype, device
)
fused_output_dequant = dequantize_nvfp4_to_dtype(
fused_output_quant, fused_block_scale, ref_global_scale, dtype, device
)
atol, rtol = 3e-1, 3e-1
torch.testing.assert_close(ref_output_dequant,
fused_output_dequant,
atol=atol,
rtol=rtol)
torch.testing.assert_close(
ref_output_dequant, fused_output_dequant, atol=atol, rtol=rtol
)

View File

@@ -4,6 +4,7 @@
Run `pytest tests/kernels/quantization/test_triton_scaled_mm.py`.
"""
import importlib
from typing import Optional
@@ -15,17 +16,19 @@ from vllm.platforms import current_platform
device = "cuda"
triton_scaled_mm_module = importlib.import_module(
"vllm.model_executor.layers.quantization.compressed_tensors."
"triton_scaled_mm")
"vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm"
)
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
def torch_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: type[torch.dtype],
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def torch_scaled_mm(
a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: type[torch.dtype],
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
out = torch.mm(a.to(torch.float32), b.to(torch.float32))
out = scale_a * out
out = scale_b.T * out
@@ -44,20 +47,22 @@ def get_8bit_types():
# This test is to check regressions for int8 support on ROCm.
@pytest.mark.parametrize("model_path", [
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
])
@pytest.mark.parametrize(
"model_path",
[
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
],
)
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10])
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="Should only run on ROCm")
def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path,
max_tokens, num_logprobs):
@pytest.mark.skipif(not current_platform.is_rocm(), reason="Should only run on ROCm")
def test_rocm_compressed_tensors_w8a8(
vllm_runner, example_prompts, model_path, max_tokens, num_logprobs
):
dtype = "bfloat16"
with vllm_runner(model_path, dtype=dtype) as vllm_model:
vllm_model.generate_greedy_logprobs(example_prompts, max_tokens,
num_logprobs)
vllm_model.generate_greedy_logprobs(example_prompts, max_tokens, num_logprobs)
MNK_FACTORS = [
@@ -76,10 +81,10 @@ MNK_FACTORS = [
@pytest.mark.parametrize("use_scalar_scale_a", [True, False])
@pytest.mark.parametrize("use_scalar_scale_b", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a,
use_scalar_scale_b, use_bias):
is_floating_point_type = lambda t: torch.tensor([1, 1], dtype=t
).is_floating_point()
def test_scaled_mm(
M, N, K, in_dtype, out_dtype, use_scalar_scale_a, use_scalar_scale_b, use_bias
):
is_floating_point_type = lambda t: torch.tensor([1, 1], dtype=t).is_floating_point()
current_platform.seed_everything(0)
@@ -93,10 +98,8 @@ def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a,
#
# So, the values here are kept small enough to avoid this situation.
if is_floating_point_type(in_dtype):
a = (0.25 * torch.rand(
(M, K), dtype=torch.float32, device=device)).to(in_dtype)
b = (0.25 * torch.rand(
(K, N), dtype=torch.float32, device=device)).to(in_dtype)
a = (0.25 * torch.rand((M, K), dtype=torch.float32, device=device)).to(in_dtype)
b = (0.25 * torch.rand((K, N), dtype=torch.float32, device=device)).to(in_dtype)
else:
a = torch.randint(-32, 32, (M, K), dtype=in_dtype, device=device)
b = torch.randint(-32, 32, (K, N), dtype=in_dtype, device=device)
@@ -113,7 +116,7 @@ def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a,
bias = None
if use_bias:
bias = torch.rand((N, ), device=device, dtype=out_dtype)
bias = torch.rand((N,), device=device, dtype=out_dtype)
c_check = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)

View File

@@ -4,8 +4,10 @@ import pytest
import torch
from tests.kernels.utils import opcheck
from vllm._custom_ops import (apply_repetition_penalties_cuda,
apply_repetition_penalties_torch)
from vllm._custom_ops import (
apply_repetition_penalties_cuda,
apply_repetition_penalties_torch,
)
from vllm.platforms import current_platform
NUM_SEQS = [1, 2, 3, 4, 8, 13, 17, 32, 37, 256, 1023, 1024, 1025]
@@ -21,8 +23,9 @@ DTYPES = [torch.float32, torch.float16]
@pytest.mark.parametrize("repetition_penalty", REPETITION_PENALTY_VALUES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_cuda(),
reason="This test for checking CUDA kernel")
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="This test for checking CUDA kernel"
)
@torch.inference_mode()
def test_apply_repetition_penalties(
num_seqs: int,
@@ -32,7 +35,7 @@ def test_apply_repetition_penalties(
seed: int,
) -> None:
"""
Test the apply_repetition_penalties custom op
Test the apply_repetition_penalties custom op
against a reference implementation.
"""
current_platform.seed_everything(seed)
@@ -46,39 +49,40 @@ def test_apply_repetition_penalties(
output_mask = torch.zeros(num_seqs, vocab_size, dtype=torch.bool)
# Mark some tokens as repeated in prompt and output
prompt_indices = torch.randint(0, vocab_size,
(num_seqs, max(1, vocab_size // 200)))
output_indices = torch.randint(0, vocab_size,
(num_seqs, max(1, vocab_size // 200)))
prompt_indices = torch.randint(0, vocab_size, (num_seqs, max(1, vocab_size // 200)))
output_indices = torch.randint(0, vocab_size, (num_seqs, max(1, vocab_size // 200)))
for i in range(num_seqs):
prompt_mask[i, prompt_indices[i]] = True
output_mask[i, output_indices[i]] = True
# Create repetition penalties tensor
repetition_penalties = torch.full((num_seqs, ),
repetition_penalty,
dtype=dtype)
repetition_penalties = torch.full((num_seqs,), repetition_penalty, dtype=dtype)
# Run all three implementations
logits_torch = logits.clone()
logits_cuda = logits.clone()
apply_repetition_penalties_torch(logits_torch, prompt_mask, output_mask,
repetition_penalties)
apply_repetition_penalties_cuda(logits_cuda, prompt_mask, output_mask,
repetition_penalties)
apply_repetition_penalties_torch(
logits_torch, prompt_mask, output_mask, repetition_penalties
)
apply_repetition_penalties_cuda(
logits_cuda, prompt_mask, output_mask, repetition_penalties
)
# Compare all outputs to reference
torch.testing.assert_close(logits_torch, logits_cuda, rtol=1e-3, atol=1e-3)
# Test the operator by applying the opcheck utility
opcheck(torch.ops._C.apply_repetition_penalties_,
(logits.clone(), prompt_mask, output_mask, repetition_penalties))
opcheck(
torch.ops._C.apply_repetition_penalties_,
(logits.clone(), prompt_mask, output_mask, repetition_penalties),
)
@pytest.mark.skipif(not current_platform.is_cuda(),
reason="This test for checking CUDA kernel")
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="This test for checking CUDA kernel"
)
@torch.inference_mode()
def test_apply_repetition_penalties_zero_seqs() -> None:
"""
@@ -104,22 +108,24 @@ def test_apply_repetition_penalties_zero_seqs() -> None:
# No tokens to mark as repeated since num_seqs=0
# Create repetition penalties tensor
repetition_penalties = torch.full((num_seqs, ),
repetition_penalty,
dtype=dtype)
repetition_penalties = torch.full((num_seqs,), repetition_penalty, dtype=dtype)
# Run all three implementations
logits_torch = logits.clone()
logits_cuda = logits.clone()
apply_repetition_penalties_torch(logits_torch, prompt_mask, output_mask,
repetition_penalties)
apply_repetition_penalties_cuda(logits_cuda, prompt_mask, output_mask,
repetition_penalties)
apply_repetition_penalties_torch(
logits_torch, prompt_mask, output_mask, repetition_penalties
)
apply_repetition_penalties_cuda(
logits_cuda, prompt_mask, output_mask, repetition_penalties
)
# Compare all outputs to reference
torch.testing.assert_close(logits_torch, logits_cuda, rtol=1e-3, atol=1e-3)
# Test the operator by applying the opcheck utility
opcheck(torch.ops._C.apply_repetition_penalties_,
(logits.clone(), prompt_mask, output_mask, repetition_penalties))
opcheck(
torch.ops._C.apply_repetition_penalties_,
(logits.clone(), prompt_mask, output_mask, repetition_penalties),
)

View File

@@ -9,11 +9,13 @@ import pytest
import torch
from packaging import version
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
create_standard_kv_cache_spec,
create_vllm_config)
from vllm.v1.attention.backends.flex_attention import (
FlexAttentionMetadataBuilder)
from tests.v1.attention.utils import (
BatchSpec,
create_common_attn_metadata,
create_standard_kv_cache_spec,
create_vllm_config,
)
from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadataBuilder
from ..models.utils import check_embeddings_close, check_logprobs_close
@@ -57,26 +59,32 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
set_seed(seed)
with vllm_runner(model_name,
runner="generate",
tensor_parallel_size=1,
num_gpu_blocks_override=128,
enforce_eager=True) as llm_flex:
with vllm_runner(
model_name,
runner="generate",
tensor_parallel_size=1,
num_gpu_blocks_override=128,
enforce_eager=True,
) as llm_flex:
output_flex = llm_flex.generate_greedy_logprobs(
prompts, max_tokens, num_logprobs)
prompts, max_tokens, num_logprobs
)
# Run with default backend
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
set_seed(seed)
with vllm_runner(model_name,
runner="generate",
tensor_parallel_size=1,
num_gpu_blocks_override=128,
enforce_eager=True,
gpu_memory_utilization=0.85) as llm_default:
with vllm_runner(
model_name,
runner="generate",
tensor_parallel_size=1,
num_gpu_blocks_override=128,
enforce_eager=True,
gpu_memory_utilization=0.85,
) as llm_default:
output_default = llm_default.generate_greedy_logprobs(
prompts, max_tokens, num_logprobs)
prompts, max_tokens, num_logprobs
)
check_logprobs_close(
outputs_0_lst=output_flex,
@@ -107,23 +115,27 @@ def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
with vllm_runner(model_name,
runner="pooling",
dtype=torch.bfloat16,
tensor_parallel_size=1,
max_model_len=100,
enforce_eager=True) as llm_flex:
with vllm_runner(
model_name,
runner="pooling",
dtype=torch.bfloat16,
tensor_parallel_size=1,
max_model_len=100,
enforce_eager=True,
) as llm_flex:
flex_outputs = llm_flex.embed(prompts)
# Run with default backend
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
with vllm_runner(model_name,
runner="pooling",
dtype=torch.bfloat16,
tensor_parallel_size=1,
max_model_len=100,
enforce_eager=True) as llm_default:
with vllm_runner(
model_name,
runner="pooling",
dtype=torch.bfloat16,
tensor_parallel_size=1,
max_model_len=100,
enforce_eager=True,
) as llm_default:
default_outputs = llm_default.embed(prompts)
check_embeddings_close(
@@ -147,27 +159,29 @@ def test_block_mask_direct_vs_slow_path():
"""
device = torch.device("cuda")
vllm_config = create_vllm_config(model_name="meta-llama/Meta-Llama-3-8B",
block_size=16,
max_model_len=1024)
vllm_config = create_vllm_config(
model_name="meta-llama/Meta-Llama-3-8B", block_size=16, max_model_len=1024
)
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
# Use a mixed batch that will create groups spanning multiple sequences
batch_spec = BatchSpec(seq_lens=[35, 64, 128, 256],
query_lens=[33, 5, 32, 64],
name="test_mixed_batch")
batch_spec = BatchSpec(
seq_lens=[35, 64, 128, 256], query_lens=[33, 5, 32, 64], name="test_mixed_batch"
)
common_attn_metadata = create_common_attn_metadata(
batch_spec, vllm_config.cache_config.block_size, device)
batch_spec, vllm_config.cache_config.block_size, device
)
builder = FlexAttentionMetadataBuilder(kv_cache_spec, [], vllm_config,
device)
builder = FlexAttentionMetadataBuilder(kv_cache_spec, [], vllm_config, device)
metadata_direct = builder.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata)
metadata_direct = builder.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
builder.direct_build = False
metadata_slow = builder.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata)
metadata_slow = builder.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
assert metadata_direct.block_mask is not None
assert metadata_slow.block_mask is not None
@@ -184,20 +198,20 @@ def test_block_mask_direct_vs_slow_path():
missing_details = []
for group_idx in range(num_groups):
direct_blocks = set(
direct_indices[group_idx, :direct_num[group_idx]].tolist())
slow_blocks = set(
slow_indices[group_idx, :slow_num[group_idx]].tolist())
direct_blocks = set(direct_indices[group_idx, : direct_num[group_idx]].tolist())
slow_blocks = set(slow_indices[group_idx, : slow_num[group_idx]].tolist())
missing_blocks = slow_blocks - direct_blocks
if missing_blocks:
all_contained = False
missing_details.append(
f"Group {group_idx}: missing {sorted(missing_blocks)}")
f"Group {group_idx}: missing {sorted(missing_blocks)}"
)
assert all_contained, (
"Direct path is missing blocks required by slow path:\n" +
"\n".join(missing_details))
"Direct path is missing blocks required by slow path:\n"
+ "\n".join(missing_details)
)
if __name__ == "__main__":

View File

@@ -13,13 +13,12 @@ QUANT_DTYPES = [current_platform.fp8_dtype()]
NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing
HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor,
scale: torch.Tensor) -> torch.Tensor:
def ref_impl(
silu_and_mul: SiluAndMul, x: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
silu_and_mul_out = silu_and_mul.forward_native(x)
out, scales = ops.scaled_fp8_quant(silu_and_mul_out, scale)
return out
@@ -27,9 +26,7 @@ def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor,
def ops_impl(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
out_shape = (x.shape[0], x.shape[1] // 2)
out = torch.empty(out_shape,
dtype=current_platform.fp8_dtype(),
device=x.device)
out = torch.empty(out_shape, dtype=current_platform.fp8_dtype(), device=x.device)
torch.ops._C.silu_and_mul_quant(out, x, scale)
return out
@@ -57,7 +54,7 @@ def test_silu_and_mul(
layer = SiluAndMul()
# Make inputs
scale = (torch.randn((1), device=device, dtype=torch.float32))
scale = torch.randn((1), device=device, dtype=torch.float32)
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
ref_out = ref_impl(layer, x, scale)
@@ -66,6 +63,7 @@ def test_silu_and_mul(
assert ref_out.dtype == quant_dtype
assert ops_out.dtype == quant_dtype
assert ref_out.shape == ops_out.shape
assert torch.allclose(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))
assert torch.allclose(
ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)
)
opcheck(torch.ops._C.silu_and_mul_quant, (ops_out, x, scale))

Some files were not shown because too many files have changed in this diff Show More