[Perf]:Optimize qwen2-vl to reduce cudaMemcpyAsync (#14377)
Signed-off-by: cynthieye <987073381@qq.com>
This commit is contained in:
@@ -259,6 +259,8 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
rotary_pos_emb: torch.Tensor,
|
rotary_pos_emb: torch.Tensor,
|
||||||
|
max_seqlen: Optional[int] = None, # Only used for Flash Attention
|
||||||
|
seqlens: Optional[list[int]] = None, # Only used for xFormers
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||||
x, _ = self.qkv(x)
|
x, _ = self.qkv(x)
|
||||||
@@ -285,7 +287,6 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
|
|
||||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||||
|
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
|
||||||
output = flash_attn_varlen_func(q,
|
output = flash_attn_varlen_func(q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
@@ -321,7 +322,6 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
from xformers import ops as xops
|
from xformers import ops as xops
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||||
|
|
||||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
|
||||||
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
|
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
|
||||||
kv_seqlen=None,
|
kv_seqlen=None,
|
||||||
device=q.device)
|
device=q.device)
|
||||||
@@ -364,11 +364,20 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mlp")
|
prefix=f"{prefix}.mlp")
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
|
def forward(
|
||||||
rotary_pos_emb: torch.Tensor) -> torch.Tensor:
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
rotary_pos_emb: torch.Tensor,
|
||||||
|
max_seqlen: Optional[int] = None, # Only used for Flash Attention
|
||||||
|
seqlens: Optional[list[int]] = None, # Only used for xFormers
|
||||||
|
) -> torch.Tensor:
|
||||||
x = x + self.attn(self.norm1(x),
|
x = x + self.attn(self.norm1(x),
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
rotary_pos_emb=rotary_pos_emb)
|
rotary_pos_emb=rotary_pos_emb,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
seqlens=seqlens)
|
||||||
|
|
||||||
x = x + self.mlp(self.norm2(x))
|
x = x + self.mlp(self.norm2(x))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@@ -528,6 +537,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.merger",
|
prefix=f"{prefix}.merger",
|
||||||
)
|
)
|
||||||
|
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> torch.dtype:
|
def dtype(self) -> torch.dtype:
|
||||||
@@ -633,14 +643,25 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
|
|
||||||
# transformers
|
# transformers
|
||||||
hidden_states = hidden_states.unsqueeze(1)
|
hidden_states = hidden_states.unsqueeze(1)
|
||||||
|
|
||||||
|
max_seqlen = None
|
||||||
|
seqlens = None
|
||||||
|
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||||
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||||
|
elif self.attn_backend == _Backend.XFORMERS:
|
||||||
|
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||||
for layer_num, blk in enumerate(self.blocks):
|
for layer_num, blk in enumerate(self.blocks):
|
||||||
if layer_num in self.fullatt_block_indexes:
|
if layer_num in self.fullatt_block_indexes:
|
||||||
cu_seqlens_now = cu_seqlens
|
cu_seqlens_now = cu_seqlens
|
||||||
else:
|
else:
|
||||||
cu_seqlens_now = cu_window_seqlens
|
cu_seqlens_now = cu_window_seqlens
|
||||||
hidden_states = blk(hidden_states,
|
hidden_states = blk(
|
||||||
|
hidden_states,
|
||||||
cu_seqlens=cu_seqlens_now,
|
cu_seqlens=cu_seqlens_now,
|
||||||
rotary_pos_emb=rotary_pos_emb)
|
rotary_pos_emb=rotary_pos_emb,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
seqlens=seqlens,
|
||||||
|
)
|
||||||
|
|
||||||
# For Qwen2.5-VL-3B, float16 will overflow at last block
|
# For Qwen2.5-VL-3B, float16 will overflow at last block
|
||||||
# for long visual tokens sequences.
|
# for long visual tokens sequences.
|
||||||
|
|||||||
@@ -307,6 +307,8 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
rotary_pos_emb: torch.Tensor,
|
rotary_pos_emb: torch.Tensor,
|
||||||
|
max_seqlen: Optional[int] = None, # Only used for Flash Attention
|
||||||
|
seqlens: Optional[list[int]] = None, # Only used for xFormers
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
# [s, b, c] --> [s, b, 3 * head * head_dim]
|
# [s, b, c] --> [s, b, 3 * head * head_dim]
|
||||||
@@ -329,7 +331,6 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
|
|
||||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||||
|
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
|
||||||
output = flash_attn_varlen_func(q,
|
output = flash_attn_varlen_func(q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
@@ -365,7 +366,6 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
from xformers import ops as xops
|
from xformers import ops as xops
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||||
|
|
||||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
|
||||||
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
|
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
|
||||||
kv_seqlen=None,
|
kv_seqlen=None,
|
||||||
device=q.device)
|
device=q.device)
|
||||||
@@ -409,11 +409,22 @@ class Qwen2VisionBlock(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mlp")
|
prefix=f"{prefix}.mlp")
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
|
def forward(
|
||||||
rotary_pos_emb: torch.Tensor) -> torch.Tensor:
|
self,
|
||||||
x = x + self.attn(self.norm1(x),
|
x: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
rotary_pos_emb: torch.Tensor,
|
||||||
|
max_seqlen: Optional[int] = None, # Only used for Flash Attention
|
||||||
|
seqlens: Optional[list[int]] = None, # Only used for xFormers
|
||||||
|
) -> torch.Tensor:
|
||||||
|
x = x + self.attn(
|
||||||
|
self.norm1(x),
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
rotary_pos_emb=rotary_pos_emb)
|
rotary_pos_emb=rotary_pos_emb,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
seqlens=seqlens,
|
||||||
|
)
|
||||||
|
|
||||||
x = x + self.mlp(self.norm2(x))
|
x = x + self.mlp(self.norm2(x))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@@ -570,6 +581,7 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.merger",
|
prefix=f"{prefix}.merger",
|
||||||
)
|
)
|
||||||
|
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> torch.dtype:
|
def dtype(self) -> torch.dtype:
|
||||||
@@ -624,8 +636,21 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
|
|
||||||
# transformers
|
# transformers
|
||||||
x = x.unsqueeze(1)
|
x = x.unsqueeze(1)
|
||||||
|
|
||||||
|
max_seqlen = None
|
||||||
|
seqlens = None
|
||||||
|
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||||
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||||
|
elif self.attn_backend == _Backend.XFORMERS:
|
||||||
|
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||||
for blk in self.blocks:
|
for blk in self.blocks:
|
||||||
x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
|
x = blk(
|
||||||
|
x,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
rotary_pos_emb=rotary_pos_emb,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
seqlens=seqlens,
|
||||||
|
)
|
||||||
|
|
||||||
# adapter
|
# adapter
|
||||||
x = self.merger(x)
|
x = self.merger(x)
|
||||||
|
|||||||
Reference in New Issue
Block a user