[Model] Support DP for ViT on Kimi-VL-A3B-Thinking-2506 (#23817)

Signed-off-by: Junhong <liujunhong11@huawei.com>
Signed-off-by: LJH-LBJ <98734602+LJH-LBJ@users.noreply.github.com>
Co-authored-by: Junhong <liujunhong11@huawei.com>
Co-authored-by: LJH-LBJ <98734602+LJH-LBJ@users.noreply.github.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
WeiQing Chen
2025-09-02 00:56:56 +08:00
committed by GitHub
parent cf91a89dd2
commit a0e0efd6bd
6 changed files with 156 additions and 61 deletions

View File

@@ -42,7 +42,6 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import math
from collections.abc import Sequence
from copy import deepcopy
from functools import cached_property
@@ -55,6 +54,8 @@ from transformers.activations import ACT2FN, PytorchGELUTanh
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import is_flash_attn_2_available
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.models.utils import maybe_prefix
from vllm.transformers_utils.configs.moonvit import MoonViTConfig
if is_flash_attn_2_available():
@@ -383,21 +384,30 @@ class MLP2(nn.Module):
bias: whether to use bias in linear layer.
"""
def __init__(self, dims: list[int], activation, bias=True):
def __init__(self,
dims: list[int],
activation,
bias=True,
prefix: str = "",
use_data_parallel: bool = False):
super().__init__()
assert len(dims) == 3
self.fc0 = nn.Linear(dims[0], dims[1], bias=bias)
self.fc1 = nn.Linear(dims[1], dims[2], bias=bias)
self.use_data_parallel = use_data_parallel
self.fc0 = ReplicatedLinear(dims[0],
dims[1],
bias=bias,
prefix=maybe_prefix(prefix, "fc0"))
self.fc1 = ReplicatedLinear(dims[1],
dims[2],
bias=bias,
prefix=maybe_prefix(prefix, "fc1"))
self.activation = activation
for m in [self.fc0, self.fc1]:
nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features))
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc0(x)
x, _ = self.fc0(x)
x = self.activation(x)
return self.fc1(x)
x, _ = self.fc1(x)
return x
class MoonVitEncoderLayer(nn.Module):
@@ -407,6 +417,8 @@ class MoonVitEncoderLayer(nn.Module):
num_heads: int,
hidden_dim: int,
mlp_dim: int,
prefix: str = "",
use_data_parallel: bool = False,
*,
attn_implementation: str = "sdpa",
activation=F.gelu,
@@ -423,9 +435,19 @@ class MoonVitEncoderLayer(nn.Module):
self.norm0 = nn.LayerNorm(hidden_dim)
self.norm1 = nn.LayerNorm(hidden_dim)
self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation)
self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias)
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias)
self.use_data_parallel = use_data_parallel
self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim],
activation,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel)
self.wqkv = ReplicatedLinear(hidden_dim,
hidden_dim * 3,
bias=attn_bias,
prefix=f"{prefix}.wqkv")
self.wo = ReplicatedLinear(hidden_dim,
hidden_dim,
bias=attn_bias,
prefix=f"{prefix}.wo")
def attention_qkvpacked(
self,
@@ -438,7 +460,7 @@ class MoonVitEncoderLayer(nn.Module):
x (torch.Tensor): (batch_size, seqlen, hidden_dim)
cu_seqlens (torch.Tensor):
"""
xqkv = self.wqkv(x)
xqkv, _ = self.wqkv(x)
qkv_shape = xqkv.size()[:-1] + (
3,
@@ -457,8 +479,7 @@ class MoonVitEncoderLayer(nn.Module):
xv,
q_cu_seqlens=cu_seqlens,
k_cu_seqlens=cu_seqlens)
attn_out = self.wo(attn_out)
attn_out, _ = self.wo(attn_out)
return attn_out
def forward(
@@ -494,13 +515,17 @@ class MoonVitEncoder(nn.Module):
hidden_dim: int,
num_layers: int,
block_cfg: dict,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.rope_2d = Rope2DPosEmb(
block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512)
self.blocks = nn.ModuleList(
[MoonVitEncoderLayer(**block_cfg) for _ in range(num_layers)])
[MoonVitEncoderLayer(use_data_parallel=use_data_parallel, \
prefix=f"{prefix}.blocks.{layer_idx}", \
**block_cfg) for layer_idx in range(num_layers)])
self.final_layernorm = nn.LayerNorm(hidden_dim)
def forward(self, hidden_states: torch.Tensor,
@@ -508,10 +533,9 @@ class MoonVitEncoder(nn.Module):
rope_freqs_cis = self.rope_2d.get_freqs_cis_by_seqlens(
grid_hws=grid_hw)
lengths = torch.cat((
torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype),
grid_hw[:, 0] * grid_hw[:, 1],
))
lengths = torch.cat(
(torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype),
(grid_hw[:, 0] * grid_hw[:, 1]).to(hidden_states.device)))
cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32)
for _, block in enumerate(self.blocks):
@@ -587,11 +611,19 @@ class MoonVitPretrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(self, config: MoonViTConfig, *inputs, **kwargs):
def __init__(self,
config: MoonViTConfig,
use_data_parallel: bool = False,
prefix: str = "",
*inputs,
**kwargs):
super().__init__(config, *inputs, **kwargs)
config = deepcopy(config)
self.use_data_parallel = use_data_parallel
self.merge_kernel_size = config.merge_kernel_size
self.hidden_size = config.hidden_size
self.patch_size = config.patch_size
self.vit_processing_type = "rope_2d"
self.patch_embed = MoonVisionPatchEmbed(
out_dim=config.hidden_size,
patch_size=config.patch_size,
@@ -610,6 +642,7 @@ class MoonVitPretrainedModel(PreTrainedModel):
"attn_bias": True,
"attn_implementation": config._attn_implementation,
},
prefix=f"{prefix}.encoder",
)
def forward(self, pixel_values: torch.Tensor,