[Model] Support DP for ViT on Kimi-VL-A3B-Thinking-2506 (#23817)
Signed-off-by: Junhong <liujunhong11@huawei.com> Signed-off-by: LJH-LBJ <98734602+LJH-LBJ@users.noreply.github.com> Co-authored-by: Junhong <liujunhong11@huawei.com> Co-authored-by: LJH-LBJ <98734602+LJH-LBJ@users.noreply.github.com> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -42,7 +42,6 @@
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
import math
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from functools import cached_property
|
||||
@@ -55,6 +54,8 @@ from transformers.activations import ACT2FN, PytorchGELUTanh
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.transformers_utils.configs.moonvit import MoonViTConfig
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
@@ -383,21 +384,30 @@ class MLP2(nn.Module):
|
||||
bias: whether to use bias in linear layer.
|
||||
"""
|
||||
|
||||
def __init__(self, dims: list[int], activation, bias=True):
|
||||
def __init__(self,
|
||||
dims: list[int],
|
||||
activation,
|
||||
bias=True,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False):
|
||||
super().__init__()
|
||||
assert len(dims) == 3
|
||||
self.fc0 = nn.Linear(dims[0], dims[1], bias=bias)
|
||||
self.fc1 = nn.Linear(dims[1], dims[2], bias=bias)
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.fc0 = ReplicatedLinear(dims[0],
|
||||
dims[1],
|
||||
bias=bias,
|
||||
prefix=maybe_prefix(prefix, "fc0"))
|
||||
self.fc1 = ReplicatedLinear(dims[1],
|
||||
dims[2],
|
||||
bias=bias,
|
||||
prefix=maybe_prefix(prefix, "fc1"))
|
||||
self.activation = activation
|
||||
for m in [self.fc0, self.fc1]:
|
||||
nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features))
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.fc0(x)
|
||||
x, _ = self.fc0(x)
|
||||
x = self.activation(x)
|
||||
return self.fc1(x)
|
||||
x, _ = self.fc1(x)
|
||||
return x
|
||||
|
||||
|
||||
class MoonVitEncoderLayer(nn.Module):
|
||||
@@ -407,6 +417,8 @@ class MoonVitEncoderLayer(nn.Module):
|
||||
num_heads: int,
|
||||
hidden_dim: int,
|
||||
mlp_dim: int,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
*,
|
||||
attn_implementation: str = "sdpa",
|
||||
activation=F.gelu,
|
||||
@@ -423,9 +435,19 @@ class MoonVitEncoderLayer(nn.Module):
|
||||
|
||||
self.norm0 = nn.LayerNorm(hidden_dim)
|
||||
self.norm1 = nn.LayerNorm(hidden_dim)
|
||||
self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation)
|
||||
self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias)
|
||||
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias)
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim],
|
||||
activation,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel)
|
||||
self.wqkv = ReplicatedLinear(hidden_dim,
|
||||
hidden_dim * 3,
|
||||
bias=attn_bias,
|
||||
prefix=f"{prefix}.wqkv")
|
||||
self.wo = ReplicatedLinear(hidden_dim,
|
||||
hidden_dim,
|
||||
bias=attn_bias,
|
||||
prefix=f"{prefix}.wo")
|
||||
|
||||
def attention_qkvpacked(
|
||||
self,
|
||||
@@ -438,7 +460,7 @@ class MoonVitEncoderLayer(nn.Module):
|
||||
x (torch.Tensor): (batch_size, seqlen, hidden_dim)
|
||||
cu_seqlens (torch.Tensor):
|
||||
"""
|
||||
xqkv = self.wqkv(x)
|
||||
xqkv, _ = self.wqkv(x)
|
||||
|
||||
qkv_shape = xqkv.size()[:-1] + (
|
||||
3,
|
||||
@@ -457,8 +479,7 @@ class MoonVitEncoderLayer(nn.Module):
|
||||
xv,
|
||||
q_cu_seqlens=cu_seqlens,
|
||||
k_cu_seqlens=cu_seqlens)
|
||||
|
||||
attn_out = self.wo(attn_out)
|
||||
attn_out, _ = self.wo(attn_out)
|
||||
return attn_out
|
||||
|
||||
def forward(
|
||||
@@ -494,13 +515,17 @@ class MoonVitEncoder(nn.Module):
|
||||
hidden_dim: int,
|
||||
num_layers: int,
|
||||
block_cfg: dict,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.rope_2d = Rope2DPosEmb(
|
||||
block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512)
|
||||
self.blocks = nn.ModuleList(
|
||||
[MoonVitEncoderLayer(**block_cfg) for _ in range(num_layers)])
|
||||
[MoonVitEncoderLayer(use_data_parallel=use_data_parallel, \
|
||||
prefix=f"{prefix}.blocks.{layer_idx}", \
|
||||
**block_cfg) for layer_idx in range(num_layers)])
|
||||
self.final_layernorm = nn.LayerNorm(hidden_dim)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
@@ -508,10 +533,9 @@ class MoonVitEncoder(nn.Module):
|
||||
rope_freqs_cis = self.rope_2d.get_freqs_cis_by_seqlens(
|
||||
grid_hws=grid_hw)
|
||||
|
||||
lengths = torch.cat((
|
||||
torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype),
|
||||
grid_hw[:, 0] * grid_hw[:, 1],
|
||||
))
|
||||
lengths = torch.cat(
|
||||
(torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype),
|
||||
(grid_hw[:, 0] * grid_hw[:, 1]).to(hidden_states.device)))
|
||||
cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32)
|
||||
|
||||
for _, block in enumerate(self.blocks):
|
||||
@@ -587,11 +611,19 @@ class MoonVitPretrainedModel(PreTrainedModel):
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def __init__(self, config: MoonViTConfig, *inputs, **kwargs):
|
||||
def __init__(self,
|
||||
config: MoonViTConfig,
|
||||
use_data_parallel: bool = False,
|
||||
prefix: str = "",
|
||||
*inputs,
|
||||
**kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
config = deepcopy(config)
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.merge_kernel_size = config.merge_kernel_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.patch_size = config.patch_size
|
||||
self.vit_processing_type = "rope_2d"
|
||||
self.patch_embed = MoonVisionPatchEmbed(
|
||||
out_dim=config.hidden_size,
|
||||
patch_size=config.patch_size,
|
||||
@@ -610,6 +642,7 @@ class MoonVitPretrainedModel(PreTrainedModel):
|
||||
"attn_bias": True,
|
||||
"attn_implementation": config._attn_implementation,
|
||||
},
|
||||
prefix=f"{prefix}.encoder",
|
||||
)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user