Add XPU MLA Sparse backend for DeepSeek v3.2 (#33230)
Signed-off-by: Zhang, Wuxun <wuxun.zhang@intel.com>
This commit is contained in:
257
vllm/v1/attention/backends/mla/xpu_mla_sparse.py
Normal file
257
vllm/v1/attention/backends/mla/xpu_mla_sparse.py
Normal file
@@ -0,0 +1,257 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
get_mla_dims,
|
||||
)
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
SparseMLAAttentionImpl,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
triton_convert_req_index_to_global_index,
|
||||
)
|
||||
from vllm.v1.attention.ops.xpu_mla_sparse import triton_bf16_mla_sparse_interface
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.deepseek_v2 import Indexer
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XPUMLASparseBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "XPU_MLA_SPARSE"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["XPUMLASparseMetadata"]:
|
||||
return XPUMLASparseMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["XPUMLASparseMetadataBuilder"]:
|
||||
return XPUMLASparseMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["XPUMLASparseImpl"]:
|
||||
return XPUMLASparseImpl
|
||||
|
||||
@classmethod
|
||||
def is_mla(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_sparse(cls) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [576]
|
||||
|
||||
|
||||
@dataclass
|
||||
class XPUMLASparseMetadata(AttentionMetadata):
|
||||
num_reqs: int
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
query_start_loc: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
block_table: torch.Tensor
|
||||
req_id_per_token: torch.Tensor
|
||||
|
||||
block_size: int = 1
|
||||
topk_tokens: int = 2048
|
||||
|
||||
|
||||
@dataclass
|
||||
class XPUMLASparseMetadataBuilder(AttentionMetadataBuilder[XPUMLASparseMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.device = device
|
||||
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
|
||||
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
|
||||
self.mla_dims = get_mla_dims(self.model_config)
|
||||
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
|
||||
self.topk_tokens_tensor = torch.tensor(
|
||||
[self.topk_tokens], device=device, dtype=torch.int32
|
||||
)
|
||||
self.max_model_len_tensor = torch.tensor(
|
||||
[self.model_config.max_model_len], device=device, dtype=torch.int32
|
||||
)
|
||||
# this is ignored by `flash_mla_with_kvcache` if indices not None
|
||||
self.dummy_block_table = torch.empty(
|
||||
(1, 1), dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
self.req_id_per_token_buffer = torch.empty(
|
||||
(max_num_batched_tokens,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> XPUMLASparseMetadata:
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
|
||||
seg_lengths = np.diff(starts)
|
||||
req_id_per_token = np.repeat(
|
||||
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
|
||||
)
|
||||
# Zero-fill for cudagraphs
|
||||
self.req_id_per_token_buffer.fill_(0)
|
||||
self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
|
||||
torch.from_numpy(req_id_per_token), non_blocking=True
|
||||
)
|
||||
|
||||
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
|
||||
|
||||
metadata = XPUMLASparseMetadata(
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
block_table=common_attn_metadata.block_table_tensor,
|
||||
req_id_per_token=req_id_per_token,
|
||||
block_size=self.kv_cache_spec.block_size,
|
||||
topk_tokens=self.topk_tokens,
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
||||
class XPUMLASparseImpl(SparseMLAAttentionImpl[XPUMLASparseMetadata]):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
topk_indice_buffer: torch.Tensor | None = None,
|
||||
indexer: Optional["Indexer"] = None,
|
||||
**mla_args,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.kv_lora_rank: int = mla_args["kv_lora_rank"]
|
||||
self.softmax_scale = scale
|
||||
assert indexer is not None
|
||||
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
|
||||
|
||||
def _forward_bf16_kv(
|
||||
self,
|
||||
q: torch.Tensor, # [sq, heads, d_qk]
|
||||
kv_c_and_k_pe_cache: torch.Tensor, # [blocks, heads, d_qk]
|
||||
topk_indices: torch.Tensor, # [sq, topk]
|
||||
attn_metadata: XPUMLASparseMetadata,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = q.shape[0]
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
|
||||
-1, 1, kv_c_and_k_pe_cache.shape[-1]
|
||||
)
|
||||
|
||||
topk_indices = topk_indices.view(num_tokens, 1, -1)
|
||||
|
||||
output, _, _ = triton_bf16_mla_sparse_interface(
|
||||
q,
|
||||
kv_c_and_k_pe_cache,
|
||||
topk_indices,
|
||||
sm_scale=self.softmax_scale,
|
||||
)
|
||||
|
||||
return output[:, : self.num_heads, :]
|
||||
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: XPUMLASparseMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
|
||||
# MQA 576/512 approach for both prefill and decode
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 kv is not supported with XPU MLA Sparse yet")
|
||||
|
||||
# Concatenate q if it's a tuple (ql_nope, q_pe)
|
||||
if isinstance(q, tuple):
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
num_actual_toks = q.shape[0]
|
||||
|
||||
assert self.topk_indices_buffer is not None
|
||||
topk_indices = self.topk_indices_buffer[:num_actual_toks]
|
||||
|
||||
topk_indices_global = triton_convert_req_index_to_global_index(
|
||||
attn_metadata.req_id_per_token,
|
||||
attn_metadata.block_table,
|
||||
topk_indices,
|
||||
BLOCK_SIZE=attn_metadata.block_size,
|
||||
NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
|
||||
)
|
||||
|
||||
attn_out = self._forward_bf16_kv(
|
||||
q, kv_c_and_k_pe_cache, topk_indices_global, attn_metadata
|
||||
)
|
||||
|
||||
return attn_out, None
|
||||
@@ -57,6 +57,7 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
|
||||
ROCM_AITER_MLA_SPARSE = (
|
||||
"vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse.ROCMAiterMLASparseBackend"
|
||||
)
|
||||
XPU_MLA_SPARSE = "vllm.v1.attention.backends.mla.xpu_mla_sparse.XPUMLASparseBackend"
|
||||
TORCH_SDPA = "" # this tag is only used for ViT
|
||||
FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
|
||||
FLASHINFER_MLA = (
|
||||
|
||||
265
vllm/v1/attention/ops/xpu_mla_sparse.py
Normal file
265
vllm/v1/attention/ops/xpu_mla_sparse.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import LOG2E, LOGE2, tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bf16_mla_sparse_kernel(
|
||||
q_buffer,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
indices_ptr,
|
||||
out_ptr,
|
||||
softmax_lse_ptr,
|
||||
max_logits_ptr,
|
||||
seq_q,
|
||||
seq_kv,
|
||||
h_q,
|
||||
dim_qk,
|
||||
dim_v,
|
||||
stride_q_token,
|
||||
stride_q_head,
|
||||
stride_k_token,
|
||||
stride_k_head,
|
||||
stride_v_token,
|
||||
stride_v_head,
|
||||
stride_out_token,
|
||||
stride_out_head,
|
||||
stride_lse,
|
||||
stride_indices_token,
|
||||
stride_indices_head,
|
||||
sm_scale,
|
||||
kv_group_num: tl.constexpr,
|
||||
index_topk: tl.constexpr,
|
||||
BLOCK_H: tl.constexpr, # block size for num heads
|
||||
BLOCK_M: tl.constexpr, # block size for num tokens
|
||||
BLOCK_N: tl.constexpr, # block size for indices
|
||||
BLOCK_DV: tl.constexpr, # block size for dim_v
|
||||
BLOCK_DMODEL: tl.constexpr, # block size for dim_nope
|
||||
BLOCK_DPE: tl.constexpr, # block size for positional embedding
|
||||
LOGE2: tl.constexpr,
|
||||
):
|
||||
cur_q = tl.program_id(0)
|
||||
cur_head_id = tl.program_id(1)
|
||||
cur_kv_head_id = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
|
||||
|
||||
VALID_BLOCK_H: tl.constexpr = BLOCK_H if kv_group_num > BLOCK_H else kv_group_num
|
||||
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
|
||||
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
|
||||
mask_h = mask_h & (cur_head < h_q)
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_dv = tl.arange(0, BLOCK_DV)
|
||||
|
||||
off_q = cur_q * stride_q_token + cur_head[:, None] * stride_q_head + offs_d[None, :]
|
||||
mask_dmodel = offs_d < BLOCK_DMODEL
|
||||
q = tl.load(
|
||||
q_buffer + off_q, mask=(mask_h[:, None]) & (mask_dmodel[None, :]), other=0.0
|
||||
)
|
||||
|
||||
if BLOCK_DPE > 0:
|
||||
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
||||
off_qpe = (
|
||||
cur_q * stride_q_token
|
||||
+ cur_head[:, None] * stride_q_head
|
||||
+ offs_dpe[None, :]
|
||||
)
|
||||
# assume dim_qk == BLOCK_DMODEL + BLOCK_DPE
|
||||
mask_dpe = offs_dpe < dim_qk
|
||||
qpe = tl.load(
|
||||
q_buffer + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
|
||||
)
|
||||
|
||||
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
|
||||
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
|
||||
|
||||
for start_indice in range(0, index_topk, BLOCK_N):
|
||||
offs_indice = start_indice + tl.arange(0, BLOCK_N)
|
||||
mask_indice = offs_indice < index_topk
|
||||
indices = tl.load(
|
||||
indices_ptr
|
||||
+ (
|
||||
cur_q * stride_indices_token
|
||||
+ cur_kv_head_id * stride_indices_head
|
||||
+ offs_indice
|
||||
),
|
||||
mask=mask_indice,
|
||||
other=-1,
|
||||
)
|
||||
|
||||
mask_kv = (indices >= 0) & (indices < seq_kv)
|
||||
mask_kv_d = mask_dmodel
|
||||
offs_k = (
|
||||
indices[None, :] * stride_k_token
|
||||
+ cur_kv_head_id * stride_k_head
|
||||
+ offs_d[:, None]
|
||||
)
|
||||
|
||||
# q_nope @ k_nope
|
||||
k = tl.load(
|
||||
k_buffer + offs_k, mask=(mask_kv[None, :]) & (mask_kv_d[:, None]), other=0.0
|
||||
)
|
||||
qk = tl.dot(q, k.to(q.dtype))
|
||||
|
||||
if BLOCK_DPE > 0:
|
||||
# q_rope @ k_rope
|
||||
offs_kpe = (
|
||||
indices[None, :] * stride_k_token
|
||||
+ cur_kv_head_id * stride_k_head
|
||||
+ offs_dpe[:, None]
|
||||
)
|
||||
mask_k_dpe = offs_dpe < dim_qk
|
||||
kpe = tl.load(
|
||||
k_buffer + offs_kpe,
|
||||
mask=(mask_kv[None, :]) & (mask_k_dpe[:, None]),
|
||||
other=0.0,
|
||||
)
|
||||
qk += tl.dot(qpe, kpe.to(q.dtype))
|
||||
|
||||
# apply scaling
|
||||
qk *= sm_scale
|
||||
qk = tl.where((mask_h[:, None]) & (mask_kv[None, :]), qk, -float("inf"))
|
||||
|
||||
# load v
|
||||
mask_v_d = offs_dv < dim_v
|
||||
offs_v = (
|
||||
indices[:, None] * stride_v_token
|
||||
+ cur_kv_head_id * stride_v_head
|
||||
+ offs_dv[None, :]
|
||||
)
|
||||
v = tl.load(
|
||||
v_buffer + offs_v, mask=(mask_kv[:, None]) & (mask_v_d[None, :]), other=0.0
|
||||
)
|
||||
|
||||
# online softmax
|
||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||
re_scale = tl.exp2(e_max - n_e_max)
|
||||
p = tl.exp2(qk - n_e_max[:, None])
|
||||
acc *= re_scale[:, None]
|
||||
|
||||
# score @ v
|
||||
acc += tl.dot(p.to(v.dtype), v)
|
||||
|
||||
# update global sum and max
|
||||
e_sum = e_sum * re_scale + tl.sum(p, 1)
|
||||
e_max = n_e_max
|
||||
|
||||
# rescaling
|
||||
acc /= e_sum[:, None]
|
||||
|
||||
max_logits = e_max * LOGE2
|
||||
# calculate lse
|
||||
lse = max_logits + tl.log2(e_sum) * LOGE2
|
||||
|
||||
# write output
|
||||
offs_o = (
|
||||
cur_q * stride_out_token
|
||||
+ cur_head[:, None] * stride_out_head
|
||||
+ offs_dv[None, :]
|
||||
)
|
||||
mask_out_d = offs_dv < dim_v
|
||||
tl.store(
|
||||
out_ptr + offs_o,
|
||||
acc.to(tl.bfloat16),
|
||||
mask=(mask_h[:, None]) & (mask_out_d[None, :]),
|
||||
)
|
||||
|
||||
offs_lse = cur_q * stride_lse + cur_head
|
||||
tl.store(softmax_lse_ptr + offs_lse, lse, mask=mask_h)
|
||||
tl.store(max_logits_ptr + offs_lse, max_logits, mask=mask_h)
|
||||
|
||||
|
||||
# reference implementation of bf16 sparse prefill kernel
|
||||
def triton_bf16_mla_sparse_interface(
|
||||
q: torch.Tensor, # [num_tokens, num_heads_q, dim_qk]
|
||||
kv: torch.Tensor, # [num_tokens, num_heads_kv, dim_qk]
|
||||
indices: torch.Tensor, # [num_tokens, num_heads_kv, topk]
|
||||
sm_scale: float,
|
||||
d_v: int = 512,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
out : [num_tokens, num_heads_q, d_v]
|
||||
max_logits : [num_tokens, num_heads_q]
|
||||
lse : logsumexp, [num_tokens, num_heads_q]
|
||||
"""
|
||||
num_tokens, num_heads_q, dim_qk = q.shape
|
||||
_, num_heads_kv, _ = kv.shape
|
||||
assert dim_qk == kv.shape[2], "q and kv have different head dimensions"
|
||||
|
||||
# for deepseek v3.2, index topk should be 2048
|
||||
_, _, index_topk = indices.shape
|
||||
|
||||
BLOCK_H = 16
|
||||
BLOCK_DMODEL = 512
|
||||
BLOCK_DPE = 64
|
||||
BLOCK_M = 32
|
||||
BLOCK_N = 16
|
||||
BLOCK_DV = 512
|
||||
assert d_v == BLOCK_DV, "only support d_v = 512"
|
||||
|
||||
assert dim_qk == BLOCK_DMODEL + BLOCK_DPE, (
|
||||
"dim_qk does not match BLOCK_DMODEL + BLOCK_DPE"
|
||||
)
|
||||
assert num_heads_kv == 1, "only support kv head = 1 for now"
|
||||
assert index_topk % BLOCK_N == 0, "index_topk must be multiple of BLOCK_N"
|
||||
|
||||
sm_scale *= LOG2E
|
||||
|
||||
kv_group_num = num_heads_q // num_heads_kv
|
||||
grid = (
|
||||
num_tokens,
|
||||
triton.cdiv(num_heads_q, min(BLOCK_H, kv_group_num)),
|
||||
)
|
||||
|
||||
out = torch.zeros((num_tokens, num_heads_q, d_v), dtype=q.dtype, device=q.device)
|
||||
softmax_lse = torch.zeros(
|
||||
(num_tokens, num_heads_q), dtype=torch.float32, device=q.device
|
||||
)
|
||||
max_logits = torch.zeros(
|
||||
(num_tokens, num_heads_q), dtype=torch.float32, device=q.device
|
||||
)
|
||||
|
||||
k = kv
|
||||
v = kv[..., :d_v]
|
||||
|
||||
_bf16_mla_sparse_kernel[grid](
|
||||
q_buffer=q,
|
||||
k_buffer=k,
|
||||
v_buffer=v,
|
||||
indices_ptr=indices,
|
||||
out_ptr=out,
|
||||
softmax_lse_ptr=softmax_lse,
|
||||
max_logits_ptr=max_logits,
|
||||
seq_q=num_tokens,
|
||||
seq_kv=kv.shape[0],
|
||||
h_q=num_heads_q,
|
||||
dim_qk=dim_qk,
|
||||
dim_v=d_v,
|
||||
stride_q_token=q.stride(0),
|
||||
stride_q_head=q.stride(1),
|
||||
stride_k_token=k.stride(0),
|
||||
stride_k_head=k.stride(1),
|
||||
stride_v_token=v.stride(0),
|
||||
stride_v_head=v.stride(1),
|
||||
stride_out_token=out.stride(0),
|
||||
stride_out_head=out.stride(1),
|
||||
stride_lse=softmax_lse.stride(0),
|
||||
stride_indices_token=indices.stride(0),
|
||||
stride_indices_head=indices.stride(1),
|
||||
sm_scale=sm_scale,
|
||||
kv_group_num=kv_group_num,
|
||||
index_topk=index_topk,
|
||||
BLOCK_H=BLOCK_H,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
BLOCK_DV=BLOCK_DV,
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_DPE=BLOCK_DPE,
|
||||
LOGE2=LOGE2,
|
||||
)
|
||||
|
||||
return out, max_logits, softmax_lse
|
||||
Reference in New Issue
Block a user