[Feature]: Remove DtoH Copy for lfm2_vl On Default Stream (#32815)
Signed-off-by: Tianshu Yu <tianshuyu.formal@gmail.com>
This commit is contained in:
committed by
GitHub
parent
10e94c84f6
commit
13d8746c54
@@ -40,99 +40,111 @@ class Siglip2VisionEmbeddings(nn.Module):
|
||||
self.position_embedding_size = int(self.num_patches**0.5)
|
||||
self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
|
||||
|
||||
@staticmethod
|
||||
def resize_positional_embeddings(
|
||||
positional_embeddings: torch.Tensor,
|
||||
def forward(
|
||||
self,
|
||||
pixel_values_packed: torch.FloatTensor,
|
||||
spatial_shapes: torch.LongTensor,
|
||||
max_length: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Resize positional embeddings to image-specific size and pad to a fixed size.
|
||||
"""Embed patchified pixel values in packed (unpadded) form.
|
||||
|
||||
Args:
|
||||
positional_embeddings (`torch.Tensor`):
|
||||
Position embeddings of shape (height, width, embed_dim)
|
||||
spatial_shapes (`torch.LongTensor`):
|
||||
Spatial shapes of shape (batch_size, 2) to resize the positional
|
||||
embeddings to
|
||||
max_length (`int`):
|
||||
Maximum length of the positional embeddings to pad resized
|
||||
positional embeddings to
|
||||
pixel_values_packed: (1, total_tokens, patch_dim) or
|
||||
(total_tokens, patch_dim), packed in tile order.
|
||||
spatial_shapes: (num_tiles, 2) on CPU (height, width) per tile.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
|
||||
(1, total_tokens, embed_dim) packed embeddings.
|
||||
"""
|
||||
batch_size = spatial_shapes.shape[0]
|
||||
assert spatial_shapes.device.type == "cpu", (
|
||||
"Expected `spatial_shapes` on CPU to avoid device-to-host sync in "
|
||||
"variable-length packing."
|
||||
)
|
||||
|
||||
if pixel_values_packed.dim() == 3:
|
||||
assert pixel_values_packed.shape[0] == 1
|
||||
pixel_values_flat = pixel_values_packed[0]
|
||||
else:
|
||||
pixel_values_flat = pixel_values_packed
|
||||
|
||||
lengths = (spatial_shapes[:, 0] * spatial_shapes[:, 1]).to(dtype=torch.int64)
|
||||
lengths_list = lengths.tolist()
|
||||
total_tokens = int(sum(lengths_list))
|
||||
if total_tokens != pixel_values_flat.shape[0]:
|
||||
raise ValueError(
|
||||
"Packed pixel_values token count does not match spatial_shapes: "
|
||||
f"{pixel_values_flat.shape[0]} vs {total_tokens}."
|
||||
)
|
||||
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
patch_embeds = self.patch_embedding(pixel_values_flat.to(dtype=target_dtype))
|
||||
|
||||
positional_embeddings = self.position_embedding.weight.reshape(
|
||||
self.position_embedding_size, self.position_embedding_size, -1
|
||||
)
|
||||
packed_pos_embeds = self.resize_positional_embeddings_packed(
|
||||
positional_embeddings,
|
||||
spatial_shapes,
|
||||
lengths_list=lengths_list,
|
||||
)
|
||||
|
||||
embeddings = patch_embeds + packed_pos_embeds
|
||||
return embeddings.unsqueeze(0)
|
||||
|
||||
@staticmethod
|
||||
def resize_positional_embeddings_packed(
|
||||
positional_embeddings: torch.Tensor,
|
||||
spatial_shapes: torch.LongTensor,
|
||||
lengths_list: list[int],
|
||||
) -> torch.Tensor:
|
||||
"""Resize positional embeddings per image and return a packed tensor.
|
||||
|
||||
Args:
|
||||
positional_embeddings: (height, width, embed_dim) base grid.
|
||||
spatial_shapes: (batch_size, 2) on CPU, (height, width) per image.
|
||||
lengths_list: flattened token length per image (height * width).
|
||||
|
||||
Returns:
|
||||
(total_tokens, embed_dim) packed positional embeddings, concatenated
|
||||
in the same order as `lengths_list`.
|
||||
"""
|
||||
assert spatial_shapes.device.type == "cpu"
|
||||
|
||||
embed_dim = positional_embeddings.shape[-1]
|
||||
source_dtype = positional_embeddings.dtype
|
||||
|
||||
resulted_positional_embeddings = torch.empty(
|
||||
(batch_size, max_length, embed_dim),
|
||||
total_tokens = int(sum(lengths_list))
|
||||
packed_pos_embeds = torch.empty(
|
||||
(total_tokens, embed_dim),
|
||||
device=positional_embeddings.device,
|
||||
dtype=source_dtype,
|
||||
)
|
||||
|
||||
# (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation
|
||||
positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0)
|
||||
# (height, width, embed_dim) -> (1, embed_dim, height, width)
|
||||
pos_4d = positional_embeddings.permute(2, 0, 1).unsqueeze(0)
|
||||
|
||||
# Upcast to float32 on CPU because antialias is not supported for
|
||||
# bfloat16/float16 on CPU
|
||||
if positional_embeddings.device.type == "cpu":
|
||||
positional_embeddings = positional_embeddings.to(torch.float32)
|
||||
# bfloat16/float16 on CPU.
|
||||
if pos_4d.device.type == "cpu":
|
||||
pos_4d = pos_4d.to(torch.float32)
|
||||
|
||||
for i in range(batch_size):
|
||||
# (1, dim, height, width) -> (1, dim, target_height, target_width)
|
||||
height, width = spatial_shapes[i]
|
||||
resized_embeddings = F.interpolate(
|
||||
positional_embeddings,
|
||||
offset = 0
|
||||
for i, length in enumerate(lengths_list):
|
||||
if length <= 0:
|
||||
continue
|
||||
height, width = spatial_shapes[i].tolist()
|
||||
resized = F.interpolate(
|
||||
pos_4d,
|
||||
size=(height, width),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
antialias=True,
|
||||
)
|
||||
resized = resized.reshape(embed_dim, height * width).transpose(0, 1)
|
||||
resized = resized.to(source_dtype)
|
||||
packed_pos_embeds[offset : offset + length] = resized
|
||||
offset += length
|
||||
|
||||
# (1, dim, target_height, target_width) ->
|
||||
# (target_height * target_width, dim)
|
||||
resized_embeddings = resized_embeddings.reshape(
|
||||
embed_dim, height * width
|
||||
).transpose(0, 1)
|
||||
|
||||
# Cast to original dtype
|
||||
resized_embeddings = resized_embeddings.to(source_dtype)
|
||||
|
||||
resulted_positional_embeddings[i, : height * width] = resized_embeddings
|
||||
resulted_positional_embeddings[i, height * width :] = resized_embeddings[0]
|
||||
|
||||
return resulted_positional_embeddings
|
||||
|
||||
def forward(
|
||||
self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor`):
|
||||
Pixel values of shape (batch_size, max_num_patches,
|
||||
num_channels * patch_size * patch_size)
|
||||
spatial_shapes (`list[tuple[int, int]]`):
|
||||
Spatial shapes of shape (batch_size, 2) to resize the positional
|
||||
embeddings to
|
||||
"""
|
||||
|
||||
# Apply patch embeddings to already patchified pixel values
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
|
||||
|
||||
# Get positional resized and padded positional embeddings
|
||||
positional_embeddings = self.position_embedding.weight.reshape(
|
||||
self.position_embedding_size, self.position_embedding_size, -1
|
||||
)
|
||||
resized_positional_embeddings = self.resize_positional_embeddings(
|
||||
positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1]
|
||||
)
|
||||
|
||||
# Add positional embeddings to patch embeddings
|
||||
embeddings = patch_embeds + resized_positional_embeddings
|
||||
return embeddings
|
||||
return packed_pos_embeds
|
||||
|
||||
|
||||
class Siglip2Attention(nn.Module):
|
||||
@@ -402,36 +414,23 @@ class Siglip2VisionTransformer(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
pixel_values_packed: torch.FloatTensor,
|
||||
spatial_shapes: torch.LongTensor,
|
||||
packed_mask: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
max_seqlen: int | torch.Tensor,
|
||||
max_seqlen: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
|
||||
Tensor containing the spatial dimensions (height, width)
|
||||
of the input images.
|
||||
of the input images.
|
||||
"""
|
||||
hidden_states = self.embeddings(pixel_values, spatial_shapes)
|
||||
flat_mask = packed_mask.view(-1)
|
||||
packed_indices = flat_mask.nonzero(as_tuple=True)[0]
|
||||
flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
hidden_states = flat_hidden_states.index_select(0, packed_indices).unsqueeze(0)
|
||||
hidden_states = self.embeddings(pixel_values_packed, spatial_shapes)
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
unpacked = encoder_outputs.new_zeros(
|
||||
packed_mask.numel(), encoder_outputs.shape[-1]
|
||||
)
|
||||
unpacked.index_copy_(0, packed_indices, encoder_outputs.squeeze(0))
|
||||
encoder_outputs = unpacked.view(
|
||||
packed_mask.shape + (encoder_outputs.shape[-1],)
|
||||
)
|
||||
last_hidden_state = self.post_layernorm(encoder_outputs)
|
||||
return last_hidden_state
|
||||
return self.post_layernorm(encoder_outputs)
|
||||
|
||||
|
||||
class Siglip2Model(torch.nn.Module):
|
||||
@@ -453,16 +452,14 @@ class Siglip2Model(torch.nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
pixel_values_packed: torch.FloatTensor,
|
||||
spatial_shapes: torch.LongTensor,
|
||||
packed_mask: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
max_seqlen: int | torch.Tensor,
|
||||
max_seqlen: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
pixel_values_packed=pixel_values_packed,
|
||||
spatial_shapes=spatial_shapes,
|
||||
packed_mask=packed_mask,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
@@ -50,7 +50,7 @@ from .interfaces import (
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
from .siglip2 import Siglip2Model
|
||||
from .lfm2_siglip2 import Siglip2Model
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
@@ -450,29 +450,78 @@ class Lfm2VLMultiModalProjector(nn.Module):
|
||||
bias=config.projector_bias,
|
||||
)
|
||||
|
||||
def forward(self, image_features: torch.Tensor):
|
||||
image_features = self.pixel_unshuffle(image_features)
|
||||
if self.projector_use_layernorm:
|
||||
image_features = self.layer_norm(image_features)
|
||||
hidden_states = self.linear_1(image_features)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
def forward(
|
||||
self,
|
||||
vision_features_packed: torch.Tensor,
|
||||
spatial_shapes: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Project packed vision features without materializing padded tensors.
|
||||
|
||||
def pixel_unshuffle(self, hidden_states: torch.Tensor):
|
||||
batch_size, width, height, channels = hidden_states.size()
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, width, height // self.factor, channels * self.factor
|
||||
Args:
|
||||
vision_features_packed: (total_tokens, hidden_size) packed in tile order.
|
||||
spatial_shapes: (num_tiles, 2) on CPU (height, width) per tile.
|
||||
|
||||
Returns:
|
||||
projected_packed: (total_projected_tokens, text_hidden_size)
|
||||
"""
|
||||
assert spatial_shapes.device.type == "cpu", (
|
||||
"Expected `spatial_shapes` on CPU to avoid device-to-host sync in "
|
||||
"variable-length packing."
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3)
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size,
|
||||
height // self.factor,
|
||||
width // self.factor,
|
||||
channels * self.factor**2,
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3)
|
||||
return hidden_states
|
||||
factor = self.factor
|
||||
device = vision_features_packed.device
|
||||
hidden_size = vision_features_packed.shape[-1]
|
||||
|
||||
spatial_shapes_list: list[list[int]] = spatial_shapes.tolist()
|
||||
lengths_list = [h * w for h, w in spatial_shapes_list]
|
||||
|
||||
gather_idx_parts: list[torch.Tensor] = []
|
||||
offset = 0
|
||||
|
||||
dh = torch.arange(factor, dtype=torch.int64)
|
||||
dw = torch.arange(factor, dtype=torch.int64)
|
||||
dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij")
|
||||
dh_flat = dh_grid.reshape(-1)
|
||||
dw_flat = dw_grid.reshape(-1)
|
||||
|
||||
for (height, width), length in zip(spatial_shapes_list, lengths_list):
|
||||
if length <= 0:
|
||||
continue
|
||||
if height % factor != 0 or width % factor != 0:
|
||||
raise ValueError(
|
||||
"spatial_shapes must be divisible by downsample_factor: "
|
||||
f"got ({height}, {width}) with factor={factor}."
|
||||
)
|
||||
height_out = height // factor
|
||||
width_out = width // factor
|
||||
|
||||
rows_out = torch.arange(height_out, dtype=torch.int64)
|
||||
cols_out = torch.arange(width_out, dtype=torch.int64)
|
||||
rr, cc = torch.meshgrid(rows_out, cols_out, indexing="ij")
|
||||
rr = rr.reshape(-1)
|
||||
cc = cc.reshape(-1)
|
||||
|
||||
token_idx = (rr[:, None] * factor + dh_flat[None, :]) * width + (
|
||||
cc[:, None] * factor + dw_flat[None, :]
|
||||
)
|
||||
gather_idx_parts.append(token_idx.reshape(-1) + offset)
|
||||
offset += length
|
||||
|
||||
if gather_idx_parts:
|
||||
gather_idx = torch.cat(gather_idx_parts).to(device=device)
|
||||
gathered = vision_features_packed.index_select(0, gather_idx)
|
||||
unshuffled = gathered.reshape(-1, factor * factor * hidden_size)
|
||||
else:
|
||||
unshuffled = vision_features_packed.new_empty(
|
||||
(0, factor * factor * hidden_size)
|
||||
)
|
||||
|
||||
if self.projector_use_layernorm:
|
||||
unshuffled = self.layer_norm(unshuffled)
|
||||
hidden_states = self.linear_1(unshuffled)
|
||||
hidden_states = self.act(hidden_states)
|
||||
projected_packed = self.linear_2(hidden_states)
|
||||
return projected_packed
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
@@ -598,61 +647,90 @@ class Lfm2VLForConditionalGeneration(
|
||||
pixel_values: torch.FloatTensor,
|
||||
spatial_shapes: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert spatial_shapes.device.type == "cpu", (
|
||||
"Expected `spatial_shapes` on CPU to avoid device-to-host sync in "
|
||||
"variable-length packing."
|
||||
)
|
||||
|
||||
pixel_values = pixel_values.to(
|
||||
dtype=self.vision_tower.vision_model.embeddings.patch_embedding.weight.dtype
|
||||
) # fp16 compatibility
|
||||
|
||||
# LFM2-VL's HF processor pads patch sequences with trailing zeros.
|
||||
# Derive the valid-patch mask from spatial_shapes instead of carrying
|
||||
# pixel_attention_mask through the vLLM multimodal pipeline.
|
||||
max_seq_len = pixel_values.shape[1]
|
||||
# Pack patch tokens upfront so the vision tower runs entirely unpadded.
|
||||
spatial_shapes_list: list[list[int]] = spatial_shapes.tolist()
|
||||
lengths_list = [h * w for h, w in spatial_shapes_list]
|
||||
total_tokens = int(sum(lengths_list))
|
||||
lengths_cpu = (spatial_shapes[:, 0] * spatial_shapes[:, 1]).to(
|
||||
dtype=torch.int32
|
||||
)
|
||||
max_seqlen = (
|
||||
lengths_cpu.max().reshape(1).to(device=pixel_values.device)
|
||||
lengths_cpu.max().reshape(1)
|
||||
if lengths_cpu.numel()
|
||||
else torch.tensor([0], dtype=torch.int32, device=pixel_values.device)
|
||||
else torch.tensor([0], dtype=torch.int32)
|
||||
)
|
||||
lengths = lengths_cpu.to(device=pixel_values.device)
|
||||
packed_mask = (
|
||||
torch.arange(max_seq_len, device=pixel_values.device)[None, :]
|
||||
< lengths[:, None]
|
||||
|
||||
if total_tokens == 0:
|
||||
return []
|
||||
|
||||
packed_pixel_values = pixel_values.new_empty(
|
||||
(total_tokens, pixel_values.shape[-1])
|
||||
)
|
||||
offset = 0
|
||||
for i, length in enumerate(lengths_list):
|
||||
if length <= 0:
|
||||
continue
|
||||
packed_pixel_values[offset : offset + length].copy_(
|
||||
pixel_values[i, :length]
|
||||
)
|
||||
offset += length
|
||||
packed_pixel_values = packed_pixel_values.unsqueeze(0)
|
||||
|
||||
lengths = torch.tensor(
|
||||
lengths_list, dtype=torch.int32, device=pixel_values.device
|
||||
)
|
||||
cu_seqlens = torch.zeros(
|
||||
lengths.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=lengths.device,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
cu_seqlens[1:] = torch.cumsum(lengths, dim=0)
|
||||
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
vision_outputs = self.vision_tower(
|
||||
pixel_values=pixel_values,
|
||||
pixel_values_packed=packed_pixel_values,
|
||||
spatial_shapes=spatial_shapes,
|
||||
packed_mask=packed_mask,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
image_outputs = getattr(vision_outputs, "last_hidden_state", vision_outputs)
|
||||
image_outputs_packed = getattr(
|
||||
vision_outputs, "last_hidden_state", vision_outputs
|
||||
)
|
||||
vision_features_packed = image_outputs_packed[0]
|
||||
|
||||
image_features = []
|
||||
factor = self.multi_modal_projector.factor
|
||||
projected_lengths_list: list[int] = []
|
||||
for (height, width), length in zip(spatial_shapes_list, lengths_list):
|
||||
if length <= 0:
|
||||
projected_lengths_list.append(0)
|
||||
continue
|
||||
if height % factor != 0 or width % factor != 0:
|
||||
raise ValueError(
|
||||
"spatial_shapes must be divisible by downsample_factor: "
|
||||
f"got ({height}, {width}) with factor={factor}."
|
||||
)
|
||||
projected_lengths_list.append((height // factor) * (width // factor))
|
||||
|
||||
# spatial_shapes is on CPU (keep_on_cpu=True), so .tolist() is instant
|
||||
spatial_shapes_list = spatial_shapes.tolist()
|
||||
for img_idx, (feature_org_h, feature_org_w) in enumerate(spatial_shapes_list):
|
||||
feature_len = feature_org_h * feature_org_w
|
||||
feature = image_outputs[img_idx, :feature_len]
|
||||
projected_packed = self.multi_modal_projector(
|
||||
vision_features_packed=vision_features_packed,
|
||||
spatial_shapes=spatial_shapes,
|
||||
)
|
||||
|
||||
# reshape to original height and width
|
||||
feature = feature.reshape(1, feature_org_h, feature_org_w, -1)
|
||||
|
||||
# project the image representation
|
||||
img_embedding = self.multi_modal_projector(feature)
|
||||
|
||||
# flatten here to handle variable length in naflex
|
||||
img_embedding = img_embedding.reshape(-1, img_embedding.size(-1))
|
||||
image_features.append(img_embedding)
|
||||
image_features: list[torch.Tensor] = []
|
||||
offset = 0
|
||||
for out_len in projected_lengths_list:
|
||||
image_features.append(projected_packed[offset : offset + out_len])
|
||||
offset += out_len
|
||||
|
||||
return image_features
|
||||
|
||||
|
||||
@@ -155,9 +155,11 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
m = common_attn_metadata
|
||||
|
||||
query_start_loc = m.query_start_loc
|
||||
query_start_loc_cpu = m.query_start_loc_cpu
|
||||
context_lens_tensor = m.compute_num_computed_tokens()
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
spec_sequence_masks_cpu: torch.Tensor | None = None
|
||||
if (
|
||||
not self.use_spec_decode
|
||||
or num_decode_draft_tokens_cpu is None
|
||||
@@ -169,12 +171,13 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
spec_sequence_masks = None
|
||||
num_spec_decodes = 0
|
||||
else:
|
||||
spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
|
||||
num_spec_decodes = spec_sequence_masks.sum().item()
|
||||
spec_sequence_masks_cpu = num_decode_draft_tokens_cpu >= 0
|
||||
num_spec_decodes = spec_sequence_masks_cpu.sum().item()
|
||||
if num_spec_decodes == 0:
|
||||
spec_sequence_masks = None
|
||||
spec_sequence_masks_cpu = None
|
||||
else:
|
||||
spec_sequence_masks = spec_sequence_masks.to(
|
||||
spec_sequence_masks = spec_sequence_masks_cpu.to(
|
||||
query_start_loc.device, non_blocking=True
|
||||
)
|
||||
|
||||
@@ -189,9 +192,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
|
||||
spec_query_start_loc = None
|
||||
non_spec_query_start_loc = query_start_loc
|
||||
non_spec_query_start_loc_cpu = query_start_loc_cpu
|
||||
num_accepted_tokens = None
|
||||
else:
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
assert spec_sequence_masks_cpu is not None
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
|
||||
non_spec_query_lens = query_lens[~spec_sequence_masks]
|
||||
num_decodes = (non_spec_query_lens == 1).sum().item()
|
||||
@@ -219,6 +225,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
non_spec_state_indices_tensor = None
|
||||
spec_query_start_loc = query_start_loc
|
||||
non_spec_query_start_loc = None
|
||||
non_spec_query_start_loc_cpu = None
|
||||
else:
|
||||
spec_token_masks = torch.repeat_interleave(
|
||||
spec_sequence_masks, query_lens
|
||||
@@ -253,6 +260,15 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
dim=0,
|
||||
out=non_spec_query_start_loc[1:],
|
||||
)
|
||||
non_spec_query_start_loc_cpu = torch.zeros(
|
||||
query_lens_cpu.size(0) - num_spec_decodes + 1,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
torch.cumsum(
|
||||
query_lens_cpu[~spec_sequence_masks_cpu],
|
||||
dim=0,
|
||||
out=non_spec_query_start_loc_cpu[1:],
|
||||
)
|
||||
|
||||
assert num_accepted_tokens is not None
|
||||
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
|
||||
@@ -261,8 +277,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
has_initial_state = context_lens_tensor > 0
|
||||
if spec_sequence_masks is not None:
|
||||
has_initial_state = has_initial_state[~spec_sequence_masks]
|
||||
assert non_spec_query_start_loc_cpu is not None
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = (
|
||||
compute_causal_conv1d_metadata(non_spec_query_start_loc)
|
||||
compute_causal_conv1d_metadata(
|
||||
non_spec_query_start_loc_cpu,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
has_initial_state = None
|
||||
|
||||
@@ -219,21 +219,24 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
if num_prefills > 0:
|
||||
if num_computed_tokens is None:
|
||||
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
|
||||
num_computed_tokens_cpu = num_computed_tokens.cpu()
|
||||
|
||||
query_start_loc_p_cpu = (
|
||||
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
query_start_loc_p = (
|
||||
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
has_initial_states_cpu = (
|
||||
num_computed_tokens_cpu[num_reqs - num_prefills : num_reqs] > 0
|
||||
)
|
||||
has_initial_states_p = has_initial_states_cpu.to(
|
||||
common_attn_metadata.query_start_loc.device
|
||||
has_initial_states_p = (
|
||||
num_computed_tokens[num_reqs - num_prefills : num_reqs] > 0
|
||||
)
|
||||
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = (
|
||||
compute_causal_conv1d_metadata(query_start_loc_p)
|
||||
compute_causal_conv1d_metadata(
|
||||
query_start_loc_p_cpu,
|
||||
device=common_attn_metadata.query_start_loc.device,
|
||||
)
|
||||
)
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
|
||||
@@ -732,13 +732,17 @@ def create_fast_prefill_custom_backend(
|
||||
return attn_backend
|
||||
|
||||
|
||||
def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
|
||||
# Needed for causal_conv1d
|
||||
seqlens = query_start_loc_p.diff().to("cpu")
|
||||
def compute_causal_conv1d_metadata(
|
||||
query_start_loc_p_cpu: torch.Tensor,
|
||||
*,
|
||||
device: torch.device,
|
||||
):
|
||||
# Needed for causal_conv1d. Use the CPU query_start_loc to avoid DtoH sync.
|
||||
assert query_start_loc_p_cpu.device.type == "cpu"
|
||||
seqlens = query_start_loc_p_cpu.diff()
|
||||
nums_dict = {} # type: ignore
|
||||
batch_ptr = None
|
||||
token_chunk_offset_ptr = None
|
||||
device = query_start_loc_p.device
|
||||
for BLOCK_M in [8]: # cover all BLOCK_M values
|
||||
nums = -(-seqlens // BLOCK_M)
|
||||
nums_dict[BLOCK_M] = {}
|
||||
|
||||
Reference in New Issue
Block a user