Add the support for the qwen3 next model (a hybrid attention model). (#24526)
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -53,9 +53,9 @@ from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
|
||||
get_dtype_size, is_pin_memory_available, round_up,
|
||||
supports_dynamo)
|
||||
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
|
||||
is_pin_memory_available, round_up, supports_dynamo)
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
create_fast_prefill_custom_backend,
|
||||
@@ -324,6 +324,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.hidden_size,
|
||||
dtype=self.dtype,
|
||||
numpy=False)
|
||||
self.num_draft_tokens = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int32)
|
||||
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int64)
|
||||
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.uses_mrope:
|
||||
@@ -663,6 +667,31 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Refresh batch metadata with any pending updates.
|
||||
self.input_batch.refresh_metadata()
|
||||
|
||||
def _update_states_after_model_execute(
|
||||
self, output_token_ids: torch.Tensor) -> None:
|
||||
"""Update the cached states after model execution.
|
||||
|
||||
This is used for MTP/EAGLE for hybrid models, as in linear attention,
|
||||
only the last token's state is kept. In MTP/EAGLE, for draft tokens
|
||||
the state are kept util we decide how many tokens are accepted for
|
||||
each sequence, and a shifting is done during the next iteration
|
||||
based on the number of accepted tokens.
|
||||
"""
|
||||
if not self.model_config.is_hybrid or not self.speculative_config:
|
||||
return
|
||||
|
||||
# Find the number of accepted tokens for each sequence.
|
||||
num_accepted_tokens = (torch.cat(
|
||||
[
|
||||
output_token_ids,
|
||||
torch.full((output_token_ids.size(0), 1),
|
||||
-1,
|
||||
device=output_token_ids.device),
|
||||
],
|
||||
dim=1) == -1).int().argmax(-1).cpu().numpy()
|
||||
for i, num_tokens in enumerate(num_accepted_tokens):
|
||||
self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
|
||||
|
||||
def _init_mrope_positions(self, req_state: CachedRequestState):
|
||||
image_grid_thw = []
|
||||
video_grid_thw = []
|
||||
@@ -936,6 +965,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# We will ignore the sampled tokens from the partial requests.
|
||||
# TODO: Support prompt logprobs.
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
num_draft_tokens = None
|
||||
spec_decode_metadata = None
|
||||
else:
|
||||
# Get the number of draft tokens for each request.
|
||||
@@ -950,6 +980,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
spec_decode_metadata = self._calc_spec_decode_metadata(
|
||||
num_draft_tokens, cu_num_tokens)
|
||||
logits_indices = spec_decode_metadata.logits_indices
|
||||
self.num_draft_tokens.np[:num_reqs] = num_draft_tokens
|
||||
self.num_draft_tokens.np[num_reqs:].fill(0)
|
||||
self.num_draft_tokens.copy_to_gpu()
|
||||
|
||||
logits_indices_padded = None
|
||||
if self.cache_config.kv_sharing_fast_prefill:
|
||||
@@ -964,6 +997,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_computed_tokens_cpu = (
|
||||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
||||
spec_decode_common_attn_metadata = None
|
||||
if use_spec_decode:
|
||||
self.num_accepted_tokens.np[:num_reqs] = (
|
||||
self.input_batch.num_accepted_tokens_cpu[:num_reqs])
|
||||
self.num_accepted_tokens.np[num_reqs:].fill(1)
|
||||
self.num_accepted_tokens.copy_to_gpu()
|
||||
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
# in the same group share the same metadata.
|
||||
@@ -1034,10 +1072,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
builder,
|
||||
)
|
||||
|
||||
extra_attn_metadata_args = {}
|
||||
if use_spec_decode and isinstance(builder,
|
||||
GDNAttentionMetadataBuilder):
|
||||
extra_attn_metadata_args = dict(
|
||||
num_accepted_tokens=self.num_accepted_tokens.
|
||||
gpu[:num_reqs],
|
||||
num_draft_tokens=self.num_draft_tokens.gpu[:num_reqs],
|
||||
)
|
||||
|
||||
attn_metadata_i = builder.build(
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
**extra_attn_metadata_args)
|
||||
|
||||
for layer_name in attn_group.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
@@ -1814,6 +1861,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
sampling_metadata,
|
||||
)
|
||||
sampler_output.sampled_token_ids = output_token_ids
|
||||
self._update_states_after_model_execute(output_token_ids)
|
||||
|
||||
return sampler_output
|
||||
|
||||
@@ -2644,13 +2692,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Note: Overriding max_query_len to be the prefill tokens
|
||||
max_query_len = num_prefill_tokens
|
||||
elif uniform_decode:
|
||||
assert not create_mixed_batch
|
||||
num_reqs = cdiv(num_tokens, max_query_len)
|
||||
num_reqs = num_tokens // max_query_len
|
||||
assert num_reqs <= max_num_reqs, \
|
||||
"Do not capture num_reqs > max_num_reqs for uniform batch"
|
||||
num_scheduled_tokens_list = [max_query_len] * num_reqs
|
||||
if num_tokens % max_query_len != 0:
|
||||
num_scheduled_tokens_list[-1] = num_tokens % max_query_len
|
||||
num_scheduled_tokens_list[-1] += num_tokens % max_query_len
|
||||
else:
|
||||
num_reqs = min(num_tokens, max_num_reqs)
|
||||
min_tokens_per_req = num_tokens // num_reqs
|
||||
@@ -3297,6 +3344,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
is_spec_decode=bool(self.vllm_config.speculative_config),
|
||||
logitsprocs=self.input_batch.logitsprocs,
|
||||
is_pooling_model=self.is_pooling_model,
|
||||
num_speculative_tokens=(
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
if self.vllm_config.speculative_config else 0),
|
||||
)
|
||||
|
||||
def _allocate_kv_cache_tensors(
|
||||
@@ -3647,7 +3697,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase)
|
||||
if len(mamba_layers) > 0:
|
||||
if self.vllm_config.speculative_config is not None:
|
||||
if (self.vllm_config.speculative_config is not None
|
||||
and self.vllm_config.model_config.hf_config.model_type
|
||||
not in ["qwen3_next"]):
|
||||
raise NotImplementedError(
|
||||
"Mamba with speculative decoding is not supported yet.")
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
@@ -3666,7 +3718,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
dtypes=mamba_module.get_state_dtype(),
|
||||
block_size=max_model_len,
|
||||
page_size_padded=page_size_padded,
|
||||
mamba_type=mamba_module.mamba_type)
|
||||
mamba_type=mamba_module.mamba_type,
|
||||
num_speculative_blocks=(
|
||||
self.speculative_config.num_speculative_tokens
|
||||
if self.speculative_config else 0),
|
||||
)
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
|
||||
Reference in New Issue
Block a user