[New Model] DeepSeek-V3.2 (Rebased to Main) (#25896)
Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Signed-off-by: Lucia Fang <fanglu@meta.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: mgoin <mgoin64@gmail.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Lucia Fang <fanglu@meta.com> Co-authored-by: NickLucche <nlucches@redhat.com> Co-authored-by: Siyuan Fu <siyuanf@nvidia.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Xiaozhu Meng <mxz297@gmail.com> Co-authored-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com>
This commit is contained in:
@@ -33,15 +33,21 @@ from torch import nn
|
||||
from transformers import DeepseekV2Config, DeepseekV3Config
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
||||
from vllm.config import (CacheConfig, ParallelConfig, VllmConfig,
|
||||
get_current_vllm_config)
|
||||
from vllm.distributed import (get_ep_group, get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
@@ -49,6 +55,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttention
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@@ -56,13 +64,26 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import cdiv, direct_register_custom_op
|
||||
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
|
||||
from vllm.v1.attention.backends.mla.indexer import (DeepseekV32IndexerBackend,
|
||||
DeepseekV32IndexerMetadata)
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
|
||||
|
||||
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm import _custom_ops as ops
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DeepseekV2MLP(nn.Module):
|
||||
|
||||
@@ -276,6 +297,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
config: Union[DeepseekV2Config, DeepseekV3Config],
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
@@ -289,6 +311,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
topk_indices_buffer: Optional[torch.Tensor] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -306,6 +329,8 @@ class DeepseekV2Attention(nn.Module):
|
||||
self.scaling = self.qk_head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
assert topk_indices_buffer is None, "topk_indices_buffer is not \
|
||||
supported for DeepseekV2Attention"
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
||||
@@ -418,6 +443,391 @@ class DeepseekV2Attention(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
|
||||
|
||||
def __init__(self, head_dim: int, dtype: torch.dtype, prefix: str,
|
||||
cache_config: CacheConfig):
|
||||
super().__init__()
|
||||
self.kv_cache = [torch.tensor([])]
|
||||
self.head_dim = head_dim
|
||||
self.prefix = prefix
|
||||
self.cache_config = cache_config
|
||||
self.dtype = dtype
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
return MLAAttentionSpec( # Only has one vector instead of K + V
|
||||
block_size=self.cache_config.block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=self.head_dim,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def forward(self):
|
||||
...
|
||||
|
||||
def get_attn_backend(self) -> AttentionBackend:
|
||||
return DeepseekV32IndexerBackend
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def cp_gather_indexer_k_quant_cache(
|
||||
kv_cache, # [num_blocks, block_size, head_dim + 1]
|
||||
dst_value, # [cu_seq_lens[-1], head_dim]
|
||||
dst_scale, # [cu_seq_lens[-1], 4]
|
||||
block_table, # [batch_size, num_blocks]
|
||||
cu_seq_lens, # [batch_size + 1, ]
|
||||
batch_size,
|
||||
):
|
||||
num_blocks, block_size, _ = kv_cache.shape
|
||||
head_dim = dst_value.shape[-1]
|
||||
kv_cache = kv_cache.view(num_blocks, -1)
|
||||
|
||||
expected_value = []
|
||||
expected_scale = []
|
||||
for b in range(batch_size):
|
||||
s = cu_seq_lens[b + 1] - cu_seq_lens[b]
|
||||
if s == 0:
|
||||
continue
|
||||
tot = cdiv(s, block_size)
|
||||
blocks = block_table[b, :tot]
|
||||
|
||||
value = []
|
||||
scale = []
|
||||
full_block = torch.arange(tot - 1,
|
||||
device=kv_cache.device,
|
||||
dtype=torch.int32)
|
||||
non_remaining_value = kv_cache[blocks[full_block], :block_size *
|
||||
head_dim].view(-1, head_dim)
|
||||
non_remaining_scale = kv_cache[blocks[full_block],
|
||||
block_size * head_dim:].view(-1, 4)
|
||||
|
||||
remaining = s - (tot - 1) * block_size
|
||||
|
||||
value = torch.cat([
|
||||
non_remaining_value,
|
||||
kv_cache[blocks[-1], :remaining * head_dim].view(-1, head_dim)
|
||||
],
|
||||
dim=0)
|
||||
scale = torch.cat([
|
||||
non_remaining_scale,
|
||||
kv_cache[blocks[-1], block_size * head_dim:block_size * head_dim +
|
||||
remaining * 4].view(-1, 4)
|
||||
],
|
||||
dim=0)
|
||||
|
||||
expected_value.append(value)
|
||||
expected_scale.append(scale)
|
||||
|
||||
gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim)
|
||||
gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4)
|
||||
gather_value = gather_value.view(torch.float8_e4m3fn)
|
||||
gather_scale = gather_scale.view(torch.float32)
|
||||
dst_value.copy_(gather_value)
|
||||
dst_scale.copy_(gather_scale)
|
||||
|
||||
|
||||
def sparse_attn_indexer(
|
||||
hidden_states: torch.Tensor,
|
||||
k_cache_prefix: str,
|
||||
kv_cache: torch.Tensor,
|
||||
q_fp8: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
quant_block_size: int,
|
||||
scale_fmt: Optional[str],
|
||||
topk_tokens: int,
|
||||
head_dim: int,
|
||||
max_model_len: int,
|
||||
total_seq_lens: int,
|
||||
topk_indices_buffer: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
|
||||
# careful! this will be None in dummy run
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
# assert isinstance(attn_metadata, dict)
|
||||
if not isinstance(attn_metadata, dict):
|
||||
return sparse_attn_indexer_fake(
|
||||
hidden_states,
|
||||
k_cache_prefix,
|
||||
kv_cache,
|
||||
q_fp8,
|
||||
k,
|
||||
weights,
|
||||
quant_block_size,
|
||||
scale_fmt,
|
||||
topk_tokens,
|
||||
head_dim,
|
||||
max_model_len,
|
||||
total_seq_lens,
|
||||
topk_indices_buffer,
|
||||
)
|
||||
attn_metadata = attn_metadata[k_cache_prefix]
|
||||
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
ops.indexer_k_quant_and_cache(
|
||||
k,
|
||||
kv_cache,
|
||||
slot_mapping,
|
||||
quant_block_size,
|
||||
scale_fmt,
|
||||
)
|
||||
|
||||
topk_indices_buffer[:hidden_states.shape[0]] = -1
|
||||
if has_prefill:
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
num_prefills = attn_metadata.num_prefills
|
||||
k_fp8 = torch.empty([prefill_metadata.total_seq_lens, head_dim],
|
||||
device=k.device,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
k_scale = torch.empty([prefill_metadata.total_seq_lens, 1],
|
||||
device=k.device,
|
||||
dtype=torch.float32)
|
||||
cp_gather_indexer_k_quant_cache(
|
||||
kv_cache,
|
||||
k_fp8,
|
||||
k_scale,
|
||||
prefill_metadata.block_table,
|
||||
prefill_metadata.cu_seq_lens,
|
||||
num_prefills,
|
||||
)
|
||||
cu_seqlen_ks = prefill_metadata.cu_seqlen_ks
|
||||
cu_seqlen_ke = prefill_metadata.cu_seqlen_ke
|
||||
num_tokens = attn_metadata.num_actual_tokens
|
||||
logits = fp8_mqa_logits(
|
||||
q_fp8[num_decode_tokens:num_tokens],
|
||||
(k_fp8, k_scale),
|
||||
weights[num_decode_tokens:num_tokens],
|
||||
cu_seqlen_ks,
|
||||
cu_seqlen_ke,
|
||||
)
|
||||
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
|
||||
dim=-1)[1]
|
||||
topk_indices -= cu_seqlen_ks[:, None]
|
||||
mask_lo = topk_indices >= 0
|
||||
mask_hi = topk_indices - (cu_seqlen_ke - cu_seqlen_ks)[:, None] < 0
|
||||
mask = torch.full_like(topk_indices,
|
||||
False,
|
||||
dtype=torch.bool,
|
||||
device=topk_indices.device)
|
||||
mask = mask_lo & mask_hi
|
||||
topk_indices = topk_indices.masked_fill(~mask, -1)
|
||||
topk_indices_buffer[num_decode_tokens:num_tokens, :topk_indices.
|
||||
shape[-1]] = topk_indices.to(dtype=torch.int32)
|
||||
|
||||
if has_decode:
|
||||
decode_metadata = attn_metadata.decode
|
||||
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
|
||||
# we only have [num_block, block_size, head_dim],
|
||||
kv_cache = kv_cache.unsqueeze(-2)
|
||||
decode_lens = decode_metadata.decode_lens
|
||||
if decode_metadata.requires_padding:
|
||||
# pad in edge case where we have short chunked prefill length <
|
||||
# decode_threshold since we unstrictly split
|
||||
# prefill and decode by decode_threshold
|
||||
# (currently set to 1 + speculative tokens)
|
||||
padded_q_fp8_decode_tokens = pack_seq_triton(
|
||||
q_fp8[:num_decode_tokens], decode_lens)
|
||||
else:
|
||||
padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
|
||||
decode_lens.shape[0], -1, *q_fp8.shape[1:])
|
||||
# TODO: move and optimize below logic with triton kernels
|
||||
batch_size = padded_q_fp8_decode_tokens.shape[0]
|
||||
next_n = padded_q_fp8_decode_tokens.shape[1]
|
||||
assert batch_size == decode_metadata.seq_lens.shape[0]
|
||||
num_padded_tokens = batch_size * next_n
|
||||
logits = fp8_paged_mqa_logits(
|
||||
padded_q_fp8_decode_tokens,
|
||||
kv_cache,
|
||||
weights[:num_padded_tokens],
|
||||
decode_metadata.seq_lens,
|
||||
decode_metadata.block_table,
|
||||
decode_metadata.schedule_metadata,
|
||||
max_model_len=max_model_len,
|
||||
)
|
||||
# padded query len
|
||||
current_device = padded_q_fp8_decode_tokens.device
|
||||
padded_num_tokens = batch_size * next_n
|
||||
positions = torch.arange(max_model_len,
|
||||
device=current_device).unsqueeze(0).expand(
|
||||
batch_size * next_n, -1)
|
||||
row_indices = torch.arange(padded_num_tokens,
|
||||
device=current_device) // next_n
|
||||
next_n_offset = torch.arange(
|
||||
padded_num_tokens,
|
||||
device=padded_q_fp8_decode_tokens.device) % next_n
|
||||
index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n +
|
||||
next_n_offset).unsqueeze(1)
|
||||
# index_end_pos: [B * N, 1]
|
||||
mask = positions <= index_end_pos
|
||||
# mask: [B * N, L]
|
||||
logits = logits.masked_fill(~mask, float('-inf'))
|
||||
topk_indices = logits.topk(topk_tokens,
|
||||
dim=-1)[1].to(torch.int32) # [B * N, K]
|
||||
# ensure we don't set indices for the top k
|
||||
# that is out of range(masked already)
|
||||
# this will happen if context length is shorter than K
|
||||
topk_indices[topk_indices > index_end_pos] = -1
|
||||
if decode_metadata.requires_padding:
|
||||
# if padded, we need to unpack
|
||||
# the topk indices removing padded tokens
|
||||
topk_indices = unpack_seq_triton(
|
||||
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
|
||||
decode_lens)
|
||||
topk_indices_buffer[:num_decode_tokens, :topk_indices.
|
||||
shape[-1]] = topk_indices.to(dtype=torch.int32)
|
||||
|
||||
return topk_indices_buffer
|
||||
|
||||
|
||||
def sparse_attn_indexer_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
k_cache_prefix: str,
|
||||
kv_cache: torch.Tensor,
|
||||
q_fp8: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
quant_block_size: int,
|
||||
scale_fmt: Optional[str],
|
||||
topk_tokens: int,
|
||||
head_dim: int,
|
||||
max_model_len: int,
|
||||
total_seq_lens: int,
|
||||
topk_indices_buffer: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# profile run
|
||||
# NOTE(Chen): create the max possible flattened_kv. So that
|
||||
# profile_run can get correct memory usage.
|
||||
_flattened_kv = torch.empty([total_seq_lens, head_dim + 4],
|
||||
device=k.device,
|
||||
dtype=torch.uint8)
|
||||
_k_fp8 = _flattened_kv[..., :head_dim].view(
|
||||
torch.float8_e4m3fn).contiguous()
|
||||
_k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
|
||||
return topk_indices_buffer
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="sparse_attn_indexer",
|
||||
op_func=sparse_attn_indexer,
|
||||
mutates_args=["topk_indices_buffer"],
|
||||
fake_impl=sparse_attn_indexer_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
class Indexer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
config: Union[DeepseekV2Config, DeepseekV3Config],
|
||||
hidden_size: int,
|
||||
q_lora_rank: int,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
cache_config: Optional[CacheConfig],
|
||||
topk_indices_buffer: Optional[torch.Tensor],
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.vllm_config = vllm_config
|
||||
self.config = config
|
||||
# self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
|
||||
self.topk_tokens = config.index_topk
|
||||
self.n_head = config.index_n_heads # 64
|
||||
self.head_dim = config.index_head_dim # 128
|
||||
self.rope_dim = config.qk_rope_head_dim # 64
|
||||
self.q_lora_rank = q_lora_rank # 1536
|
||||
# no tensor parallel, just replicated
|
||||
self.wq_b = ReplicatedLinear(self.q_lora_rank,
|
||||
self.head_dim * self.n_head,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wq_b")
|
||||
self.wk = ReplicatedLinear(hidden_size,
|
||||
self.head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wk")
|
||||
self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
|
||||
self.weights_proj = ReplicatedLinear(hidden_size,
|
||||
self.n_head,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.weights_proj")
|
||||
self.softmax_scale = self.head_dim**-0.5
|
||||
|
||||
self.scale_fmt = "ue8m0"
|
||||
self.quant_block_size = 128 # TODO: get from config
|
||||
self.topk_indices_buffer = topk_indices_buffer
|
||||
|
||||
# NOTE: (zyongye) we use fp8 naive cache,
|
||||
# where we store value in fp8 and scale in fp32
|
||||
# per self.quant_block_size element
|
||||
self.k_cache = DeepseekV32IndexerCache(
|
||||
head_dim=self.head_dim +
|
||||
self.head_dim // self.quant_block_size * 4,
|
||||
dtype=torch.uint8,
|
||||
prefix=f"{prefix}.k_cache",
|
||||
cache_config=cache_config)
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.prefix = prefix
|
||||
from vllm.v1.attention.backends.mla.indexer import (
|
||||
get_max_prefill_buffer_size)
|
||||
self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions,
|
||||
rotary_emb) -> torch.Tensor:
|
||||
q, _ = self.wq_b(qr)
|
||||
q = q.view(-1, self.n_head, self.head_dim)
|
||||
q_pe, q_nope = torch.split(
|
||||
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
|
||||
|
||||
k, _ = self.wk(hidden_states)
|
||||
k = self.k_norm(k)
|
||||
k_pe, k_nope = torch.split(
|
||||
k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
|
||||
|
||||
q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
|
||||
q = torch.cat([q_pe, q_nope], dim=-1)
|
||||
k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)
|
||||
|
||||
# we only quant q here since k quant is fused with cache insertion
|
||||
q = q.view(-1, self.head_dim)
|
||||
q_fp8, q_scale = per_token_group_quant_fp8(q,
|
||||
self.quant_block_size,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=self.scale_fmt
|
||||
is not None)
|
||||
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
|
||||
q_scale = q_scale.view(-1, self.n_head, 1)
|
||||
|
||||
weights, _ = self.weights_proj(hidden_states)
|
||||
weights = weights.unsqueeze(
|
||||
-1) * q_scale * self.softmax_scale * self.n_head**-0.5
|
||||
weights = weights.squeeze(-1)
|
||||
|
||||
return torch.ops.vllm.sparse_attn_indexer(
|
||||
hidden_states,
|
||||
self.k_cache.prefix,
|
||||
self.k_cache.kv_cache[0],
|
||||
q_fp8,
|
||||
k,
|
||||
weights,
|
||||
self.quant_block_size,
|
||||
self.scale_fmt,
|
||||
self.topk_tokens,
|
||||
self.head_dim,
|
||||
self.max_model_len,
|
||||
self.max_total_seq_len,
|
||||
self.topk_indices_buffer,
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV2MLAAttention(nn.Module):
|
||||
"""
|
||||
Main reference: DeepseekV2 paper, and FlashInfer Implementation
|
||||
@@ -429,6 +839,7 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
config: Union[DeepseekV2Config, DeepseekV3Config],
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
@@ -443,6 +854,7 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
topk_indices_buffer: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -523,6 +935,15 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
self.is_v32 = hasattr(config, "index_topk")
|
||||
|
||||
if self.is_v32:
|
||||
self.indexer = Indexer(vllm_config, config, hidden_size,
|
||||
q_lora_rank, quant_config, cache_config,
|
||||
topk_indices_buffer, f"{prefix}.indexer")
|
||||
else:
|
||||
self.indexer = None
|
||||
|
||||
mla_modules = MLAModules(
|
||||
kv_a_layernorm=self.kv_a_layernorm,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
@@ -536,7 +957,11 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
if self.q_lora_rank is not None else None,
|
||||
q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
|
||||
q_proj=self.q_proj if self.q_lora_rank is None else None,
|
||||
indexer=self.indexer,
|
||||
is_sparse=self.is_v32,
|
||||
topk_indices_buffer=topk_indices_buffer,
|
||||
)
|
||||
|
||||
self.mla_attn = MultiHeadLatentAttention(
|
||||
self.hidden_size,
|
||||
self.num_local_heads,
|
||||
@@ -562,7 +987,10 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
|
||||
class DeepseekV2DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str,
|
||||
topk_indices_buffer: Optional[torch.Tensor] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
@@ -585,6 +1013,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
else:
|
||||
attn_cls = DeepseekV2Attention
|
||||
self.self_attn = attn_cls(
|
||||
vllm_config=vllm_config,
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
@@ -600,6 +1029,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
topk_indices_buffer=topk_indices_buffer,
|
||||
)
|
||||
|
||||
if (config.n_routed_experts is not None
|
||||
@@ -683,6 +1113,16 @@ class DeepseekV2Model(nn.Module):
|
||||
self.config = config
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.is_v32 = hasattr(config, "index_topk")
|
||||
if self.is_v32:
|
||||
topk_tokens = config.index_topk
|
||||
topk_indices_buffer = torch.empty(
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
topk_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
else:
|
||||
topk_indices_buffer = None
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
@@ -695,7 +1135,8 @@ class DeepseekV2Model(nn.Module):
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix),
|
||||
lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix,
|
||||
topk_indices_buffer),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
|
||||
Reference in New Issue
Block a user