[Kernel/Model] Migrate mamba_ssm and causal_conv1d kernels to vLLM (#7651)

This commit is contained in:
Mor Zusman
2024-08-29 01:06:52 +03:00
committed by GitHub
parent 8c56e57def
commit fdd9daafa3
20 changed files with 2815 additions and 31 deletions

View File

@@ -4,9 +4,6 @@ from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Tuple
import torch
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from torch import nn
from torch.nn.parameter import Parameter
from transformers import JambaConfig
@@ -24,6 +21,10 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
@@ -161,7 +162,7 @@ class JambaMambaMixer(nn.Module):
(self.conv_kernel_size - hidden_states.shape[-1], 0))
cache_params.conv_state.copy_(conv_states)
hidden_states = causal_conv1d_fn(
hidden_states, _ = causal_conv1d_fn(
hidden_states,
conv_weights,
self.conv1d.bias,