[Bugfix][ROCm]Fix Qwen3-Next-80B-A3B-Thinking inference and optimize non-standard block size (544) support under rocm_atten (#31380)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
@@ -112,6 +112,7 @@ def test_contexted_kv_attention(
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
op: Callable,
|
||||
block_size: int = 32,
|
||||
) -> None:
|
||||
if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89):
|
||||
pytest.skip(
|
||||
@@ -138,7 +139,6 @@ def test_contexted_kv_attention(
|
||||
MAX_CTX_LEN = 1024
|
||||
BS = 10
|
||||
cache_size = 640
|
||||
block_size = 32
|
||||
max_block_per_request = 64
|
||||
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
|
||||
# ensure one sequence in batch is a decode
|
||||
@@ -333,6 +333,7 @@ def test_contexted_kv_attention_alibi(
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
op: Callable,
|
||||
block_size: int = 32,
|
||||
) -> None:
|
||||
if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89):
|
||||
pytest.skip(
|
||||
@@ -385,7 +386,6 @@ def test_contexted_kv_attention_alibi(
|
||||
MAX_CTX_LEN = 1024
|
||||
BS = 10
|
||||
cache_size = 640
|
||||
block_size = 32
|
||||
max_block_per_request = 64
|
||||
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
|
||||
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
|
||||
@@ -637,3 +637,34 @@ def test_contexted_kv_attention_alibi_f32(
|
||||
test_contexted_kv_attention_alibi(
|
||||
num_heads, num_queries_per_kv, head_size, dtype, kv_cache_dtype, device, op
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("head_size", [128])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("op", OPS)
|
||||
@torch.inference_mode()
|
||||
def test_qwen3_nonstandard_block_size(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
op: Callable,
|
||||
) -> None:
|
||||
"""
|
||||
A separate test function specifically added
|
||||
for Qwen3-Next-80B (Block Size 544).
|
||||
"""
|
||||
if not current_platform.is_rocm():
|
||||
pytest.skip("544 block size optimization is only for ROCm.")
|
||||
|
||||
test_contexted_kv_attention(
|
||||
num_heads=64,
|
||||
num_queries_per_kv=1,
|
||||
head_size=head_size,
|
||||
block_size=544,
|
||||
sliding_window=0,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype="auto",
|
||||
device=device,
|
||||
op=op,
|
||||
)
|
||||
|
||||
@@ -46,6 +46,7 @@ def kernel_paged_attention_2d(
|
||||
output_stride_0: tl.int64, # int
|
||||
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||
BLOCK_SIZE: tl.constexpr, # int
|
||||
PHYSICAL_BLOCK_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||
@@ -104,14 +105,15 @@ def kernel_paged_attention_2d(
|
||||
|
||||
if not USE_SINKS:
|
||||
M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
|
||||
L = tl.zeros([num_queries_per_kv_padded], dtype=tl.float32)
|
||||
else:
|
||||
M = tl.load(
|
||||
sink_ptr + query_head_idx,
|
||||
mask=head_mask,
|
||||
other=float("-inf"),
|
||||
).to(dtype=tl.float32)
|
||||
L = tl.where(float("-inf") < M, 1.0, 0.0)
|
||||
|
||||
L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], dtype=tl.float32)
|
||||
|
||||
# sequence len for this particular sequence
|
||||
@@ -125,30 +127,45 @@ def kernel_paged_attention_2d(
|
||||
|
||||
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
|
||||
|
||||
offs_n = tl.arange(0, BLOCK_SIZE)
|
||||
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
||||
# iterate through tiles
|
||||
for j in range(0, num_blocks):
|
||||
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
|
||||
|
||||
offs_n = tl.arange(0, BLOCK_SIZE)
|
||||
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
||||
|
||||
v_offset = (
|
||||
physical_block_idx * stride_v_cache_0
|
||||
+ kv_head_idx * stride_v_cache_1
|
||||
+ offs_d[None, :] * stride_v_cache_2
|
||||
+ offs_n[:, None] * stride_v_cache_3
|
||||
)
|
||||
start_n = j * BLOCK_SIZE
|
||||
# Calculate the logical location within a non-standard physical block,
|
||||
# such as 544 in Qwen/Qwen3-Next-80B-A3B-Thinking.
|
||||
# Supports non-contiguous mapping
|
||||
# from logical blocks to physical blocks
|
||||
abs_token_idx = start_n + offs_n
|
||||
l_block_idx = abs_token_idx // PHYSICAL_BLOCK_SIZE
|
||||
# Vectorized loading of physical block IDs
|
||||
p_block_idx = tl.load(block_tables_ptr + block_table_offset + l_block_idx)
|
||||
internal_offsets = abs_token_idx % PHYSICAL_BLOCK_SIZE
|
||||
|
||||
# 5D addressing logic of K
|
||||
k_offset = (
|
||||
physical_block_idx * stride_k_cache_0
|
||||
p_block_idx[None, :] * stride_k_cache_0
|
||||
+ kv_head_idx * stride_k_cache_1
|
||||
+ (offs_d[:, None] // x) * stride_k_cache_2
|
||||
+ offs_n[None, :] * stride_k_cache_3
|
||||
+ internal_offsets[None, :] * stride_k_cache_3
|
||||
+ (offs_d[:, None] % x) * stride_k_cache_4
|
||||
)
|
||||
|
||||
# 4D addressing logic of V (Slot is innermost)
|
||||
v_offset = (
|
||||
p_block_idx[:, None] * stride_v_cache_0
|
||||
+ kv_head_idx * stride_v_cache_1
|
||||
+ offs_d[None, :] * stride_v_cache_2
|
||||
+ internal_offsets[:, None] * stride_v_cache_3
|
||||
)
|
||||
|
||||
# K : (HEAD_SIZE, BLOCK_SIZE)
|
||||
K_load = tl.load(key_cache_ptr + k_offset, mask=dim_mask[:, None], other=0.0)
|
||||
K_load = tl.load(
|
||||
key_cache_ptr + k_offset,
|
||||
mask=dim_mask[:, None],
|
||||
other=0.0,
|
||||
eviction_policy="evict_last",
|
||||
)
|
||||
|
||||
if K_load.dtype.is_fp8():
|
||||
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
|
||||
@@ -156,7 +173,12 @@ def kernel_paged_attention_2d(
|
||||
K = K_load
|
||||
|
||||
# V : (BLOCK_SIZE, HEAD_SIZE)
|
||||
V_load = tl.load(value_cache_ptr + v_offset, mask=dim_mask[None, :], other=0.0)
|
||||
V_load = tl.load(
|
||||
value_cache_ptr + v_offset,
|
||||
mask=dim_mask[None, :],
|
||||
other=0.0,
|
||||
eviction_policy="evict_last",
|
||||
)
|
||||
|
||||
if V_load.dtype.is_fp8():
|
||||
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
|
||||
@@ -167,9 +189,9 @@ def kernel_paged_attention_2d(
|
||||
boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32)
|
||||
seq_mask = seq_offset[None, :] < boundary
|
||||
|
||||
# S : (num_queries_per_kv, BLOCK_SIZE,)
|
||||
S = tl.where(head_mask[:, None] & seq_mask, 0.0, float("-inf")).to(tl.float32)
|
||||
S += scale * tl.dot(Q, K)
|
||||
# First calculate the dot, then apply the mask.
|
||||
qk = scale * tl.dot(Q, K)
|
||||
S = tl.where(head_mask[:, None] & seq_mask, qk, float("-inf"))
|
||||
|
||||
context_len = seq_len - 1
|
||||
|
||||
@@ -184,13 +206,15 @@ def kernel_paged_attention_2d(
|
||||
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||
|
||||
# P : (num_queries_per_kv, BLOCK_SIZE,)
|
||||
P = tl.exp(S - m_j[:, None])
|
||||
p = tl.exp(S - m_j[:, None])
|
||||
p = tl.where(m_j[:, None] == float("-inf"), 0.0, p)
|
||||
|
||||
# l_j : (num_queries_per_kv,)
|
||||
l_j = tl.sum(P, axis=1)
|
||||
l_j = tl.sum(p, axis=1)
|
||||
|
||||
# alpha : (num_queries_per_kv, )
|
||||
alpha = tl.exp(M - m_j)
|
||||
alpha = tl.where(float("-inf") == M, 0.0, alpha)
|
||||
|
||||
# acc : (num_queries_per_kv, BLOCK_SIZE,)
|
||||
acc = acc * alpha[:, None]
|
||||
@@ -200,10 +224,10 @@ def kernel_paged_attention_2d(
|
||||
M = m_j
|
||||
|
||||
# acc : (num_queries_per_kv, BLOCK_SIZE,)
|
||||
acc += tl.dot(P.to(V.dtype), V)
|
||||
acc += tl.dot(p.to(V.dtype), V)
|
||||
|
||||
# epilogue
|
||||
acc = acc / L[:, None]
|
||||
acc = acc / (L[:, None] + 1e-10)
|
||||
if USE_FP8:
|
||||
acc = acc * tl.load(out_scale_inv)
|
||||
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||
@@ -241,9 +265,10 @@ def chunked_prefill_paged_decode(
|
||||
output_scale=None,
|
||||
# Optional tensor for sinks
|
||||
sinks=None,
|
||||
is_block_table_ptr: bool = False,
|
||||
):
|
||||
if sm_scale is None:
|
||||
sm_scale = 1.0 / (query.shape[1] ** 0.5)
|
||||
sm_scale = 1.0 / (query.shape[2] ** 0.5)
|
||||
|
||||
use_alibi_slopes = alibi_slopes is not None
|
||||
|
||||
@@ -315,6 +340,16 @@ def chunked_prefill_paged_decode(
|
||||
alibi_slopes,
|
||||
sinks,
|
||||
)
|
||||
# Triton is only forced when encountering a non-standard block
|
||||
# like Qwen3 with a size of 544.
|
||||
# 1. Check if block_size is a power of 2 (16, 32, 64...)
|
||||
# 2. If it's a power of 2, we trust the vLLM's native use_custom decision.
|
||||
# 3. If it's not a power of 2 (such as Qwen3's 544),
|
||||
# then our Triton path is forced.
|
||||
is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0)
|
||||
if not is_pow2:
|
||||
use_custom = False
|
||||
|
||||
if use_custom:
|
||||
_PARTITION_SIZE_ROCM = 256
|
||||
max_num_partitions = (
|
||||
@@ -356,6 +391,25 @@ def chunked_prefill_paged_decode(
|
||||
fp8_out_scale=output_scale,
|
||||
)
|
||||
else:
|
||||
real_block_size = value_cache.shape[3]
|
||||
# The standard model directly uses the original block_size.
|
||||
# Non-standard 544 uses 32 to accommodate integer division logic.
|
||||
TRITON_BLOCK_SIZE = block_size if is_pow2 else 32
|
||||
if is_block_table_ptr:
|
||||
# Using the physical base address of tensors
|
||||
kv_element_size = key_cache.element_size()
|
||||
block_byte_stride = key_cache.stride(0) * kv_element_size
|
||||
# Get the starting physical address of the KV Cache
|
||||
base_addr = key_cache.data_ptr()
|
||||
|
||||
# Normalization: Directly calculate the block offset
|
||||
# of the pointer relative to the base address
|
||||
processed_block_table = ((block_table - base_addr) // block_byte_stride).to(
|
||||
torch.int32
|
||||
)
|
||||
else:
|
||||
processed_block_table = block_table.to(torch.int32)
|
||||
|
||||
kernel_paged_attention_2d[
|
||||
(
|
||||
num_seqs,
|
||||
@@ -367,7 +421,7 @@ def chunked_prefill_paged_decode(
|
||||
key_cache_ptr=key_cache,
|
||||
value_cache_ptr=value_cache,
|
||||
sink_ptr=sinks,
|
||||
block_tables_ptr=block_table,
|
||||
block_tables_ptr=processed_block_table,
|
||||
seq_lens_ptr=seq_lens,
|
||||
alibi_slopes_ptr=alibi_slopes,
|
||||
scale=sm_scale,
|
||||
@@ -377,12 +431,13 @@ def chunked_prefill_paged_decode(
|
||||
num_query_heads=num_query_heads,
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
num_queries_per_kv_padded=num_queries_per_kv_padded,
|
||||
block_table_stride=block_table.stride(0),
|
||||
block_table_stride=processed_block_table.stride(0),
|
||||
query_stride_0=query.stride(0),
|
||||
query_stride_1=query.stride(1),
|
||||
output_stride_0=output.stride(0),
|
||||
output_stride_1=output.stride(1),
|
||||
BLOCK_SIZE=block_size,
|
||||
BLOCK_SIZE=TRITON_BLOCK_SIZE,
|
||||
PHYSICAL_BLOCK_SIZE=real_block_size,
|
||||
HEAD_SIZE=head_size,
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||
|
||||
@@ -79,6 +79,7 @@ def _fwd_kernel(
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_DMODEL_PADDED: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
PHYSICAL_BLOCK_SIZE: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
SLIDING_WINDOW: tl.constexpr,
|
||||
num_unroll_cache: tl.constexpr,
|
||||
@@ -139,42 +140,52 @@ def _fwd_kernel(
|
||||
# initialize pointer to m and l
|
||||
if not USE_SINKS:
|
||||
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
else:
|
||||
m_i = tl.load(
|
||||
sink_ptr + tl.full([BLOCK_M], cur_head, dtype=tl.int64),
|
||||
mask=(offs_m < cur_batch_query_len),
|
||||
other=float("-inf"),
|
||||
).to(dtype=tl.float32)
|
||||
l_i = tl.where(m_i > float("-inf"), 1.0, 0.0)
|
||||
|
||||
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D]
|
||||
|
||||
# compute query against context (no causal mask here)
|
||||
for start_n in tl.range(
|
||||
0, cur_batch_ctx_len, BLOCK_SIZE, loop_unroll_factor=num_unroll_cache
|
||||
):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_SIZE)
|
||||
# -- compute qk ----
|
||||
# Under a block size of 544 (Qwen/Qwen3-Next-80B-A3B-Thinking),
|
||||
# replace one physical block every 17 32-Tile blocks
|
||||
# Calculate the logical block index of each of the 32 tokens
|
||||
# in the current Tile (handling cross-block cases).
|
||||
token_indices = start_n + offs_bs_n
|
||||
bn_logical_indices = token_indices // PHYSICAL_BLOCK_SIZE
|
||||
|
||||
# 2. Vectorized loading of physical block IDs from B_Loc
|
||||
bn = tl.load(
|
||||
B_Loc
|
||||
+ cur_batch * stride_b_loc_b
|
||||
+ (start_n // BLOCK_SIZE) * stride_b_loc_s
|
||||
B_Loc + cur_batch * stride_b_loc_b + bn_logical_indices * stride_b_loc_s
|
||||
).to(tl.int64)
|
||||
# [D,BLOCK_SIZE]
|
||||
|
||||
# 3. Calculate the exact offset of
|
||||
# each token within its physical block.
|
||||
internal_offsets = token_indices % PHYSICAL_BLOCK_SIZE
|
||||
|
||||
# Addressing of K (5D)
|
||||
off_k = (
|
||||
bn[None, :] * stride_k_cache_bs
|
||||
+ cur_kv_head * stride_k_cache_h
|
||||
+ (offs_d[:, None] // x) * stride_k_cache_d
|
||||
+ ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl
|
||||
+ internal_offsets[None, :] * stride_k_cache_bl
|
||||
+ (offs_d[:, None] % x) * stride_k_cache_x
|
||||
)
|
||||
|
||||
# [BLOCK_SIZE,D]
|
||||
# Addressing of V (4D)
|
||||
off_v = (
|
||||
bn[:, None] * stride_v_cache_bs
|
||||
+ cur_kv_head * stride_v_cache_h
|
||||
+ offs_d[None, :] * stride_v_cache_d
|
||||
+ offs_bs_n[:, None] * stride_v_cache_bl
|
||||
+ internal_offsets[:, None] * stride_v_cache_bl
|
||||
)
|
||||
|
||||
if (
|
||||
@@ -195,12 +206,12 @@ def _fwd_kernel(
|
||||
else:
|
||||
k = k_load
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N]
|
||||
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
|
||||
# qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N]
|
||||
qk = sm_scale * tl.dot(q, k, input_precision=IN_PRECISION)
|
||||
qk = tl.where(
|
||||
(start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf")
|
||||
)
|
||||
qk *= sm_scale
|
||||
# qk *= sm_scale
|
||||
if SLIDING_WINDOW > 0:
|
||||
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
|
||||
# Q entries in sequence
|
||||
@@ -217,14 +228,16 @@ def _fwd_kernel(
|
||||
(cur_batch_ctx_len + offs_m[:, None]) - (start_n + offs_bs_n[None, :])
|
||||
< SLIDING_WINDOW,
|
||||
qk,
|
||||
-10000,
|
||||
float("-inf"),
|
||||
)
|
||||
|
||||
# compute running maximum
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
p = tl.where(m_ij[:, None] == float("-inf"), 0.0, p)
|
||||
l_ij = tl.sum(p, axis=1)
|
||||
alpha = tl.exp(m_i - m_ij)
|
||||
alpha = tl.where(m_i == float("-inf"), 0.0, alpha)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# update acc
|
||||
@@ -293,14 +306,17 @@ def _fwd_kernel(
|
||||
qk = tl.where(
|
||||
offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW,
|
||||
qk,
|
||||
-10000,
|
||||
float("-inf"),
|
||||
)
|
||||
|
||||
# compute running maximum
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
p = tl.where(m_ij[:, None] == float("-inf"), 0.0, p)
|
||||
l_ij = tl.sum(p, axis=1)
|
||||
alpha = tl.exp(m_i - m_ij)
|
||||
# To prevent NaN from appearing in the first round
|
||||
alpha = tl.where(m_i == float("-inf"), 0.0, alpha)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# update acc
|
||||
@@ -317,7 +333,7 @@ def _fwd_kernel(
|
||||
l_i = l_i * alpha + l_ij
|
||||
m_i = m_ij
|
||||
|
||||
acc = acc / l_i[:, None]
|
||||
acc = acc / (l_i[:, None] + 1e-10)
|
||||
|
||||
# initialize pointers to output
|
||||
off_o = (
|
||||
@@ -637,6 +653,7 @@ def context_attention_fwd(
|
||||
skip_decode=False,
|
||||
fp8_out_scale=None,
|
||||
sinks=None,
|
||||
is_block_table_ptr: bool = False,
|
||||
):
|
||||
q_dtype_is_f32 = q.dtype is torch.float32
|
||||
|
||||
@@ -689,6 +706,19 @@ def context_attention_fwd(
|
||||
if sliding_window is None or sliding_window <= 0:
|
||||
sliding_window = 0
|
||||
|
||||
if is_block_table_ptr:
|
||||
kv_element_size = k_cache.element_size()
|
||||
block_byte_stride = k_cache.stride(0) * kv_element_size
|
||||
# The physical starting point of the obtained KV Cache Pool
|
||||
base_addr = k_cache.data_ptr()
|
||||
|
||||
mask = b_loc > 0
|
||||
processed_b_loc = torch.where(
|
||||
mask, (b_loc - base_addr) // block_byte_stride, b_loc
|
||||
).to(torch.int32)
|
||||
else:
|
||||
processed_b_loc = b_loc.to(torch.int32)
|
||||
|
||||
if alibi_slopes is not None:
|
||||
assert sinks is None, "Sinks arg is not supported with alibi"
|
||||
assert fp8_out_scale is None, "FP8 output not supported with alibi"
|
||||
@@ -752,7 +782,24 @@ def context_attention_fwd(
|
||||
max_seq_len = 0 if max_seq_len is None else max_seq_len
|
||||
extra_kargs = {}
|
||||
if current_platform.is_rocm():
|
||||
extra_kargs = {"kpack": 1, "waves_per_eu": 2}
|
||||
extra_kargs = {}
|
||||
|
||||
real_block_size = v_cache.shape[3]
|
||||
is_pow2 = real_block_size > 0 and (real_block_size & (real_block_size - 1) == 0)
|
||||
# For standard models involving powers of 2,
|
||||
# follow the original logic (Llama 128/64)
|
||||
# For non-standard models (Qwen3-next block_size 544), set to 32.
|
||||
if is_pow2:
|
||||
BLOCK_M = 128
|
||||
BLOCK_N = 64
|
||||
else:
|
||||
BLOCK_M = 32
|
||||
BLOCK_N = 32
|
||||
|
||||
# TRITON_BLOCK_SIZE is kept at 32 to ensure
|
||||
# correct alignment logic when the kernel handles
|
||||
# non-standard sizes (such as 544).
|
||||
TRITON_BLOCK_SIZE = 32
|
||||
|
||||
grid_fn = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"]))
|
||||
_fwd_kernel[grid_fn](
|
||||
@@ -762,7 +809,7 @@ def context_attention_fwd(
|
||||
k_cache,
|
||||
v_cache,
|
||||
sinks,
|
||||
b_loc,
|
||||
processed_b_loc,
|
||||
sm_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
@@ -771,8 +818,8 @@ def context_attention_fwd(
|
||||
b_seq_len,
|
||||
k_cache.shape[4],
|
||||
o,
|
||||
b_loc.stride(0),
|
||||
b_loc.stride(1),
|
||||
processed_b_loc.stride(0),
|
||||
processed_b_loc.stride(1),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
@@ -785,16 +832,17 @@ def context_attention_fwd(
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
k_cache.stride(4), # [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
v_cache.stride(0),
|
||||
v_cache.stride(1),
|
||||
v_cache.stride(2),
|
||||
v_cache.stride(3), # [num_blocks, num_kv_heads, head_size, block_size]
|
||||
BLOCK_SIZE=v_cache.shape[3],
|
||||
stride_k_cache_bs=k_cache.stride(0),
|
||||
stride_k_cache_h=k_cache.stride(1),
|
||||
stride_k_cache_d=k_cache.stride(2),
|
||||
stride_k_cache_bl=k_cache.stride(3),
|
||||
stride_k_cache_x=k_cache.stride(4),
|
||||
stride_v_cache_bs=v_cache.stride(0),
|
||||
stride_v_cache_h=v_cache.stride(1),
|
||||
stride_v_cache_d=v_cache.stride(2),
|
||||
stride_v_cache_bl=v_cache.stride(3),
|
||||
BLOCK_SIZE=TRITON_BLOCK_SIZE,
|
||||
PHYSICAL_BLOCK_SIZE=real_block_size,
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
IN_PRECISION=IN_PRECISION,
|
||||
BLOCK_DMODEL=Lk,
|
||||
@@ -802,8 +850,8 @@ def context_attention_fwd(
|
||||
SLIDING_WINDOW=sliding_window,
|
||||
SKIP_DECODE=skip_decode,
|
||||
USE_FP8=fp8_out_scale is not None,
|
||||
BLOCK_M=128,
|
||||
BLOCK_N=64,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
num_unroll_cache=4,
|
||||
num_unroll_request=1,
|
||||
num_warps=4,
|
||||
|
||||
@@ -20,10 +20,15 @@ def reshape_and_cache_kernel_flash(
|
||||
key_stride: tl.int64,
|
||||
value_stride: tl.int64,
|
||||
block_stride: tl.int64,
|
||||
head_stride: tl.int64,
|
||||
dim_stride_k: tl.int64,
|
||||
dim_stride_v: tl.int64,
|
||||
page_stride: tl.int64,
|
||||
num_heads: tl.constexpr,
|
||||
head_size: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
x: tl.constexpr,
|
||||
USE_HEAD_MAJOR_LAYOUT: tl.constexpr,
|
||||
# FP8 flags
|
||||
FP8_KV_CACHE: tl.constexpr,
|
||||
# tune parameters
|
||||
@@ -35,17 +40,38 @@ def reshape_and_cache_kernel_flash(
|
||||
# Padding token that should be ignored.
|
||||
return
|
||||
|
||||
tile_i = tl.program_id(axis=1)
|
||||
tile_offs = tl.arange(0, TILE_SIZE)
|
||||
tile_pos = tile_i * TILE_SIZE + tile_offs
|
||||
|
||||
block_idx = slot_idx // block_size
|
||||
block_offset = slot_idx % block_size
|
||||
|
||||
tile_i = tl.program_id(axis=1)
|
||||
tile_offs = tl.arange(0, TILE_SIZE)
|
||||
tile_pos = tile_i * TILE_SIZE + tile_offs
|
||||
src_key_idx = token_idx * key_stride
|
||||
src_value_idx = token_idx * value_stride
|
||||
|
||||
tgt_idx = block_idx * block_stride + block_offset * page_stride
|
||||
if USE_HEAD_MAJOR_LAYOUT:
|
||||
# Decompose the tile index back into head and dim coordinates.
|
||||
cur_head = tile_pos // head_size
|
||||
cur_dim = tile_pos % head_size
|
||||
# Value addressing (4D): [Block, Head, Dim, Slot]
|
||||
tgt_idx_v = (
|
||||
block_idx * block_stride
|
||||
+ cur_head * head_stride
|
||||
+ cur_dim * dim_stride_v
|
||||
+ block_offset * 1
|
||||
)
|
||||
# Key addressing (5D): [Block, Head, Dim//8, Slot, 8]
|
||||
tgt_idx_k = (
|
||||
block_idx * block_stride
|
||||
+ cur_head * head_stride
|
||||
+ (cur_dim // x) * dim_stride_k
|
||||
+ block_offset * x
|
||||
+ (cur_dim % x)
|
||||
)
|
||||
else:
|
||||
tgt_base = block_idx * block_stride + block_offset * page_stride
|
||||
tgt_idx_k = tgt_base + tile_pos
|
||||
tgt_idx_v = tgt_base + tile_pos
|
||||
|
||||
# [TILE_SIZE]
|
||||
key_load = tl.load(
|
||||
@@ -73,12 +99,12 @@ def reshape_and_cache_kernel_flash(
|
||||
value_tile = value_load
|
||||
|
||||
tl.store(
|
||||
key_cache_ptr + tgt_idx + tile_pos,
|
||||
key_cache_ptr + tgt_idx_k,
|
||||
key_tile,
|
||||
mask=tile_pos < (num_heads * head_size),
|
||||
)
|
||||
tl.store(
|
||||
value_cache_ptr + tgt_idx + tile_pos,
|
||||
value_cache_ptr + tgt_idx_v,
|
||||
value_tile,
|
||||
mask=tile_pos < (num_heads * head_size),
|
||||
)
|
||||
@@ -99,17 +125,26 @@ def triton_reshape_and_cache_flash(
|
||||
):
|
||||
num_heads = key.shape[1]
|
||||
head_size = key.shape[2]
|
||||
block_size = key_cache.shape[1]
|
||||
n = num_heads * head_size
|
||||
|
||||
use_head_major_layout = key_cache.ndim == 5
|
||||
if use_head_major_layout:
|
||||
block_size = key_cache.shape[3]
|
||||
x = key_cache.shape[4]
|
||||
head_stride = key_cache.stride(1)
|
||||
dim_stride_k = key_cache.stride(2)
|
||||
dim_stride_v = value_cache.stride(2)
|
||||
else:
|
||||
block_size = key_cache.shape[1]
|
||||
x = 1
|
||||
dim_stride_k = 0
|
||||
dim_stride_v = 0
|
||||
head_stride = key_cache.stride()[2]
|
||||
n = num_heads * head_size
|
||||
key_stride = key.stride()[0]
|
||||
value_stride = value.stride()[0]
|
||||
block_stride = key_cache.stride()[0]
|
||||
page_stride = key_cache.stride()[1]
|
||||
|
||||
head_stride = key_cache.stride()[2]
|
||||
assert head_stride == head_size, "only continous heads are supported"
|
||||
|
||||
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), (
|
||||
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
|
||||
)
|
||||
@@ -171,10 +206,15 @@ def triton_reshape_and_cache_flash(
|
||||
key_stride=key_stride,
|
||||
value_stride=value_stride,
|
||||
block_stride=block_stride,
|
||||
head_stride=head_stride,
|
||||
dim_stride_k=dim_stride_k,
|
||||
dim_stride_v=dim_stride_v,
|
||||
page_stride=page_stride,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
block_size=block_size,
|
||||
x=x,
|
||||
USE_HEAD_MAJOR_LAYOUT=use_head_major_layout,
|
||||
# FP8 flags
|
||||
FP8_KV_CACHE=FP8_KV_CACHE,
|
||||
# autotune parameters
|
||||
|
||||
@@ -15,6 +15,9 @@ from vllm.attention.backends.abstract import (
|
||||
)
|
||||
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
@@ -321,16 +324,38 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# Get the actual block_size from value_cache
|
||||
# value_cache shape: [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = value_cache.shape[3]
|
||||
# Determine if it is a power of 2
|
||||
is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0)
|
||||
|
||||
if is_pow2:
|
||||
# Normal 16, 32, 64, etc., use vLLM native HIP C++ logic
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
else:
|
||||
# Case B: Non-standard blocks (e.g., 544 in Qwen3),
|
||||
# force using our modified Triton logic
|
||||
triton_reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
|
||||
Reference in New Issue
Block a user