[Kernel/Model] Migrate mamba_ssm and causal_conv1d kernels to vLLM (#7651)
This commit is contained in:
@@ -4,9 +4,6 @@ from dataclasses import dataclass
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||||
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
|
||||
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
from transformers import JambaConfig
|
||||
@@ -24,6 +21,10 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn, selective_state_update)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
@@ -161,7 +162,7 @@ class JambaMambaMixer(nn.Module):
|
||||
(self.conv_kernel_size - hidden_states.shape[-1], 0))
|
||||
cache_params.conv_state.copy_(conv_states)
|
||||
|
||||
hidden_states = causal_conv1d_fn(
|
||||
hidden_states, _ = causal_conv1d_fn(
|
||||
hidden_states,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
|
||||
Reference in New Issue
Block a user