[Docs] Fix warnings in mkdocs build (continued) (#24092)

Signed-off-by: Zerohertz <ohg3417@gmail.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Hyogeun Oh (오효근)
2025-09-10 22:23:28 +09:00
committed by GitHub
parent c0bd6a684a
commit ccee371e86
10 changed files with 337 additions and 342 deletions

View File

@@ -9,7 +9,7 @@ model alternates between state space model layers and attention-based layers.
"""
from collections.abc import Iterable
from itertools import cycle
from typing import Optional, Union
from typing import Any, Optional, Union
import torch
from torch import nn
@@ -528,8 +528,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
hidden_states: Input tensor [batch_size, seq_len, hidden_size]
mamba_cache_params: Parameters for Mamba's state caches
(one for conv, one for ssm)
sequence_idx: Index tensor for identifying sequences in batch
Required for proper chunked processing in prefill
transformer_hidden_states: Optional output from transformer path
Added to input if provided (used in hybrid architecture)
positions: Optional position IDs (unused in Mamba)
@@ -591,8 +589,6 @@ class Zamba2HybridLayer(nn.Module):
Args:
shared_transformer: Transformer decoder layer for attention pathway
linear: Linear projection for transformer output before Mamba
mamba: Mamba decoder layer for state space pathway
"""
super().__init__()
self.block_idx = block_idx
@@ -630,8 +626,6 @@ class Zamba2HybridLayer(nn.Module):
positions: Position IDs for positional embeddings
mamba_cache_params: Parameters for Mamba's state caches
(one for conv, one for ssm)
sequence_idx: Indices for identifying sequences in batch,
required for proper chunked processing in prefill
Returns:
Output tensor combining transformer and Mamba representations
@@ -915,8 +909,8 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
prefix: Optional prefix for parameter names
Raises:
AssertionError: If prefix caching is enabled (not supported by
Mamba)
AssertionError: If prefix caching is enabled
(not supported by Mamba)
"""
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
@@ -971,7 +965,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
**kwargs: Any) -> torch.Tensor:
"""Forward pass through the model.
Args:
@@ -1012,9 +1006,9 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers: dict[str,
torch.Tensor],
**kwargs) -> dict[str, torch.Tensor]:
def copy_inputs_before_cuda_graphs(
self, input_buffers: dict[str, torch.Tensor],
**kwargs: Any) -> dict[str, torch.Tensor]:
"""Copy inputs before CUDA graph capture.
Args: