diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index 81533c29d..40108e490 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -214,3 +214,4 @@ configuration. | `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | | `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any | +| `XPU_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | Any | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | Any | diff --git a/tests/kernels/attention/test_xpu_mla_sparse.py b/tests/kernels/attention/test_xpu_mla_sparse.py new file mode 100644 index 000000000..419644923 --- /dev/null +++ b/tests/kernels/attention/test_xpu_mla_sparse.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm.v1.attention.ops.xpu_mla_sparse import triton_bf16_mla_sparse_interface + + +# https://github.com/deepseek-ai/FlashMLA/blob/main/tests/ref.py#L7 +def _merge_two_lse( + lse0: torch.Tensor, lse1: torch.Tensor | None, s_q: int, h_q: int +) -> torch.Tensor: + if lse1 is None: + return lse0 + else: + return torch.logsumexp( + torch.stack([lse0.view(s_q, h_q), lse1.broadcast_to(s_q, h_q)], dim=0), + dim=0, + ) + + +# Adapted from https://github.com/deepseek-ai/FlashMLA/blob/main/tests/ref.py#L19 +def reference_mla_sparse_prefill( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int, + topk_length: torch.Tensor | None = None, + attn_sink: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Returns: + - o: [s_q, h_q, dv] + - o_fp32: [s_q, h_q, dv] + - max_logits: [s_q, h_q] + - lse: [s_q, h_q] + """ + s_q, h_q, d_qk = q.shape + s_kv, _, _ = kv.shape + _, _, topk = indices.shape + + indices = indices.clone().squeeze(1) + if topk_length is not None: + mask = torch.arange(topk, device=topk_length.device).unsqueeze(0).broadcast_to( + s_q, topk + ) >= topk_length.unsqueeze(1) # [s_q, topk] + indices[mask] = -1 + invalid_mask = (indices < 0) | (indices >= s_kv) # [s_q, topk] + indices[invalid_mask] = 0 + + q = q.float() + gathered_kv = ( + kv.index_select(dim=0, index=indices.flatten()).reshape(s_q, topk, d_qk).float() + ) # [s_q, topk, d_qk] + P = q @ gathered_kv.transpose(1, 2) # [s_q, h_q, topk] + P *= sm_scale + P[invalid_mask.unsqueeze(1).broadcast_to(P.shape)] = float("-inf") + + orig_lse = torch.logsumexp(P, dim=-1) # [s_q, h_q] + max_logits = P.max(dim=-1).values # [s_q, h_q] + + lse_for_o = _merge_two_lse(orig_lse, attn_sink, s_q, h_q) + if not torch.is_inference_mode_enabled(): + lse_for_o = lse_for_o.clone() + lse_for_o[lse_for_o == float("-inf")] = float( + "+inf" + ) # So that corresponding O will be 0 + s_for_o = torch.exp(P - lse_for_o.unsqueeze(-1)) + out = s_for_o @ gathered_kv[..., :d_v] # [s_q, h_q, dv] + + lonely_q_mask = orig_lse == float("-inf") # [s_q, h_q] + orig_lse[lonely_q_mask] = float("+inf") + return (out.to(kv.dtype), out, max_logits, orig_lse) + + +@pytest.mark.parametrize("device_str", ["xpu"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.skipif( + not torch.xpu.is_available(), + reason="XPU is required", +) +def test_bf16_triton_sparse_mla(device_str, dtype): + device = torch.device(device_str) + s_q = 1 + s_kv = 256 + h_q = 64 # kernel expects multiple of 64 + h_kv = 1 + d_qk = 576 + d_v = 512 + topk = 128 + + torch.random.manual_seed(1234) + + q = torch.randn((s_q, h_q, d_qk), dtype=dtype, device=device) + kv = torch.randn((s_kv, h_kv, d_qk), dtype=dtype, device=device) + indices = torch.full((s_q, h_kv, topk), -1, dtype=torch.int32, device=device) + for t in range(s_q): + for h in range(h_kv): + i_i = torch.randperm(max(1, t))[:topk] + indices[t, h, : len(i_i)] = i_i + + sm_scale = d_qk**-0.5 + + out, max_logits, lse = triton_bf16_mla_sparse_interface( + q, kv, indices, sm_scale, d_v + ) + assert out.shape == (s_q, h_q, d_v) + assert max_logits.shape == (s_q, h_q) + assert lse.shape == (s_q, h_q) + + ref_out, ref_out_fp32, ref_max_logits, ref_lse = reference_mla_sparse_prefill( + q, kv, indices, sm_scale, d_v + ) + assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-2) + assert torch.allclose(max_logits, ref_max_logits, atol=1e-3, rtol=1e-3) + assert torch.allclose(lse, ref_lse, atol=1e-3, rtol=1e-3) diff --git a/vllm/_xpu_ops.py b/vllm/_xpu_ops.py index 1f64aacd4..b873bfa7f 100644 --- a/vllm/_xpu_ops.py +++ b/vllm/_xpu_ops.py @@ -7,6 +7,7 @@ import torch from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func from vllm.logger import init_logger +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -157,3 +158,247 @@ class xpu_ops: "get_scheduler_metadata is not implemented for xpu_ops, returning None." ) return None + + @staticmethod + def indexer_k_quant_and_cache( + k: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + quant_block_size: int, + scale_fmt: str | None, + ) -> None: + head_dim = k.shape[-1] + k = k.view(-1, head_dim) # [total_tokens, head_dim] + + def group_quant_torch( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: torch.dtype | None = None, + column_major_scales: bool = False, + out_q: torch.Tensor | None = None, + use_ue8m0: bool | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if use_ue8m0 is None: + # Default fallback - could import is_deep_gemm_e8m0_used if needed + use_ue8m0 = False + + if dtype is None: + dtype = current_platform.fp8_dtype() + + # Validate inputs + assert x.shape[-1] % group_size == 0, ( + f"Last dimension {x.shape[-1]} must be divisible by " + f"group_size {group_size}" + ) + assert x.stride(-1) == 1, "Input tensor groups must be contiguous" + + # Prepare output tensor + if out_q is None: + x_q = torch.empty_like(x, dtype=dtype) + else: + assert out_q.shape == x.shape + x_q = out_q + + # Reshape input for group processing + # Original shape: (..., last_dim) + # Target shape: (..., num_groups, group_size) + original_shape = x.shape + num_groups = original_shape[-1] // group_size + + # Reshape to separate groups + group_shape = original_shape[:-1] + (num_groups, group_size) + x_grouped = x.view(group_shape) + + # Compute per-group absolute maximum values + # Shape: (..., num_groups) + abs_max = torch.amax(torch.abs(x_grouped), dim=-1, keepdim=False) + abs_max = torch.maximum( + abs_max, torch.tensor(eps, device=x.device, dtype=x.dtype) + ) + + # Compute scales + FP8_MAX = torch.finfo(dtype).max + FP8_MIN = torch.finfo(dtype).min + scale_raw = abs_max / FP8_MAX + + if use_ue8m0: + # For UE8M0 format, scales must be powers of 2 + scales = torch.pow(2.0, torch.ceil(torch.log2(scale_raw))) + else: + scales = scale_raw + + # Expand scales for broadcasting with grouped data + # Shape: (..., num_groups, 1) + scales_expanded = scales.unsqueeze(-1) + + # Quantize the grouped data + x_scaled = x_grouped / scales_expanded + x_clamped = torch.clamp(x_scaled, FP8_MIN, FP8_MAX) + x_quantized = x_clamped.to(dtype) + + # Reshape back to original shape + x_q.copy_(x_quantized.view(original_shape)) + + # Prepare scales tensor in requested format + if column_major_scales: + # Column-major: (num_groups,) + batch_dims + # Transpose the scales to put group dimension first + scales_shape = (num_groups,) + original_shape[:-1] + x_s = scales.permute(-1, *range(len(original_shape) - 1)) + x_s = x_s.contiguous().view(scales_shape) + else: + # Row-major: batch_dims + (num_groups,) + x_s = scales.contiguous() + + # Ensure scales are float32 + return x_q, x_s.float() + + k_fp8, k_scale = group_quant_torch( + k, + group_size=quant_block_size, + column_major_scales=False, + use_ue8m0=(scale_fmt == "ue8m0"), + ) + + k_fp8_bytes = k_fp8.view(-1, head_dim).view(torch.uint8) + scale_bytes = k_scale.view(torch.uint8).view(-1, 4) + k = torch.cat( + [k_fp8_bytes, scale_bytes], dim=-1 + ) # [total_tokens, head_dim + 4] + + slot_mapping = slot_mapping.flatten() + # kv_cache: [num_block, block_size, head_dim + 4] + kv_cache.view(-1, kv_cache.shape[-1]).index_copy_(0, slot_mapping, k) + + @staticmethod + def cp_gather_indexer_k_quant_cache( + kv_cache: torch.Tensor, + dst_k: torch.Tensor, + dst_scale: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, + ) -> None: + """ + Args: + kv_cache: [num_blocks, block_size, cache_stride] - quantized KV cache + Layout per block: [k_values, scale_values] + - k_values: [block_size * head_dim] + - scale_values: [block_size * head_dim * 4 / quant_block_size] + dst_k: [num_tokens, head_dim] - output tensor for K values + dst_scale: [num_tokens, head_dim / quant_block_size * 4] + - output tensor for scale values + block_table: [batch_size, num_blocks] - block table for indexing + cu_seq_lens: [batch_size + 1] - cumulative sequence lengths + """ + batch_size = block_table.size(0) + num_tokens = dst_k.size(0) + head_dim = dst_k.size(1) + cache_block_size = kv_cache.size(1) + quant_block_size = head_dim * 4 // dst_scale.size(1) + + # For each token, find which batch it belongs to using searchsorted + token_indices = torch.arange(num_tokens, device=dst_k.device) + 1 + # cu_seq_lens is [batch_size + 1], we need to find which interval each + # token belongs to + batch_indices = torch.searchsorted(cu_seq_lens, token_indices) - 1 + batch_indices = torch.clamp(batch_indices, 0, batch_size - 1) + + # Calculate the in-batch sequence index for each token + inbatch_seq_indices = token_indices - cu_seq_lens[batch_indices] + + # Find which block each token belongs to + block_indices_in_table = inbatch_seq_indices // cache_block_size + physical_block_indices = block_table[batch_indices, block_indices_in_table] + + # Calculate the offset within each block + inblock_offsets = (inbatch_seq_indices - 1) % cache_block_size + + # Calculate strides + block_stride = kv_cache.stride(0) # stride for each block + + # Flatten kv_cache for easier indexing + kv_cache_flat = kv_cache.view(-1) + + # Calculate source offset for K values for all tokens (vectorized) + src_block_offsets = physical_block_indices * block_stride + src_k_offsets = src_block_offsets + inblock_offsets * head_dim + + # Gather K values using advanced indexing + # Create indices for all elements we need to gather + k_indices = src_k_offsets.unsqueeze(1) + torch.arange( + head_dim, device=dst_k.device + ) + dst_k[:] = kv_cache_flat[k_indices] + + # Calculate source offset for scale values (vectorized) + # Scales are stored after all K values for each block + scale_size = head_dim * 4 // quant_block_size + src_scale_offsets = src_block_offsets + head_dim + inblock_offsets * scale_size + + # Gather scale values + scale_indices = src_scale_offsets.unsqueeze(1) + torch.arange( + scale_size, device=dst_scale.device + ) + dst_scale[:] = kv_cache_flat[scale_indices] + + @staticmethod + def top_k_per_row_prefill( + logits: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + raw_topk_indices: torch.Tensor, + num_rows: int, + stride0: int, + strdide1: int, + topk_tokens: int, + ) -> torch.Tensor: + real_topk = min(topk_tokens, logits.shape[-1]) + topk_indices = logits.topk(real_topk, dim=-1)[1].to(torch.int32) + topk_indices -= cu_seqlen_ks[:, None] + mask_lo = topk_indices >= 0 + mask_hi = topk_indices - (cu_seqlen_ke - cu_seqlen_ks)[:, None] < 0 + mask = torch.full_like( + topk_indices, False, dtype=torch.bool, device=topk_indices.device + ) + mask = mask_lo & mask_hi + topk_indices.masked_fill_(~mask, -1) + raw_topk_indices[: topk_indices.shape[0], : topk_indices.shape[1]] = ( + topk_indices + ) + + @staticmethod + def top_k_per_row_decode( + logits: torch.Tensor, + next_n: int, + seq_lens: torch.Tensor, + raw_topk_indices: torch.Tensor, + num_rows: int, + stride0: int, + stride1: int, + topk_tokens: int, + ) -> torch.Tensor: + device = logits.device + batch_size = seq_lens.size(0) + # padded query len + padded_num_tokens = batch_size * next_n + positions = ( + torch.arange(logits.shape[-1], device=device) + .unsqueeze(0) + .expand(batch_size * next_n, -1) + ) + row_indices = torch.arange(padded_num_tokens, device=device) // next_n + next_n_offset = torch.arange(padded_num_tokens, device=device) % next_n + index_end_pos = (seq_lens[row_indices] - next_n + next_n_offset).unsqueeze(1) + # index_end_pos: [B * N, 1] + mask = positions <= index_end_pos + # mask: [B * N, L] + logits = logits.masked_fill(~mask, float("-inf")) + topk_indices = logits.topk(topk_tokens, dim=-1)[1].to(torch.int32) # [B * N, K] + # ensure we don't set indices for the top k + # that is out of range(masked already) + # this will happen if context length is shorter than K + topk_indices[topk_indices > index_end_pos] = -1 + raw_topk_indices[: topk_indices.shape[0], : topk_indices.shape[1]] = ( + topk_indices + ) diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index 5383e2f11..0d55ba858 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -135,16 +135,29 @@ def sparse_attn_indexer( topk_indices = topk_indices_buffer[ chunk.token_start : chunk.token_end, :topk_tokens ] - torch.ops._C.top_k_per_row_prefill( - logits, - chunk.cu_seqlen_ks, - chunk.cu_seqlen_ke, - topk_indices, - num_rows, - logits.stride(0), - logits.stride(1), - topk_tokens, - ) + + if current_platform.is_xpu(): + ops.top_k_per_row_prefill( + logits, + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + else: + torch.ops._C.top_k_per_row_prefill( + logits, + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) # Compute lengths from row spans # lengths = (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks).to(torch.int32) @@ -220,16 +233,28 @@ def sparse_attn_indexer( None, ) else: - torch.ops._C.top_k_per_row_decode( - logits, - next_n, - decode_metadata.seq_lens, - topk_indices, - num_rows, - logits.stride(0), - logits.stride(1), - topk_tokens, - ) + if current_platform.is_xpu(): + ops.top_k_per_row_decode( + logits, + next_n, + decode_metadata.seq_lens, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + else: + torch.ops._C.top_k_per_row_decode( + logits, + next_n, + decode_metadata.seq_lens, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) if decode_metadata.requires_padding: # if padded, we need to unpack @@ -320,14 +345,14 @@ class SparseAttnIndexer(CustomOp): k: torch.Tensor, weights: torch.Tensor, ): - if current_platform.is_cuda(): + if current_platform.is_cuda() or current_platform.is_xpu(): return self.forward_cuda(hidden_states, q_fp8, k, weights) elif current_platform.is_rocm(): return self.forward_hip(hidden_states, q_fp8, k, weights) else: raise NotImplementedError( "SparseAttnIndexer native forward is only implemented for " - "CUDA and ROCm platform." + "CUDA, ROCm and XPU platforms." ) def forward_cuda( diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 893b5454f..b7bcee4dd 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -61,7 +61,8 @@ class XPUPlatform(Platform): dtype = attn_selector_config.dtype if attn_selector_config.use_sparse: - raise NotImplementedError("Sparse Attention is not supported on XPU.") + logger.info_once("Using XPU MLA Sparse backend.") + return AttentionBackendEnum.XPU_MLA_SPARSE.get_path() if attn_selector_config.use_mla: logger.info_once("Using Triton MLA backend on V1 engine.") return AttentionBackendEnum.TRITON_MLA.get_path() diff --git a/vllm/triton_utils/__init__.py b/vllm/triton_utils/__init__.py index ce459ca91..f4866a702 100644 --- a/vllm/triton_utils/__init__.py +++ b/vllm/triton_utils/__init__.py @@ -17,4 +17,7 @@ else: tl = TritonLanguagePlaceholder() tldevice = TritonLanguagePlaceholder() -__all__ = ["HAS_TRITON", "triton", "tl", "tldevice"] +LOG2E = 1.4426950408889634 +LOGE2 = 0.6931471805599453 + +__all__ = ["HAS_TRITON", "triton", "tl", "tldevice", "LOG2E", "LOGE2"] diff --git a/vllm/v1/attention/backends/mla/xpu_mla_sparse.py b/vllm/v1/attention/backends/mla/xpu_mla_sparse.py new file mode 100644 index 000000000..feb8191fd --- /dev/null +++ b/vllm/v1/attention/backends/mla/xpu_mla_sparse.py @@ -0,0 +1,257 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import TYPE_CHECKING, ClassVar, Optional + +import numpy as np +import torch + +from vllm.config import VllmConfig +from vllm.config.cache import CacheDType +from vllm.logger import init_logger +from vllm.model_executor.layers.attention.mla_attention import ( + get_mla_dims, +) +from vllm.v1.attention.backend import ( + AttentionBackend, + AttentionCGSupport, + AttentionLayer, + AttentionMetadata, + AttentionMetadataBuilder, + CommonAttentionMetadata, + SparseMLAAttentionImpl, +) +from vllm.v1.attention.backends.mla.flashmla_sparse import ( + triton_convert_req_index_to_global_index, +) +from vllm.v1.attention.ops.xpu_mla_sparse import triton_bf16_mla_sparse_interface +from vllm.v1.kv_cache_interface import AttentionSpec + +if TYPE_CHECKING: + from vllm.model_executor.models.deepseek_v2 import Indexer +logger = init_logger(__name__) + + +class XPUMLASparseBackend(AttentionBackend): + accept_output_buffer: bool = True + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "bfloat16", + ] + + @staticmethod + def get_name() -> str: + return "XPU_MLA_SPARSE" + + @staticmethod + def get_metadata_cls() -> type["XPUMLASparseMetadata"]: + return XPUMLASparseMetadata + + @staticmethod + def get_builder_cls() -> type["XPUMLASparseMetadataBuilder"]: + return XPUMLASparseMetadataBuilder + + @staticmethod + def get_impl_cls() -> type["XPUMLASparseImpl"]: + return XPUMLASparseImpl + + @classmethod + def is_mla(cls) -> bool: + return True + + @classmethod + def is_sparse(cls) -> bool: + return True + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + return (num_blocks, block_size, head_size) + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [576] + + +@dataclass +class XPUMLASparseMetadata(AttentionMetadata): + num_reqs: int + max_query_len: int + max_seq_len: int + + num_actual_tokens: int # Number of tokens excluding padding. + query_start_loc: torch.Tensor + slot_mapping: torch.Tensor + + block_table: torch.Tensor + req_id_per_token: torch.Tensor + + block_size: int = 1 + topk_tokens: int = 2048 + + +@dataclass +class XPUMLASparseMetadataBuilder(AttentionMetadataBuilder[XPUMLASparseMetadata]): + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + self.kv_cache_spec = kv_cache_spec + self.model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + self.device = device + max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens + + self.num_heads = self.model_config.get_num_attention_heads(parallel_config) + self.mla_dims = get_mla_dims(self.model_config) + self.topk_tokens = vllm_config.model_config.hf_config.index_topk + self.topk_tokens_tensor = torch.tensor( + [self.topk_tokens], device=device, dtype=torch.int32 + ) + self.max_model_len_tensor = torch.tensor( + [self.model_config.max_model_len], device=device, dtype=torch.int32 + ) + # this is ignored by `flash_mla_with_kvcache` if indices not None + self.dummy_block_table = torch.empty( + (1, 1), dtype=torch.int32, device=self.device + ) + + self.req_id_per_token_buffer = torch.empty( + (max_num_batched_tokens,), + dtype=torch.int32, + device=device, + ) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> XPUMLASparseMetadata: + num_tokens = common_attn_metadata.num_actual_tokens + starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32) + seg_lengths = np.diff(starts) + req_id_per_token = np.repeat( + np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths + ) + # Zero-fill for cudagraphs + self.req_id_per_token_buffer.fill_(0) + self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_( + torch.from_numpy(req_id_per_token), non_blocking=True + ) + + req_id_per_token = self.req_id_per_token_buffer[:num_tokens] + + metadata = XPUMLASparseMetadata( + num_reqs=common_attn_metadata.num_reqs, + max_query_len=common_attn_metadata.max_query_len, + max_seq_len=common_attn_metadata.max_seq_len, + num_actual_tokens=common_attn_metadata.num_actual_tokens, + query_start_loc=common_attn_metadata.query_start_loc, + slot_mapping=common_attn_metadata.slot_mapping, + block_table=common_attn_metadata.block_table_tensor, + req_id_per_token=req_id_per_token, + block_size=self.kv_cache_spec.block_size, + topk_tokens=self.topk_tokens, + ) + return metadata + + +class XPUMLASparseImpl(SparseMLAAttentionImpl[XPUMLASparseMetadata]): + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None, + sliding_window: int | None, + kv_cache_dtype: str, + logits_soft_cap: float | None, + attn_type: str, + kv_sharing_target_layer_name: str | None, + # MLA Specific Arguments + topk_indice_buffer: torch.Tensor | None = None, + indexer: Optional["Indexer"] = None, + **mla_args, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + self.kv_lora_rank: int = mla_args["kv_lora_rank"] + self.softmax_scale = scale + assert indexer is not None + self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer + + def _forward_bf16_kv( + self, + q: torch.Tensor, # [sq, heads, d_qk] + kv_c_and_k_pe_cache: torch.Tensor, # [blocks, heads, d_qk] + topk_indices: torch.Tensor, # [sq, topk] + attn_metadata: XPUMLASparseMetadata, + ) -> torch.Tensor: + num_tokens = q.shape[0] + kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( + -1, 1, kv_c_and_k_pe_cache.shape[-1] + ) + + topk_indices = topk_indices.view(num_tokens, 1, -1) + + output, _, _ = triton_bf16_mla_sparse_interface( + q, + kv_c_and_k_pe_cache, + topk_indices, + sm_scale=self.softmax_scale, + ) + + return output[:, : self.num_heads, :] + + def forward_mqa( + self, + q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: XPUMLASparseMetadata, + layer: AttentionLayer, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use + # MQA 576/512 approach for both prefill and decode + + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError("FP8 kv is not supported with XPU MLA Sparse yet") + + # Concatenate q if it's a tuple (ql_nope, q_pe) + if isinstance(q, tuple): + q = torch.cat(q, dim=-1) + + num_actual_toks = q.shape[0] + + assert self.topk_indices_buffer is not None + topk_indices = self.topk_indices_buffer[:num_actual_toks] + + topk_indices_global = triton_convert_req_index_to_global_index( + attn_metadata.req_id_per_token, + attn_metadata.block_table, + topk_indices, + BLOCK_SIZE=attn_metadata.block_size, + NUM_TOPK_TOKENS=attn_metadata.topk_tokens, + ) + + attn_out = self._forward_bf16_kv( + q, kv_c_and_k_pe_cache, topk_indices_global, attn_metadata + ) + + return attn_out, None diff --git a/vllm/v1/attention/backends/registry.py b/vllm/v1/attention/backends/registry.py index 8e60551e2..4744ead4f 100644 --- a/vllm/v1/attention/backends/registry.py +++ b/vllm/v1/attention/backends/registry.py @@ -57,6 +57,7 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): ROCM_AITER_MLA_SPARSE = ( "vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse.ROCMAiterMLASparseBackend" ) + XPU_MLA_SPARSE = "vllm.v1.attention.backends.mla.xpu_mla_sparse.XPUMLASparseBackend" TORCH_SDPA = "" # this tag is only used for ViT FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" FLASHINFER_MLA = ( diff --git a/vllm/v1/attention/ops/xpu_mla_sparse.py b/vllm/v1/attention/ops/xpu_mla_sparse.py new file mode 100644 index 000000000..8a4c1ffd6 --- /dev/null +++ b/vllm/v1/attention/ops/xpu_mla_sparse.py @@ -0,0 +1,265 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.triton_utils import LOG2E, LOGE2, tl, triton + + +@triton.jit +def _bf16_mla_sparse_kernel( + q_buffer, + k_buffer, + v_buffer, + indices_ptr, + out_ptr, + softmax_lse_ptr, + max_logits_ptr, + seq_q, + seq_kv, + h_q, + dim_qk, + dim_v, + stride_q_token, + stride_q_head, + stride_k_token, + stride_k_head, + stride_v_token, + stride_v_head, + stride_out_token, + stride_out_head, + stride_lse, + stride_indices_token, + stride_indices_head, + sm_scale, + kv_group_num: tl.constexpr, + index_topk: tl.constexpr, + BLOCK_H: tl.constexpr, # block size for num heads + BLOCK_M: tl.constexpr, # block size for num tokens + BLOCK_N: tl.constexpr, # block size for indices + BLOCK_DV: tl.constexpr, # block size for dim_v + BLOCK_DMODEL: tl.constexpr, # block size for dim_nope + BLOCK_DPE: tl.constexpr, # block size for positional embedding + LOGE2: tl.constexpr, +): + cur_q = tl.program_id(0) + cur_head_id = tl.program_id(1) + cur_kv_head_id = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) + + VALID_BLOCK_H: tl.constexpr = BLOCK_H if kv_group_num > BLOCK_H else kv_group_num + cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) + mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H + mask_h = mask_h & (cur_head < h_q) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + + off_q = cur_q * stride_q_token + cur_head[:, None] * stride_q_head + offs_d[None, :] + mask_dmodel = offs_d < BLOCK_DMODEL + q = tl.load( + q_buffer + off_q, mask=(mask_h[:, None]) & (mask_dmodel[None, :]), other=0.0 + ) + + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + off_qpe = ( + cur_q * stride_q_token + + cur_head[:, None] * stride_q_head + + offs_dpe[None, :] + ) + # assume dim_qk == BLOCK_DMODEL + BLOCK_DPE + mask_dpe = offs_dpe < dim_qk + qpe = tl.load( + q_buffer + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0 + ) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) + + for start_indice in range(0, index_topk, BLOCK_N): + offs_indice = start_indice + tl.arange(0, BLOCK_N) + mask_indice = offs_indice < index_topk + indices = tl.load( + indices_ptr + + ( + cur_q * stride_indices_token + + cur_kv_head_id * stride_indices_head + + offs_indice + ), + mask=mask_indice, + other=-1, + ) + + mask_kv = (indices >= 0) & (indices < seq_kv) + mask_kv_d = mask_dmodel + offs_k = ( + indices[None, :] * stride_k_token + + cur_kv_head_id * stride_k_head + + offs_d[:, None] + ) + + # q_nope @ k_nope + k = tl.load( + k_buffer + offs_k, mask=(mask_kv[None, :]) & (mask_kv_d[:, None]), other=0.0 + ) + qk = tl.dot(q, k.to(q.dtype)) + + if BLOCK_DPE > 0: + # q_rope @ k_rope + offs_kpe = ( + indices[None, :] * stride_k_token + + cur_kv_head_id * stride_k_head + + offs_dpe[:, None] + ) + mask_k_dpe = offs_dpe < dim_qk + kpe = tl.load( + k_buffer + offs_kpe, + mask=(mask_kv[None, :]) & (mask_k_dpe[:, None]), + other=0.0, + ) + qk += tl.dot(qpe, kpe.to(q.dtype)) + + # apply scaling + qk *= sm_scale + qk = tl.where((mask_h[:, None]) & (mask_kv[None, :]), qk, -float("inf")) + + # load v + mask_v_d = offs_dv < dim_v + offs_v = ( + indices[:, None] * stride_v_token + + cur_kv_head_id * stride_v_head + + offs_dv[None, :] + ) + v = tl.load( + v_buffer + offs_v, mask=(mask_kv[:, None]) & (mask_v_d[None, :]), other=0.0 + ) + + # online softmax + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp2(e_max - n_e_max) + p = tl.exp2(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + + # score @ v + acc += tl.dot(p.to(v.dtype), v) + + # update global sum and max + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + + # rescaling + acc /= e_sum[:, None] + + max_logits = e_max * LOGE2 + # calculate lse + lse = max_logits + tl.log2(e_sum) * LOGE2 + + # write output + offs_o = ( + cur_q * stride_out_token + + cur_head[:, None] * stride_out_head + + offs_dv[None, :] + ) + mask_out_d = offs_dv < dim_v + tl.store( + out_ptr + offs_o, + acc.to(tl.bfloat16), + mask=(mask_h[:, None]) & (mask_out_d[None, :]), + ) + + offs_lse = cur_q * stride_lse + cur_head + tl.store(softmax_lse_ptr + offs_lse, lse, mask=mask_h) + tl.store(max_logits_ptr + offs_lse, max_logits, mask=mask_h) + + +# reference implementation of bf16 sparse prefill kernel +def triton_bf16_mla_sparse_interface( + q: torch.Tensor, # [num_tokens, num_heads_q, dim_qk] + kv: torch.Tensor, # [num_tokens, num_heads_kv, dim_qk] + indices: torch.Tensor, # [num_tokens, num_heads_kv, topk] + sm_scale: float, + d_v: int = 512, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + out : [num_tokens, num_heads_q, d_v] + max_logits : [num_tokens, num_heads_q] + lse : logsumexp, [num_tokens, num_heads_q] + """ + num_tokens, num_heads_q, dim_qk = q.shape + _, num_heads_kv, _ = kv.shape + assert dim_qk == kv.shape[2], "q and kv have different head dimensions" + + # for deepseek v3.2, index topk should be 2048 + _, _, index_topk = indices.shape + + BLOCK_H = 16 + BLOCK_DMODEL = 512 + BLOCK_DPE = 64 + BLOCK_M = 32 + BLOCK_N = 16 + BLOCK_DV = 512 + assert d_v == BLOCK_DV, "only support d_v = 512" + + assert dim_qk == BLOCK_DMODEL + BLOCK_DPE, ( + "dim_qk does not match BLOCK_DMODEL + BLOCK_DPE" + ) + assert num_heads_kv == 1, "only support kv head = 1 for now" + assert index_topk % BLOCK_N == 0, "index_topk must be multiple of BLOCK_N" + + sm_scale *= LOG2E + + kv_group_num = num_heads_q // num_heads_kv + grid = ( + num_tokens, + triton.cdiv(num_heads_q, min(BLOCK_H, kv_group_num)), + ) + + out = torch.zeros((num_tokens, num_heads_q, d_v), dtype=q.dtype, device=q.device) + softmax_lse = torch.zeros( + (num_tokens, num_heads_q), dtype=torch.float32, device=q.device + ) + max_logits = torch.zeros( + (num_tokens, num_heads_q), dtype=torch.float32, device=q.device + ) + + k = kv + v = kv[..., :d_v] + + _bf16_mla_sparse_kernel[grid]( + q_buffer=q, + k_buffer=k, + v_buffer=v, + indices_ptr=indices, + out_ptr=out, + softmax_lse_ptr=softmax_lse, + max_logits_ptr=max_logits, + seq_q=num_tokens, + seq_kv=kv.shape[0], + h_q=num_heads_q, + dim_qk=dim_qk, + dim_v=d_v, + stride_q_token=q.stride(0), + stride_q_head=q.stride(1), + stride_k_token=k.stride(0), + stride_k_head=k.stride(1), + stride_v_token=v.stride(0), + stride_v_head=v.stride(1), + stride_out_token=out.stride(0), + stride_out_head=out.stride(1), + stride_lse=softmax_lse.stride(0), + stride_indices_token=indices.stride(0), + stride_indices_head=indices.stride(1), + sm_scale=sm_scale, + kv_group_num=kv_group_num, + index_topk=index_topk, + BLOCK_H=BLOCK_H, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DV=BLOCK_DV, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DPE=BLOCK_DPE, + LOGE2=LOGE2, + ) + + return out, max_logits, softmax_lse