[Docs] Fix warnings in mkdocs build (continued) (#24092)
Signed-off-by: Zerohertz <ohg3417@gmail.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
@@ -9,7 +9,7 @@ model alternates between state space model layers and attention-based layers.
|
||||
"""
|
||||
from collections.abc import Iterable
|
||||
from itertools import cycle
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -528,8 +528,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
||||
hidden_states: Input tensor [batch_size, seq_len, hidden_size]
|
||||
mamba_cache_params: Parameters for Mamba's state caches
|
||||
(one for conv, one for ssm)
|
||||
sequence_idx: Index tensor for identifying sequences in batch
|
||||
Required for proper chunked processing in prefill
|
||||
transformer_hidden_states: Optional output from transformer path
|
||||
Added to input if provided (used in hybrid architecture)
|
||||
positions: Optional position IDs (unused in Mamba)
|
||||
@@ -591,8 +589,6 @@ class Zamba2HybridLayer(nn.Module):
|
||||
|
||||
Args:
|
||||
shared_transformer: Transformer decoder layer for attention pathway
|
||||
linear: Linear projection for transformer output before Mamba
|
||||
mamba: Mamba decoder layer for state space pathway
|
||||
"""
|
||||
super().__init__()
|
||||
self.block_idx = block_idx
|
||||
@@ -630,8 +626,6 @@ class Zamba2HybridLayer(nn.Module):
|
||||
positions: Position IDs for positional embeddings
|
||||
mamba_cache_params: Parameters for Mamba's state caches
|
||||
(one for conv, one for ssm)
|
||||
sequence_idx: Indices for identifying sequences in batch,
|
||||
required for proper chunked processing in prefill
|
||||
|
||||
Returns:
|
||||
Output tensor combining transformer and Mamba representations
|
||||
@@ -915,8 +909,8 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
prefix: Optional prefix for parameter names
|
||||
|
||||
Raises:
|
||||
AssertionError: If prefix caching is enabled (not supported by
|
||||
Mamba)
|
||||
AssertionError: If prefix caching is enabled
|
||||
(not supported by Mamba)
|
||||
"""
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
@@ -971,7 +965,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
**kwargs: Any) -> torch.Tensor:
|
||||
"""Forward pass through the model.
|
||||
|
||||
Args:
|
||||
@@ -1012,9 +1006,9 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers: dict[str,
|
||||
torch.Tensor],
|
||||
**kwargs) -> dict[str, torch.Tensor]:
|
||||
def copy_inputs_before_cuda_graphs(
|
||||
self, input_buffers: dict[str, torch.Tensor],
|
||||
**kwargs: Any) -> dict[str, torch.Tensor]:
|
||||
"""Copy inputs before CUDA graph capture.
|
||||
|
||||
Args:
|
||||
|
||||
Reference in New Issue
Block a user