[Model][Spec Decode] Nemotron-H MTP and Mamba Speculative Decoding Support (#33726)
Signed-off-by: Shahar Mor <smor@nvidia.com> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Shahar Mor <smor@nvidia.com> Co-authored-by: Roi Koren <roik@nvidia.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
committed by
GitHub
parent
a9e15e040d
commit
f5972a872f
@@ -228,6 +228,7 @@ class Mamba2ForCausalLM(
|
||||
head_dim=hf_config.head_dim,
|
||||
state_size=hf_config.state_size,
|
||||
conv_kernel=hf_config.conv_kernel,
|
||||
num_spec=vllm_config.num_speculative_tokens,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -636,6 +636,9 @@ class NemotronHModel(nn.Module):
|
||||
hidden_states, _ = self.norm_f(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def is_spec_layer(self, config: NemotronHConfig, weight_name: str) -> bool:
|
||||
return weight_name.startswith("mtp.")
|
||||
|
||||
def _get_max_n_routed_experts(self) -> int:
|
||||
"""Get max n_routed_experts from config or block_configs for puzzle models.
|
||||
|
||||
@@ -702,6 +705,10 @@ class NemotronHModel(nn.Module):
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
# Skip MTP/spec decode layers early (before stacked params mapping)
|
||||
if name.startswith("mtp."):
|
||||
continue
|
||||
|
||||
# load stacked params
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
@@ -845,6 +852,7 @@ class NemotronHForCausalLM(
|
||||
head_dim=hf_config.mamba_head_dim,
|
||||
state_size=hf_config.ssm_state_size,
|
||||
conv_kernel=hf_config.conv_kernel,
|
||||
num_spec=vllm_config.num_speculative_tokens,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
503
vllm/model_executor/models/nemotron_h_mtp.py
Normal file
503
vllm/model_executor/models/nemotron_h_mtp.py
Normal file
@@ -0,0 +1,503 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""NemotronH-MTP model with attention layers."""
|
||||
|
||||
import typing
|
||||
from collections.abc import Callable, Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import (
|
||||
make_empty_intermediate_tensors_factory,
|
||||
maybe_prefix,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import NemotronHConfig
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .nemotron_h import (
|
||||
NemotronHAttentionDecoderLayer,
|
||||
NemotronHMoEDecoderLayer,
|
||||
)
|
||||
|
||||
|
||||
class NemotronHMTPAttentionDecoderLayer(NemotronHAttentionDecoderLayer):
|
||||
def __init__(
|
||||
self,
|
||||
config: NemotronHConfig,
|
||||
layer_idx: int,
|
||||
model_config: ModelConfig | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
parallel_config: ParallelConfig | None = None,
|
||||
prefix: str = "",
|
||||
has_start_projections: bool = False,
|
||||
has_end_norm: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
config=config,
|
||||
layer_idx=layer_idx,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
parallel_config=parallel_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
self.has_start_projections = has_start_projections
|
||||
self.has_end_norm = has_end_norm
|
||||
|
||||
if has_start_projections:
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
# Fusion layer to combine embeddings with target hidden states
|
||||
self.eh_proj = ColumnParallelLinear(
|
||||
input_size=config.hidden_size * 2,
|
||||
output_size=config.hidden_size,
|
||||
bias=False,
|
||||
gather_output=True,
|
||||
params_dtype=config.dtype
|
||||
if hasattr(config, "dtype")
|
||||
else torch.bfloat16,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.eh_proj",
|
||||
)
|
||||
|
||||
if has_end_norm:
|
||||
self.final_layernorm = RMSNorm(
|
||||
config.hidden_size,
|
||||
eps=getattr(config, "layer_norm_epsilon", 1e-5),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
# Start projections (Fusion)
|
||||
if self.has_start_projections:
|
||||
# Normalize both inputs before fusion
|
||||
assert inputs_embeds is not None
|
||||
inputs_embeds_normed = self.enorm(inputs_embeds)
|
||||
previous_hidden_states_normed = self.hnorm(hidden_states)
|
||||
|
||||
# Fuse via concatenation and linear projection
|
||||
fused = torch.cat(
|
||||
[inputs_embeds_normed, previous_hidden_states_normed], dim=-1
|
||||
)
|
||||
hidden_states, _ = self.eh_proj(fused)
|
||||
|
||||
# Call parent forward (Attention)
|
||||
# Parent forward expects: hidden_states, residual
|
||||
hidden_states, residual = super().forward(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
)
|
||||
|
||||
# End norm
|
||||
if self.has_end_norm:
|
||||
if residual is not None:
|
||||
hidden_states = hidden_states + residual
|
||||
residual = None # Consumed residual
|
||||
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class NemotronHMTPMoEDecoderLayer(NemotronHMoEDecoderLayer):
|
||||
def __init__(
|
||||
self,
|
||||
config: NemotronHConfig,
|
||||
layer_idx: int,
|
||||
model_config: ModelConfig | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
parallel_config: ParallelConfig | None = None,
|
||||
prefix: str = "",
|
||||
has_start_projections: bool = False,
|
||||
has_end_norm: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
config=config,
|
||||
layer_idx=layer_idx,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
parallel_config=parallel_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
self.has_start_projections = has_start_projections
|
||||
self.has_end_norm = has_end_norm
|
||||
|
||||
if has_start_projections:
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
# Fusion layer to combine embeddings with target hidden states
|
||||
self.eh_proj = ColumnParallelLinear(
|
||||
input_size=config.hidden_size * 2,
|
||||
output_size=config.hidden_size,
|
||||
bias=False,
|
||||
gather_output=True,
|
||||
params_dtype=config.dtype
|
||||
if hasattr(config, "dtype")
|
||||
else torch.bfloat16,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.eh_proj",
|
||||
)
|
||||
|
||||
if has_end_norm:
|
||||
self.final_layernorm = RMSNorm(
|
||||
config.hidden_size,
|
||||
eps=getattr(config, "layer_norm_epsilon", 1e-5),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
# Start projections (Fusion)
|
||||
if self.has_start_projections:
|
||||
# Normalize both inputs before fusion
|
||||
assert inputs_embeds is not None
|
||||
inputs_embeds_normed = self.enorm(inputs_embeds)
|
||||
previous_hidden_states_normed = self.hnorm(hidden_states)
|
||||
|
||||
# Fuse via concatenation and linear projection
|
||||
fused = torch.cat(
|
||||
[inputs_embeds_normed, previous_hidden_states_normed], dim=-1
|
||||
)
|
||||
hidden_states, _ = self.eh_proj(fused)
|
||||
|
||||
# Call parent forward (MoE)
|
||||
hidden_states, residual = super().forward(
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
)
|
||||
|
||||
# End norm
|
||||
if self.has_end_norm:
|
||||
if residual is not None:
|
||||
hidden_states = hidden_states + residual
|
||||
residual = None # Consumed residual
|
||||
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class NemotronHMultiTokenPredictor(nn.Module):
|
||||
"""MTP predictor with NemotronH layers."""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.org_vocab_size = config.vocab_size
|
||||
|
||||
self.mtp_start_layer_idx = config.num_hidden_layers
|
||||
self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1)
|
||||
assert self.num_mtp_layers == 1, (
|
||||
"Only one MTP layer is supported for NemotronH-MTP"
|
||||
)
|
||||
|
||||
self.pattern_str = config.mtp_hybrid_override_pattern
|
||||
self.pattern_len = len(self.pattern_str)
|
||||
assert self.pattern_len > 0
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
|
||||
# Build flat list of layers
|
||||
self.layers = torch.nn.ModuleDict()
|
||||
|
||||
# Total number of physical layers = num_steps * pattern_len
|
||||
total_layers = self.num_mtp_layers * self.pattern_len
|
||||
for i in range(total_layers):
|
||||
step_rel_idx = i % self.pattern_len
|
||||
|
||||
char = self.pattern_str[step_rel_idx]
|
||||
|
||||
is_start_of_step = step_rel_idx == 0
|
||||
is_end_of_step = step_rel_idx == self.pattern_len - 1
|
||||
|
||||
layer_prefix = f"{prefix}.layers.{i}"
|
||||
|
||||
# TODO smor- remove double layers formation
|
||||
common_kwargs = dict(
|
||||
config=config,
|
||||
layer_idx=self.mtp_start_layer_idx + i,
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
quant_config=vllm_config.quant_config,
|
||||
parallel_config=vllm_config.parallel_config,
|
||||
prefix=layer_prefix,
|
||||
has_start_projections=is_start_of_step,
|
||||
has_end_norm=is_end_of_step,
|
||||
)
|
||||
|
||||
if char == "*":
|
||||
self.layers[str(i)] = NemotronHMTPAttentionDecoderLayer(**common_kwargs)
|
||||
elif char == "E":
|
||||
self.layers[str(i)] = NemotronHMTPMoEDecoderLayer(**common_kwargs)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Pattern char '{char}' in {self.pattern_str} not implemented"
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors: Callable[..., IntermediateTensors] = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
assert self.embed_tokens is not None, (
|
||||
"embed_tokens not initialized - must be shared from target model"
|
||||
)
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
|
||||
residual = None
|
||||
|
||||
for i in range(self.pattern_len):
|
||||
hidden_states, residual = self.layers[str(i)](
|
||||
inputs_embeds=inputs_embeds,
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class NemotronHMTP(nn.Module, SupportsPP):
|
||||
"""NemotronH MTP model."""
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.vllm_config = vllm_config
|
||||
self.config = config
|
||||
self.quant_config = vllm_config.quant_config
|
||||
|
||||
# Needed for load_weights mapping
|
||||
self.mtp_start_layer_idx = config.num_hidden_layers
|
||||
|
||||
# EPLB config for experts
|
||||
self.num_redundant_experts = 0
|
||||
if vllm_config.parallel_config and vllm_config.parallel_config.eplb_config:
|
||||
self.num_redundant_experts = (
|
||||
vllm_config.parallel_config.eplb_config.num_redundant_experts
|
||||
)
|
||||
|
||||
# MTP predictor
|
||||
self.model = NemotronHMultiTokenPredictor(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "mtp")
|
||||
)
|
||||
|
||||
# LM head for generating logits
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
|
||||
self.logits_processor = LogitsProcessor(self.config.vocab_size)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor:
|
||||
"""Forward - applies attention-based MTP."""
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
hidden_states,
|
||||
intermediate_tensors,
|
||||
inputs_embeds,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
"""Compute logits for DRAFT token generation."""
|
||||
assert self.lm_head is not None, (
|
||||
"lm_head not initialized - must be shared from target model"
|
||||
)
|
||||
return self.logits_processor(self.lm_head, hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
"""Load MTP weights with proper name remapping."""
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
expert_params_mapping = []
|
||||
if hasattr(self.config, "n_routed_experts") and self.config.n_routed_experts:
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
self,
|
||||
ckpt_gate_proj_name="up_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="", # Empty - non-gated MoE
|
||||
num_experts=self.config.n_routed_experts,
|
||||
num_redundant_experts=self.num_redundant_experts,
|
||||
)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
# Only process MTP weights - skip all non-MTP weights
|
||||
if (
|
||||
not name.startswith("mtp.")
|
||||
and "embeddings" not in name
|
||||
and "lm_head" not in name
|
||||
):
|
||||
continue
|
||||
# Skip rotary embeddings (computed, not loaded)
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
name = name.replace("mtp.layers.", "model.layers.")
|
||||
|
||||
if "embeddings" in name:
|
||||
name = name.replace("embeddings", "embed_tokens")
|
||||
if name.startswith("backbone."):
|
||||
name = name.replace("backbone.", "model.")
|
||||
|
||||
# Handle stacked parameters (qkv_proj) for attention layers
|
||||
is_stacked = False
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# Must be in a mixer (attention layer)
|
||||
if ".mixer." not in name:
|
||||
continue
|
||||
|
||||
is_stacked = True
|
||||
stacked_name = name.replace(weight_name, param_name)
|
||||
|
||||
if stacked_name.endswith(".bias") and stacked_name not in params_dict:
|
||||
continue
|
||||
|
||||
if stacked_name not in params_dict:
|
||||
# Might be that mapping failed or param doesn't exist
|
||||
continue
|
||||
|
||||
param = params_dict[stacked_name]
|
||||
weight_loader = getattr(param, "weight_loader", None)
|
||||
if weight_loader is not None:
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
loaded_params.add(stacked_name)
|
||||
break
|
||||
|
||||
if is_stacked:
|
||||
continue
|
||||
|
||||
is_expert_weight = False
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
# weight_name is like "experts.0.up_proj."
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
is_expert_weight = True
|
||||
|
||||
# Replace the expert-specific weight name with fused parameter name
|
||||
# e.g., "experts.0.up_proj." -> "experts.w13_"
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
|
||||
if name_mapped not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name_mapped]
|
||||
weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
|
||||
success = weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
name_mapped,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True,
|
||||
)
|
||||
if success:
|
||||
loaded_params.add(name_mapped)
|
||||
break
|
||||
|
||||
if is_expert_weight:
|
||||
continue
|
||||
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
return loaded_params
|
||||
@@ -266,7 +266,8 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer):
|
||||
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
|
||||
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
|
||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||
prep_initial_states = attn_metadata.prep_initial_states
|
||||
chunk_size = attn_metadata.chunk_size
|
||||
@@ -309,13 +310,6 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer):
|
||||
gate_d, gate_p = torch.split(
|
||||
gate[:num_actual_tokens], [num_decodes, num_prefill_tokens], dim=0
|
||||
)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor,
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||
# and decode outputs
|
||||
preallocated_ssm_out = torch.empty(
|
||||
@@ -336,7 +330,7 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer):
|
||||
if has_prefill:
|
||||
# 2. Convolution sequence transformation
|
||||
# - "cache_indices" updates the conv_state cache in positions
|
||||
# pointed to by "state_indices_tensor"
|
||||
# pointed to by "state_indices_tensor_p"
|
||||
x = hidden_states_p.transpose(0, 1) # this is the form that causal-conv see
|
||||
hidden_states_p = causal_conv1d_fn(
|
||||
x,
|
||||
|
||||
@@ -522,6 +522,7 @@ _SPECULATIVE_DECODING_MODELS = {
|
||||
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
|
||||
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
|
||||
"ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"),
|
||||
"NemotronHMTPModel": ("nemotron_h_mtp", "NemotronHMTP"),
|
||||
"LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
|
||||
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
|
||||
"Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"),
|
||||
|
||||
Reference in New Issue
Block a user