[Feature]: Remove DtoH Copy for lfm2_vl On Default Stream (#32815)

Signed-off-by: Tianshu Yu <tianshuyu.formal@gmail.com>
This commit is contained in:
tianshu-Michael-yu
2026-01-23 05:20:30 -08:00
committed by GitHub
parent 10e94c84f6
commit 13d8746c54
5 changed files with 260 additions and 158 deletions

View File

@@ -40,99 +40,111 @@ class Siglip2VisionEmbeddings(nn.Module):
self.position_embedding_size = int(self.num_patches**0.5)
self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
@staticmethod
def resize_positional_embeddings(
positional_embeddings: torch.Tensor,
def forward(
self,
pixel_values_packed: torch.FloatTensor,
spatial_shapes: torch.LongTensor,
max_length: int,
) -> torch.Tensor:
"""
Resize positional embeddings to image-specific size and pad to a fixed size.
"""Embed patchified pixel values in packed (unpadded) form.
Args:
positional_embeddings (`torch.Tensor`):
Position embeddings of shape (height, width, embed_dim)
spatial_shapes (`torch.LongTensor`):
Spatial shapes of shape (batch_size, 2) to resize the positional
embeddings to
max_length (`int`):
Maximum length of the positional embeddings to pad resized
positional embeddings to
pixel_values_packed: (1, total_tokens, patch_dim) or
(total_tokens, patch_dim), packed in tile order.
spatial_shapes: (num_tiles, 2) on CPU (height, width) per tile.
Returns:
`torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
(1, total_tokens, embed_dim) packed embeddings.
"""
batch_size = spatial_shapes.shape[0]
assert spatial_shapes.device.type == "cpu", (
"Expected `spatial_shapes` on CPU to avoid device-to-host sync in "
"variable-length packing."
)
if pixel_values_packed.dim() == 3:
assert pixel_values_packed.shape[0] == 1
pixel_values_flat = pixel_values_packed[0]
else:
pixel_values_flat = pixel_values_packed
lengths = (spatial_shapes[:, 0] * spatial_shapes[:, 1]).to(dtype=torch.int64)
lengths_list = lengths.tolist()
total_tokens = int(sum(lengths_list))
if total_tokens != pixel_values_flat.shape[0]:
raise ValueError(
"Packed pixel_values token count does not match spatial_shapes: "
f"{pixel_values_flat.shape[0]} vs {total_tokens}."
)
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values_flat.to(dtype=target_dtype))
positional_embeddings = self.position_embedding.weight.reshape(
self.position_embedding_size, self.position_embedding_size, -1
)
packed_pos_embeds = self.resize_positional_embeddings_packed(
positional_embeddings,
spatial_shapes,
lengths_list=lengths_list,
)
embeddings = patch_embeds + packed_pos_embeds
return embeddings.unsqueeze(0)
@staticmethod
def resize_positional_embeddings_packed(
positional_embeddings: torch.Tensor,
spatial_shapes: torch.LongTensor,
lengths_list: list[int],
) -> torch.Tensor:
"""Resize positional embeddings per image and return a packed tensor.
Args:
positional_embeddings: (height, width, embed_dim) base grid.
spatial_shapes: (batch_size, 2) on CPU, (height, width) per image.
lengths_list: flattened token length per image (height * width).
Returns:
(total_tokens, embed_dim) packed positional embeddings, concatenated
in the same order as `lengths_list`.
"""
assert spatial_shapes.device.type == "cpu"
embed_dim = positional_embeddings.shape[-1]
source_dtype = positional_embeddings.dtype
resulted_positional_embeddings = torch.empty(
(batch_size, max_length, embed_dim),
total_tokens = int(sum(lengths_list))
packed_pos_embeds = torch.empty(
(total_tokens, embed_dim),
device=positional_embeddings.device,
dtype=source_dtype,
)
# (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation
positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0)
# (height, width, embed_dim) -> (1, embed_dim, height, width)
pos_4d = positional_embeddings.permute(2, 0, 1).unsqueeze(0)
# Upcast to float32 on CPU because antialias is not supported for
# bfloat16/float16 on CPU
if positional_embeddings.device.type == "cpu":
positional_embeddings = positional_embeddings.to(torch.float32)
# bfloat16/float16 on CPU.
if pos_4d.device.type == "cpu":
pos_4d = pos_4d.to(torch.float32)
for i in range(batch_size):
# (1, dim, height, width) -> (1, dim, target_height, target_width)
height, width = spatial_shapes[i]
resized_embeddings = F.interpolate(
positional_embeddings,
offset = 0
for i, length in enumerate(lengths_list):
if length <= 0:
continue
height, width = spatial_shapes[i].tolist()
resized = F.interpolate(
pos_4d,
size=(height, width),
mode="bilinear",
align_corners=False,
antialias=True,
)
resized = resized.reshape(embed_dim, height * width).transpose(0, 1)
resized = resized.to(source_dtype)
packed_pos_embeds[offset : offset + length] = resized
offset += length
# (1, dim, target_height, target_width) ->
# (target_height * target_width, dim)
resized_embeddings = resized_embeddings.reshape(
embed_dim, height * width
).transpose(0, 1)
# Cast to original dtype
resized_embeddings = resized_embeddings.to(source_dtype)
resulted_positional_embeddings[i, : height * width] = resized_embeddings
resulted_positional_embeddings[i, height * width :] = resized_embeddings[0]
return resulted_positional_embeddings
def forward(
self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor
) -> torch.Tensor:
"""
Args:
pixel_values (`torch.FloatTensor`):
Pixel values of shape (batch_size, max_num_patches,
num_channels * patch_size * patch_size)
spatial_shapes (`list[tuple[int, int]]`):
Spatial shapes of shape (batch_size, 2) to resize the positional
embeddings to
"""
# Apply patch embeddings to already patchified pixel values
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
# Get positional resized and padded positional embeddings
positional_embeddings = self.position_embedding.weight.reshape(
self.position_embedding_size, self.position_embedding_size, -1
)
resized_positional_embeddings = self.resize_positional_embeddings(
positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1]
)
# Add positional embeddings to patch embeddings
embeddings = patch_embeds + resized_positional_embeddings
return embeddings
return packed_pos_embeds
class Siglip2Attention(nn.Module):
@@ -402,36 +414,23 @@ class Siglip2VisionTransformer(nn.Module):
def forward(
self,
pixel_values: torch.FloatTensor,
pixel_values_packed: torch.FloatTensor,
spatial_shapes: torch.LongTensor,
packed_mask: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: int | torch.Tensor,
max_seqlen: torch.Tensor,
) -> torch.Tensor:
r"""
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width)
of the input images.
of the input images.
"""
hidden_states = self.embeddings(pixel_values, spatial_shapes)
flat_mask = packed_mask.view(-1)
packed_indices = flat_mask.nonzero(as_tuple=True)[0]
flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
hidden_states = flat_hidden_states.index_select(0, packed_indices).unsqueeze(0)
hidden_states = self.embeddings(pixel_values_packed, spatial_shapes)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
unpacked = encoder_outputs.new_zeros(
packed_mask.numel(), encoder_outputs.shape[-1]
)
unpacked.index_copy_(0, packed_indices, encoder_outputs.squeeze(0))
encoder_outputs = unpacked.view(
packed_mask.shape + (encoder_outputs.shape[-1],)
)
last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state
return self.post_layernorm(encoder_outputs)
class Siglip2Model(torch.nn.Module):
@@ -453,16 +452,14 @@ class Siglip2Model(torch.nn.Module):
def forward(
self,
pixel_values: torch.FloatTensor,
pixel_values_packed: torch.FloatTensor,
spatial_shapes: torch.LongTensor,
packed_mask: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: int | torch.Tensor,
max_seqlen: torch.Tensor,
) -> torch.Tensor:
return self.vision_model(
pixel_values=pixel_values,
pixel_values_packed=pixel_values_packed,
spatial_shapes=spatial_shapes,
packed_mask=packed_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)

View File

@@ -50,7 +50,7 @@ from .interfaces import (
SupportsMultiModal,
SupportsPP,
)
from .siglip2 import Siglip2Model
from .lfm2_siglip2 import Siglip2Model
from .utils import (
AutoWeightsLoader,
WeightsMapper,
@@ -450,29 +450,78 @@ class Lfm2VLMultiModalProjector(nn.Module):
bias=config.projector_bias,
)
def forward(self, image_features: torch.Tensor):
image_features = self.pixel_unshuffle(image_features)
if self.projector_use_layernorm:
image_features = self.layer_norm(image_features)
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
def forward(
self,
vision_features_packed: torch.Tensor,
spatial_shapes: torch.Tensor,
) -> torch.Tensor:
"""Project packed vision features without materializing padded tensors.
def pixel_unshuffle(self, hidden_states: torch.Tensor):
batch_size, width, height, channels = hidden_states.size()
hidden_states = hidden_states.reshape(
batch_size, width, height // self.factor, channels * self.factor
Args:
vision_features_packed: (total_tokens, hidden_size) packed in tile order.
spatial_shapes: (num_tiles, 2) on CPU (height, width) per tile.
Returns:
projected_packed: (total_projected_tokens, text_hidden_size)
"""
assert spatial_shapes.device.type == "cpu", (
"Expected `spatial_shapes` on CPU to avoid device-to-host sync in "
"variable-length packing."
)
hidden_states = hidden_states.permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(
batch_size,
height // self.factor,
width // self.factor,
channels * self.factor**2,
)
hidden_states = hidden_states.permute(0, 2, 1, 3)
return hidden_states
factor = self.factor
device = vision_features_packed.device
hidden_size = vision_features_packed.shape[-1]
spatial_shapes_list: list[list[int]] = spatial_shapes.tolist()
lengths_list = [h * w for h, w in spatial_shapes_list]
gather_idx_parts: list[torch.Tensor] = []
offset = 0
dh = torch.arange(factor, dtype=torch.int64)
dw = torch.arange(factor, dtype=torch.int64)
dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij")
dh_flat = dh_grid.reshape(-1)
dw_flat = dw_grid.reshape(-1)
for (height, width), length in zip(spatial_shapes_list, lengths_list):
if length <= 0:
continue
if height % factor != 0 or width % factor != 0:
raise ValueError(
"spatial_shapes must be divisible by downsample_factor: "
f"got ({height}, {width}) with factor={factor}."
)
height_out = height // factor
width_out = width // factor
rows_out = torch.arange(height_out, dtype=torch.int64)
cols_out = torch.arange(width_out, dtype=torch.int64)
rr, cc = torch.meshgrid(rows_out, cols_out, indexing="ij")
rr = rr.reshape(-1)
cc = cc.reshape(-1)
token_idx = (rr[:, None] * factor + dh_flat[None, :]) * width + (
cc[:, None] * factor + dw_flat[None, :]
)
gather_idx_parts.append(token_idx.reshape(-1) + offset)
offset += length
if gather_idx_parts:
gather_idx = torch.cat(gather_idx_parts).to(device=device)
gathered = vision_features_packed.index_select(0, gather_idx)
unshuffled = gathered.reshape(-1, factor * factor * hidden_size)
else:
unshuffled = vision_features_packed.new_empty(
(0, factor * factor * hidden_size)
)
if self.projector_use_layernorm:
unshuffled = self.layer_norm(unshuffled)
hidden_states = self.linear_1(unshuffled)
hidden_states = self.act(hidden_states)
projected_packed = self.linear_2(hidden_states)
return projected_packed
@MULTIMODAL_REGISTRY.register_processor(
@@ -598,61 +647,90 @@ class Lfm2VLForConditionalGeneration(
pixel_values: torch.FloatTensor,
spatial_shapes: torch.Tensor,
) -> torch.Tensor:
assert spatial_shapes.device.type == "cpu", (
"Expected `spatial_shapes` on CPU to avoid device-to-host sync in "
"variable-length packing."
)
pixel_values = pixel_values.to(
dtype=self.vision_tower.vision_model.embeddings.patch_embedding.weight.dtype
) # fp16 compatibility
# LFM2-VL's HF processor pads patch sequences with trailing zeros.
# Derive the valid-patch mask from spatial_shapes instead of carrying
# pixel_attention_mask through the vLLM multimodal pipeline.
max_seq_len = pixel_values.shape[1]
# Pack patch tokens upfront so the vision tower runs entirely unpadded.
spatial_shapes_list: list[list[int]] = spatial_shapes.tolist()
lengths_list = [h * w for h, w in spatial_shapes_list]
total_tokens = int(sum(lengths_list))
lengths_cpu = (spatial_shapes[:, 0] * spatial_shapes[:, 1]).to(
dtype=torch.int32
)
max_seqlen = (
lengths_cpu.max().reshape(1).to(device=pixel_values.device)
lengths_cpu.max().reshape(1)
if lengths_cpu.numel()
else torch.tensor([0], dtype=torch.int32, device=pixel_values.device)
else torch.tensor([0], dtype=torch.int32)
)
lengths = lengths_cpu.to(device=pixel_values.device)
packed_mask = (
torch.arange(max_seq_len, device=pixel_values.device)[None, :]
< lengths[:, None]
if total_tokens == 0:
return []
packed_pixel_values = pixel_values.new_empty(
(total_tokens, pixel_values.shape[-1])
)
offset = 0
for i, length in enumerate(lengths_list):
if length <= 0:
continue
packed_pixel_values[offset : offset + length].copy_(
pixel_values[i, :length]
)
offset += length
packed_pixel_values = packed_pixel_values.unsqueeze(0)
lengths = torch.tensor(
lengths_list, dtype=torch.int32, device=pixel_values.device
)
cu_seqlens = torch.zeros(
lengths.shape[0] + 1,
dtype=torch.int32,
device=lengths.device,
device=pixel_values.device,
)
cu_seqlens[1:] = torch.cumsum(lengths, dim=0)
with set_forward_context(None, self.vllm_config):
vision_outputs = self.vision_tower(
pixel_values=pixel_values,
pixel_values_packed=packed_pixel_values,
spatial_shapes=spatial_shapes,
packed_mask=packed_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
image_outputs = getattr(vision_outputs, "last_hidden_state", vision_outputs)
image_outputs_packed = getattr(
vision_outputs, "last_hidden_state", vision_outputs
)
vision_features_packed = image_outputs_packed[0]
image_features = []
factor = self.multi_modal_projector.factor
projected_lengths_list: list[int] = []
for (height, width), length in zip(spatial_shapes_list, lengths_list):
if length <= 0:
projected_lengths_list.append(0)
continue
if height % factor != 0 or width % factor != 0:
raise ValueError(
"spatial_shapes must be divisible by downsample_factor: "
f"got ({height}, {width}) with factor={factor}."
)
projected_lengths_list.append((height // factor) * (width // factor))
# spatial_shapes is on CPU (keep_on_cpu=True), so .tolist() is instant
spatial_shapes_list = spatial_shapes.tolist()
for img_idx, (feature_org_h, feature_org_w) in enumerate(spatial_shapes_list):
feature_len = feature_org_h * feature_org_w
feature = image_outputs[img_idx, :feature_len]
projected_packed = self.multi_modal_projector(
vision_features_packed=vision_features_packed,
spatial_shapes=spatial_shapes,
)
# reshape to original height and width
feature = feature.reshape(1, feature_org_h, feature_org_w, -1)
# project the image representation
img_embedding = self.multi_modal_projector(feature)
# flatten here to handle variable length in naflex
img_embedding = img_embedding.reshape(-1, img_embedding.size(-1))
image_features.append(img_embedding)
image_features: list[torch.Tensor] = []
offset = 0
for out_len in projected_lengths_list:
image_features.append(projected_packed[offset : offset + out_len])
offset += out_len
return image_features

View File

@@ -155,9 +155,11 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
m = common_attn_metadata
query_start_loc = m.query_start_loc
query_start_loc_cpu = m.query_start_loc_cpu
context_lens_tensor = m.compute_num_computed_tokens()
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
spec_sequence_masks_cpu: torch.Tensor | None = None
if (
not self.use_spec_decode
or num_decode_draft_tokens_cpu is None
@@ -169,12 +171,13 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
spec_sequence_masks = None
num_spec_decodes = 0
else:
spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
num_spec_decodes = spec_sequence_masks.sum().item()
spec_sequence_masks_cpu = num_decode_draft_tokens_cpu >= 0
num_spec_decodes = spec_sequence_masks_cpu.sum().item()
if num_spec_decodes == 0:
spec_sequence_masks = None
spec_sequence_masks_cpu = None
else:
spec_sequence_masks = spec_sequence_masks.to(
spec_sequence_masks = spec_sequence_masks_cpu.to(
query_start_loc.device, non_blocking=True
)
@@ -189,9 +192,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
spec_query_start_loc = None
non_spec_query_start_loc = query_start_loc
non_spec_query_start_loc_cpu = query_start_loc_cpu
num_accepted_tokens = None
else:
query_lens = query_start_loc[1:] - query_start_loc[:-1]
assert spec_sequence_masks_cpu is not None
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
non_spec_query_lens = query_lens[~spec_sequence_masks]
num_decodes = (non_spec_query_lens == 1).sum().item()
@@ -219,6 +225,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
non_spec_state_indices_tensor = None
spec_query_start_loc = query_start_loc
non_spec_query_start_loc = None
non_spec_query_start_loc_cpu = None
else:
spec_token_masks = torch.repeat_interleave(
spec_sequence_masks, query_lens
@@ -253,6 +260,15 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
dim=0,
out=non_spec_query_start_loc[1:],
)
non_spec_query_start_loc_cpu = torch.zeros(
query_lens_cpu.size(0) - num_spec_decodes + 1,
dtype=torch.int32,
)
torch.cumsum(
query_lens_cpu[~spec_sequence_masks_cpu],
dim=0,
out=non_spec_query_start_loc_cpu[1:],
)
assert num_accepted_tokens is not None
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
@@ -261,8 +277,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
has_initial_state = context_lens_tensor > 0
if spec_sequence_masks is not None:
has_initial_state = has_initial_state[~spec_sequence_masks]
assert non_spec_query_start_loc_cpu is not None
nums_dict, batch_ptr, token_chunk_offset_ptr = (
compute_causal_conv1d_metadata(non_spec_query_start_loc)
compute_causal_conv1d_metadata(
non_spec_query_start_loc_cpu,
device=query_start_loc.device,
)
)
else:
has_initial_state = None

View File

@@ -219,21 +219,24 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
if num_prefills > 0:
if num_computed_tokens is None:
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
num_computed_tokens_cpu = num_computed_tokens.cpu()
query_start_loc_p_cpu = (
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
- num_decode_tokens
)
query_start_loc_p = (
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
- num_decode_tokens
)
has_initial_states_cpu = (
num_computed_tokens_cpu[num_reqs - num_prefills : num_reqs] > 0
)
has_initial_states_p = has_initial_states_cpu.to(
common_attn_metadata.query_start_loc.device
has_initial_states_p = (
num_computed_tokens[num_reqs - num_prefills : num_reqs] > 0
)
nums_dict, batch_ptr, token_chunk_offset_ptr = (
compute_causal_conv1d_metadata(query_start_loc_p)
compute_causal_conv1d_metadata(
query_start_loc_p_cpu,
device=common_attn_metadata.query_start_loc.device,
)
)
if self.vllm_config.cache_config.enable_prefix_caching:

View File

@@ -732,13 +732,17 @@ def create_fast_prefill_custom_backend(
return attn_backend
def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
# Needed for causal_conv1d
seqlens = query_start_loc_p.diff().to("cpu")
def compute_causal_conv1d_metadata(
query_start_loc_p_cpu: torch.Tensor,
*,
device: torch.device,
):
# Needed for causal_conv1d. Use the CPU query_start_loc to avoid DtoH sync.
assert query_start_loc_p_cpu.device.type == "cpu"
seqlens = query_start_loc_p_cpu.diff()
nums_dict = {} # type: ignore
batch_ptr = None
token_chunk_offset_ptr = None
device = query_start_loc_p.device
for BLOCK_M in [8]: # cover all BLOCK_M values
nums = -(-seqlens // BLOCK_M)
nums_dict[BLOCK_M] = {}