[V1] Add V1 support of Qwen2-VL (#12128)

Signed-off-by: Roger Wang <ywang@roblox.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: imkero <kerorek@outlook.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Roger Wang
2025-01-19 03:52:13 -08:00
committed by GitHub
parent edaae198e7
commit 81763c58a0
9 changed files with 291 additions and 84 deletions

View File

@@ -30,6 +30,9 @@ class CachedRequestState:
num_computed_tokens: int
output_token_ids: List[int]
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[int] = None
@property
def num_tokens(self) -> int:
return len(self.prompt_token_ids) + len(self.output_token_ids)

View File

@@ -14,6 +14,7 @@ from vllm.distributed.parallel_state import graph_capture
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.sampling_params import SamplingType
@@ -139,6 +140,32 @@ class GPUModelRunner:
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.model_config.uses_mrope:
# NOTE: `mrope_positions` is implemented as a permuted tensor to
# satisfy the following properties to allow `torch.compile` to work
# properly:
# - shape: (3, <variable>)
# - stride: (1, 3)
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1921022256
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
# the modality of inputs. For text-only inputs, each dimension has
# identical position IDs, making M-RoPE functionally equivalent to
# 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191
self.mrope_positions = torch.zeros((self.max_num_tokens, 3),
dtype=torch.int64,
device=self.device)
self.mrope_positions_cpu = torch.zeros((self.max_num_tokens, 3),
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.mrope_positions = self.mrope_positions.permute((1, 0))
self.mrope_positions_cpu = self.mrope_positions_cpu.permute((1, 0))
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
@@ -246,6 +273,35 @@ class GPUModelRunner:
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.model_config.uses_mrope:
image_grid_thw = []
video_grid_thw = []
for mm_input in self.requests[req_id].mm_inputs:
if mm_input.get("image_grid_thw") is not None:
image_grid_thw.extend(
mm_input["image_grid_thw"].tolist())
if mm_input.get("video_grid_thw") is not None:
video_grid_thw.extend(
mm_input["video_grid_thw"].tolist())
hf_config = self.model_config.hf_config
self.requests[req_id].mrope_positions, \
self.requests[req_id].mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
self.requests[req_id].prompt_token_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.
spatial_merge_size,
)
req_ids_to_add.append(req_id)
# Update the cached states of the resumed requests.
@@ -313,6 +369,11 @@ class GPUModelRunner:
arange,
out=positions_np)
# Calculate M-RoPE positions.
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.model_config.uses_mrope:
self._calc_mrope_positions(scheduler_output)
# Get token indices.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
@@ -359,8 +420,16 @@ class GPUModelRunner:
# Copy the tensors to the GPU.
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
self.positions[:total_num_scheduled_tokens].copy_(
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
if self.model_config.uses_mrope:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
non_blocking=True)
else:
# Common case (1D positions)
self.positions[:total_num_scheduled_tokens].copy_(
self.positions_cpu[:total_num_scheduled_tokens],
non_blocking=True)
query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
self.device, non_blocking=True)
seq_start_loc = self.seq_start_loc_cpu[:num_reqs + 1].to(
@@ -472,6 +541,61 @@ class GPUModelRunner:
logits_indices = query_start_loc[1:] - 1
return attn_metadata, logits_indices
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
mrope_pos_ptr = 0
num_reqs = self.input_batch.num_reqs
for index, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
assert req_id is not None
req = self.requests[req_id]
assert req.mrope_positions is not None
num_computed_tokens = \
self.input_batch.num_computed_tokens_cpu[index]
num_scheduled_tokens = \
scheduler_output.num_scheduled_tokens[req_id]
num_prompt_tokens = len(req.prompt_token_ids)
if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
prompt_part_len = max(0,
num_prompt_tokens - num_computed_tokens)
completion_part_len = max(
0, num_scheduled_tokens - prompt_part_len)
else:
prompt_part_len = num_scheduled_tokens
completion_part_len = 0
assert num_scheduled_tokens == prompt_part_len + completion_part_len
if prompt_part_len > 0:
# prompt's mrope_positions are pre-computed
dst_start = mrope_pos_ptr
dst_end = mrope_pos_ptr + prompt_part_len
src_start = num_computed_tokens
src_end = num_computed_tokens + prompt_part_len
self.mrope_positions_cpu[:, dst_start:dst_end] = \
req.mrope_positions[:,src_start:src_end]
mrope_pos_ptr += prompt_part_len
if completion_part_len > 0:
# compute completion's mrope_positions on-the-fly
dst_start = mrope_pos_ptr
dst_end = mrope_pos_ptr + completion_part_len
self.mrope_positions_cpu[:, dst_start:dst_end] = \
MRotaryEmbedding.get_next_input_positions_tensor(
req.mrope_position_delta,
context_len=num_computed_tokens +
prompt_part_len,
seq_len=num_computed_tokens +
prompt_part_len +
completion_part_len,
)
mrope_pos_ptr += completion_part_len
def _prepare_sampling(
self,
scheduler_output: "SchedulerOutput",
@@ -618,9 +742,12 @@ class GPUModelRunner:
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata, self.vllm_config):
positions = self.mrope_positions[:, :num_input_tokens] \
if self.model_config.uses_mrope \
else self.positions[:num_input_tokens]
hidden_states = self.model(
input_ids=input_ids,
positions=self.positions[:num_input_tokens],
positions=positions,
kv_caches=self.kv_caches,
attn_metadata=None,
inputs_embeds=inputs_embeds,
@@ -707,9 +834,12 @@ class GPUModelRunner:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
with set_forward_context(None, self.vllm_config):
positions = self.mrope_positions[:, :num_tokens] \
if self.model_config.uses_mrope \
else self.positions[:num_tokens]
hidden_states = model(
input_ids=input_ids,
positions=self.positions[:num_tokens],
positions=positions,
kv_caches=kv_caches,
attn_metadata=None,
inputs_embeds=inputs_embeds,