[Refactor] Remove redundant TP gather/split in split_qkv in QwenVL (#28271)
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
This commit is contained in:
@@ -291,25 +291,6 @@ class Qwen2_5_VisionMLP(nn.Module):
|
|||||||
return x_down
|
return x_down
|
||||||
|
|
||||||
|
|
||||||
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
|
|
||||||
"""All-gather the input tensor interleavely across model parallel group."""
|
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
|
|
||||||
dist.all_gather(
|
|
||||||
gathered_tensors, local_tensor, group=parallel_state.get_tp_group().device_group
|
|
||||||
)
|
|
||||||
|
|
||||||
gathered_tensors_split = [
|
|
||||||
torch.split(tensor, hidden_size // tp_size, -1) for tensor in gathered_tensors
|
|
||||||
]
|
|
||||||
ordered_tensors = [
|
|
||||||
tensor for pair in zip(*gathered_tensors_split) for tensor in pair
|
|
||||||
]
|
|
||||||
result_tensor = torch.cat(ordered_tensors, dim=-1)
|
|
||||||
return result_tensor
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen2_5_VisionAttention(nn.Module):
|
class Qwen2_5_VisionAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -383,21 +364,10 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||||
# [s, b, 3 * head * head_dim]
|
# [s, b, 3 * head * head_dim]
|
||||||
seq_len, bs, _ = qkv.shape
|
seq_len, bs, _ = qkv.shape
|
||||||
if self.tp_size > 1:
|
|
||||||
qkv = all_gather_interleave(qkv, self.qkv.hidden_size, self.tp_size)
|
|
||||||
|
|
||||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
|
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
|
||||||
q, k, v = qkv.chunk(3, dim=2)
|
q, k, v = qkv.chunk(3, dim=2)
|
||||||
|
|
||||||
# 3 * [s, b, head * head_dim]
|
|
||||||
if self.tp_size > 1:
|
|
||||||
splitter = partial(
|
|
||||||
dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size
|
|
||||||
)
|
|
||||||
q = splitter(q)[self.tp_rank]
|
|
||||||
k = splitter(k)[self.tp_rank]
|
|
||||||
v = splitter(v)[self.tp_rank]
|
|
||||||
|
|
||||||
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
|
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||||
new_shape = (
|
new_shape = (
|
||||||
seq_len,
|
seq_len,
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ from vllm.attention.layer import (
|
|||||||
)
|
)
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.multimodal import BaseDummyOptions
|
from vllm.config.multimodal import BaseDummyOptions
|
||||||
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
|
from vllm.distributed import parallel_state
|
||||||
from vllm.distributed import utils as dist_utils
|
from vllm.distributed import utils as dist_utils
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.activation import QuickGELU
|
from vllm.model_executor.layers.activation import QuickGELU
|
||||||
@@ -396,21 +396,10 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||||
# [s, b, 3 * head * head_dim]
|
# [s, b, 3 * head * head_dim]
|
||||||
seq_len, bs, _ = qkv.shape
|
seq_len, bs, _ = qkv.shape
|
||||||
if self.tp_size > 1:
|
|
||||||
qkv = tensor_model_parallel_all_gather(qkv)
|
|
||||||
|
|
||||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
|
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
|
||||||
q, k, v = qkv.chunk(3, dim=2)
|
q, k, v = qkv.chunk(3, dim=2)
|
||||||
|
|
||||||
# 3 * [s, b, head * head_dim]
|
|
||||||
if self.tp_size > 1:
|
|
||||||
splitter = partial(
|
|
||||||
dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size
|
|
||||||
)
|
|
||||||
q = splitter(q)[self.tp_rank]
|
|
||||||
k = splitter(k)[self.tp_rank]
|
|
||||||
v = splitter(v)[self.tp_rank]
|
|
||||||
|
|
||||||
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
|
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||||
new_shape = (
|
new_shape = (
|
||||||
seq_len,
|
seq_len,
|
||||||
|
|||||||
Reference in New Issue
Block a user