[CI] Heavy refactoring of Voxtral multimodal audio model tests (#34294)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -96,3 +96,5 @@ albumentations==1.4.6
|
||||
transformers==4.57.3
|
||||
# Pin HF Hub version
|
||||
huggingface-hub==0.36.2
|
||||
# Pin Mistral Common
|
||||
mistral-common[image,audio]==1.9.1
|
||||
|
||||
@@ -419,7 +419,6 @@ class HfRunner:
|
||||
self.tokenizer: "PreTrainedTokenizer | PreTrainedTokenizerFast" = (
|
||||
AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
)
|
||||
@@ -430,7 +429,6 @@ class HfRunner:
|
||||
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_name,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if skip_tokenizer_init:
|
||||
|
||||
@@ -4,16 +4,18 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from mistral_common.audio import Audio
|
||||
from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
|
||||
from mistral_common.protocol.instruct.messages import UserMessage
|
||||
from transformers import VoxtralForConditionalGeneration
|
||||
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
from ....conftest import AudioTestAssets
|
||||
from ....utils import RemoteOpenAIServer
|
||||
from ...utils import check_logprobs_close
|
||||
from .test_ultravox import MULTI_AUDIO_PROMPT, run_multi_audio_test
|
||||
from .vlm_utils import model_utils
|
||||
|
||||
MODEL_NAME = "mistralai/Voxtral-Mini-3B-2507"
|
||||
MISTRAL_FORMAT_ARGS = [
|
||||
@@ -26,40 +28,21 @@ MISTRAL_FORMAT_ARGS = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def server(request, audio_assets: AudioTestAssets):
|
||||
args = [
|
||||
"--enforce-eager",
|
||||
"--limit-mm-per-prompt",
|
||||
json.dumps({"audio": len(audio_assets)}),
|
||||
] + MISTRAL_FORMAT_ARGS
|
||||
|
||||
with RemoteOpenAIServer(
|
||||
MODEL_NAME, args, env_dict={"VLLM_AUDIO_FETCH_TIMEOUT": "30"}
|
||||
) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
def _get_prompt(audio_assets, question):
|
||||
def _get_prompt(audio_assets: AudioTestAssets, question: str) -> list[int]:
|
||||
"""Build a token-ID prompt via mistral_common for vLLM offline inference."""
|
||||
tokenizer = MistralTokenizer.from_pretrained(MODEL_NAME)
|
||||
|
||||
audios = [
|
||||
Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
|
||||
for i in range(len(audio_assets))
|
||||
Audio.from_file(str(asset.get_local_path()), strict=False)
|
||||
for asset in audio_assets
|
||||
]
|
||||
audio_chunks = [
|
||||
AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
|
||||
]
|
||||
|
||||
text_chunk = TextChunk(text=question)
|
||||
messages = [UserMessage(content=[*audio_chunks, text_chunk]).to_openai()]
|
||||
|
||||
messages = [
|
||||
UserMessage(content=[*audio_chunks, TextChunk(text=question)]).to_openai()
|
||||
]
|
||||
return tokenizer.apply_chat_template(messages=messages)
|
||||
|
||||
|
||||
@@ -77,7 +60,7 @@ def test_models_with_multiple_audios(
|
||||
vllm_prompt = _get_prompt(audio_assets, MULTI_AUDIO_PROMPT)
|
||||
run_multi_audio_test(
|
||||
vllm_runner,
|
||||
[(vllm_prompt, [audio.audio_and_sample_rate for audio in audio_assets])],
|
||||
[(vllm_prompt, [a.audio_and_sample_rate for a in audio_assets])], # type: ignore[list-item]
|
||||
MODEL_NAME,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
@@ -86,30 +69,142 @@ def test_models_with_multiple_audios(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_online_serving(client, audio_assets: AudioTestAssets):
|
||||
"""Exercises online serving with/without chunked prefill enabled."""
|
||||
def test_online_serving(vllm_runner, audio_assets: AudioTestAssets):
|
||||
"""Two-layer accuracy and serving validation using Mistral format.
|
||||
|
||||
def asset_to_chunk(asset):
|
||||
1. Offline vLLM greedy output (runs first to avoid CUDA fork issues
|
||||
with multiprocessing - see vlm_utils/core.py).
|
||||
2. Online OpenAI-compatible API output must match offline — validates
|
||||
that the serving path (chat template, audio encoding, tokenization)
|
||||
does not corrupt anything.
|
||||
|
||||
Steps run sequentially so each releases the GPU before the next starts.
|
||||
"""
|
||||
|
||||
question = f"What's happening in these {len(audio_assets)} audio clips?"
|
||||
max_tokens = 10
|
||||
audio_data = [asset.audio_and_sample_rate for asset in audio_assets]
|
||||
|
||||
vllm_prompt = _get_prompt(audio_assets, question)
|
||||
with vllm_runner(
|
||||
MODEL_NAME,
|
||||
dtype="half",
|
||||
enforce_eager=True,
|
||||
tokenizer_mode="mistral",
|
||||
config_format="mistral",
|
||||
load_format="mistral",
|
||||
limit_mm_per_prompt={"audio": len(audio_assets)},
|
||||
) as vllm_model:
|
||||
offline_outputs = vllm_model.generate_greedy(
|
||||
[vllm_prompt],
|
||||
max_tokens,
|
||||
audios=[audio_data],
|
||||
)
|
||||
|
||||
offline_text = offline_outputs[0][1]
|
||||
assert offline_text, "Offline vLLM inference produced empty output"
|
||||
|
||||
def _asset_to_openai_chunk(asset):
|
||||
audio = Audio.from_file(str(asset.get_local_path()), strict=False)
|
||||
audio.format = "wav"
|
||||
audio_dict = AudioChunk.from_audio(audio).to_openai()
|
||||
return audio_dict
|
||||
return AudioChunk.from_audio(audio).to_openai()
|
||||
|
||||
audio_chunks = [asset_to_chunk(asset) for asset in audio_assets]
|
||||
text = f"What's happening in these {len(audio_assets)} audio clips?"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [*audio_chunks, {"type": "text", "text": text}],
|
||||
"content": [
|
||||
*[_asset_to_openai_chunk(a) for a in audio_assets],
|
||||
{"type": "text", "text": question},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME, messages=messages, max_tokens=10
|
||||
server_args = [
|
||||
"--enforce-eager",
|
||||
"--limit-mm-per-prompt",
|
||||
json.dumps({"audio": len(audio_assets)}),
|
||||
*MISTRAL_FORMAT_ARGS,
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(
|
||||
MODEL_NAME,
|
||||
server_args,
|
||||
env_dict={"VLLM_AUDIO_FETCH_TIMEOUT": "30"},
|
||||
) as remote_server:
|
||||
client = remote_server.get_client()
|
||||
completion = client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
assert len(completion.choices) == 1
|
||||
choice = completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert choice.message.content == offline_text, (
|
||||
f"Online serving output does not match offline inference.\n"
|
||||
f" Online: {choice.message.content!r}\n"
|
||||
f" Offline: {offline_text!r}"
|
||||
)
|
||||
|
||||
assert len(chat_completion.choices) == 1
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.message.content == "In the first audio clip, you hear a brief"
|
||||
assert choice.finish_reason == "length"
|
||||
|
||||
def test_hf_reference(hf_runner, vllm_runner, audio_assets: AudioTestAssets):
|
||||
"""Compare vLLM Mistral-format output against HF Transformers reference.
|
||||
|
||||
Instead of requiring an exact text match (which is brittle across
|
||||
attention backends), we compare per-token logprobs using the standard
|
||||
check_logprobs_close helper: when tokens diverge at a position, each
|
||||
runner's chosen token must appear in the other's top-k logprobs.
|
||||
|
||||
Marked xfail(strict=False) so remaining edge-case mismatches
|
||||
don't block CI.
|
||||
"""
|
||||
question = f"What's happening in these {len(audio_assets)} audio clips?"
|
||||
max_tokens = 10
|
||||
num_logprobs = 5
|
||||
audio_data = [asset.audio_and_sample_rate for asset in audio_assets]
|
||||
|
||||
vllm_prompt = _get_prompt(audio_assets, question)
|
||||
with vllm_runner(
|
||||
MODEL_NAME,
|
||||
dtype="half",
|
||||
enforce_eager=True,
|
||||
tokenizer_mode="mistral",
|
||||
config_format="mistral",
|
||||
load_format="mistral",
|
||||
limit_mm_per_prompt={"audio": len(audio_assets)},
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
[vllm_prompt],
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
audios=[audio_data],
|
||||
)
|
||||
assert vllm_outputs[0][1], "vLLM inference produced empty output"
|
||||
|
||||
with hf_runner(
|
||||
MODEL_NAME,
|
||||
dtype="half",
|
||||
auto_cls=VoxtralForConditionalGeneration,
|
||||
) as hf_model:
|
||||
hf_model = model_utils.voxtral_patch_hf_runner(hf_model)
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
[question],
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
audios=[audio_data],
|
||||
)
|
||||
assert hf_outputs[0][1], "HF Transformers produced empty output"
|
||||
|
||||
print(
|
||||
f"HF Reference Comparison\n"
|
||||
f" vLLM: {vllm_outputs[0][1]!r}\n"
|
||||
f" HF: {hf_outputs[0][1]!r}"
|
||||
)
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=vllm_outputs,
|
||||
outputs_1_lst=hf_outputs,
|
||||
name_0="vllm",
|
||||
name_1="hf",
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ from mistral_common.protocol.transcription.request import (
|
||||
TranscriptionRequest,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy
|
||||
|
||||
from vllm import LLM, EngineArgs, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
@@ -26,7 +27,7 @@ ENGINE_CONFIG = dict(
|
||||
load_format="mistral",
|
||||
tokenizer_mode="mistral",
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.4,
|
||||
gpu_memory_utilization=0.9,
|
||||
)
|
||||
|
||||
|
||||
@@ -148,6 +149,9 @@ async def test_voxtral_realtime_generator(audio_assets, tokenizer, async_engine)
|
||||
|
||||
output_tokens_list.append(output_tokens)
|
||||
|
||||
texts = [tokenizer.decode(output_tokens) for output_tokens in output_tokens_list]
|
||||
texts = [
|
||||
tokenizer.decode(output_tokens, special_token_policy=SpecialTokenPolicy.IGNORE)
|
||||
for output_tokens in output_tokens_list
|
||||
]
|
||||
texts[1] = texts[1].replace("a base hit", "OBS").replace("oh my", "oh, my")
|
||||
assert texts == EXPECTED_TEXT
|
||||
|
||||
@@ -1215,3 +1215,91 @@ def tarsier_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
hf_processor.patch_size = vision_encoder_info.get_patch_size()
|
||||
|
||||
return hf_model
|
||||
|
||||
|
||||
def voxtral_patch_hf_runner(hf_model: "HfRunner") -> "HfRunner":
|
||||
"""Patch HfRunner for Voxtral's conversation-based processor.
|
||||
|
||||
Two issues in HfRunner require patching:
|
||||
|
||||
1. VoxtralProcessor requires ``apply_chat_template()`` with conversation
|
||||
dicts (accepting ``url``, ``path``, or ``base64`` audio) rather than
|
||||
the standard ``processor(text=, audio=, sampling_rate=)`` interface.
|
||||
2. HfRunner.get_inputs cannot handle multi-audio per prompt because it
|
||||
mis-unpacks ``[(arr1, sr1), (arr2, sr2)]`` via a ``len == 2`` check.
|
||||
|
||||
We override ``get_inputs`` to build conversation dicts and call
|
||||
``apply_chat_template`` directly, bypassing both issues. We also wrap
|
||||
``model.generate`` to strip prompt tokens before decoding, since
|
||||
HfRunner.generate calls batch_decode on the full sequence (prompt +
|
||||
generated).
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
|
||||
import soundfile as sf
|
||||
|
||||
processor = hf_model.processor
|
||||
|
||||
def _audio_to_base64(audio_array, sample_rate: int) -> str:
|
||||
"""Encode a numpy audio array as a base64 WAV string."""
|
||||
buf = io.BytesIO()
|
||||
sf.write(buf, audio_array, int(sample_rate), format="WAV")
|
||||
return base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
|
||||
def patched_get_inputs(prompts, images=None, videos=None, audios=None, **kwargs):
|
||||
all_inputs = []
|
||||
for i, prompt in enumerate(prompts):
|
||||
content: list[dict] = []
|
||||
|
||||
if audios is not None and audios[i] is not None:
|
||||
items = audios[i]
|
||||
if not isinstance(items, list):
|
||||
items = [items]
|
||||
for item in items:
|
||||
if isinstance(item, (list, tuple)) and len(item) == 2:
|
||||
arr, sr = item
|
||||
else:
|
||||
arr, sr = item, 16_000
|
||||
content.append(
|
||||
{
|
||||
"type": "audio",
|
||||
"base64": _audio_to_base64(arr, sr),
|
||||
}
|
||||
)
|
||||
|
||||
content.append({"type": "text", "text": prompt})
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
[{"role": "user", "content": content}]
|
||||
)
|
||||
if hasattr(inputs, "to"):
|
||||
inputs = inputs.to(dtype=hf_model.dtype)
|
||||
all_inputs.append(inputs)
|
||||
|
||||
return all_inputs
|
||||
|
||||
_orig_generate = hf_model.model.generate
|
||||
|
||||
def patched_generate(*args, **kwargs):
|
||||
"""Strip prompt tokens so only generated tokens are decoded."""
|
||||
input_ids = kwargs.get("input_ids")
|
||||
if input_ids is None and args:
|
||||
input_ids = args[0]
|
||||
prompt_len = input_ids.shape[1] if input_ids is not None else 0
|
||||
|
||||
output = _orig_generate(*args, **kwargs)
|
||||
if prompt_len:
|
||||
if isinstance(output, torch.Tensor):
|
||||
output = output[:, prompt_len:]
|
||||
else:
|
||||
# GenerateDecoderOnlyOutput - trim sequences but preserve
|
||||
# scores/logits so generate_greedy_logprobs_limit can
|
||||
# extract per-token logprobs.
|
||||
output.sequences = output.sequences[:, prompt_len:]
|
||||
return output
|
||||
|
||||
hf_model.get_inputs = patched_get_inputs # type: ignore[method-assign, assignment]
|
||||
hf_model.model.generate = patched_generate # type: ignore[method-assign]
|
||||
return hf_model
|
||||
|
||||
@@ -184,22 +184,42 @@ def get_text_token_prompts(
|
||||
text_prompt: str | None
|
||||
token_prompt: list[int]
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
images = parsed_data.get("image", [])
|
||||
request = ChatCompletionRequest(
|
||||
messages=[
|
||||
UserMessage(
|
||||
content=[
|
||||
TextChunk(text=""),
|
||||
*(ImageChunk(image=image) for image in images),
|
||||
]
|
||||
),
|
||||
]
|
||||
# ChatCompletionRequest only supports ImageChunk natively;
|
||||
# for other modalities (e.g. audio), fall back to the model's
|
||||
# own dummy inputs builder which knows the right placeholders.
|
||||
has_non_image = any(
|
||||
k != "image" and count > 0 for k, count in mm_counts.items()
|
||||
)
|
||||
res = tokenizer.mistral.encode_chat_completion(request)
|
||||
|
||||
# Mistral does not support decode_tokens with skip_special_tokens=False
|
||||
text_prompt = None
|
||||
token_prompt = res.tokens
|
||||
if has_non_image:
|
||||
inputs = dummy_inputs.get_dummy_processor_inputs(
|
||||
model_config.max_model_len,
|
||||
mm_counts,
|
||||
)
|
||||
text_prompt = None
|
||||
token_prompt = (
|
||||
inputs.prompt
|
||||
if isinstance(inputs.prompt, list)
|
||||
else tokenizer.encode(inputs.prompt, add_special_tokens=False)
|
||||
)
|
||||
else:
|
||||
images = parsed_data.get("image", [])
|
||||
request = ChatCompletionRequest(
|
||||
messages=[
|
||||
UserMessage(
|
||||
content=[
|
||||
TextChunk(text=""),
|
||||
*(ImageChunk(image=image) for image in images),
|
||||
]
|
||||
),
|
||||
]
|
||||
)
|
||||
res = tokenizer.mistral.encode_chat_completion(request)
|
||||
|
||||
# Mistral does not support decode_tokens with
|
||||
# skip_special_tokens=False
|
||||
text_prompt = None
|
||||
token_prompt = res.tokens
|
||||
else:
|
||||
inputs = dummy_inputs.get_dummy_processor_inputs(
|
||||
model_config.max_model_len,
|
||||
|
||||
@@ -291,6 +291,34 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
|
||||
# skip validation here
|
||||
...
|
||||
|
||||
def _apply_hf_processor_mm_only(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
|
||||
audios = processor_data.get("audios", [])
|
||||
if not isinstance(audios, list):
|
||||
audios = [audios]
|
||||
|
||||
audio_config = processor._audio_processor.audio_config
|
||||
audio_tensors: list[torch.Tensor] = []
|
||||
for audio in audios:
|
||||
audio = np.asarray(audio, dtype=np.float32).ravel()
|
||||
if not audio_config.is_streaming:
|
||||
audio = processor._audio_processor.pad(
|
||||
audio,
|
||||
processor.sampling_rate,
|
||||
audio_config.is_streaming,
|
||||
)
|
||||
audio_tensors.append(torch.tensor(audio))
|
||||
|
||||
result = BatchFeature({"audio_arrays": audio_tensors} if audio_tensors else {})
|
||||
result.update(passthrough_data)
|
||||
return result
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import functools
|
||||
import logging
|
||||
import math
|
||||
from dataclasses import replace
|
||||
from functools import partial
|
||||
@@ -30,11 +31,20 @@ from vllm.v1.attention.backend import (
|
||||
subclass_attention_backend_with_overrides,
|
||||
)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
|
||||
try:
|
||||
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
|
||||
except ImportError:
|
||||
AiterFlashAttentionBackend = None
|
||||
from vllm.v1.attention.backends.rocm_attn import RocmAttentionBackend
|
||||
from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
from .utils import make_layers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CausalRMSNorm = partial(RMSNorm, eps=1e-5)
|
||||
|
||||
|
||||
@@ -122,6 +132,13 @@ def create_whisper_attention_backend_with_block_pooling(
|
||||
num_kv_heads=kv_cache_spec.num_kv_heads // block_pool_size,
|
||||
)
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
# Override model_config-derived values with the actual
|
||||
# encoder values from kv_cache_spec
|
||||
self.num_heads_kv = kv_cache_spec.num_kv_heads
|
||||
self.headdim = kv_cache_spec.head_size
|
||||
# num_heads_q for the encoder is the same as num_kv_heads
|
||||
# (no GQA in whisper encoder)
|
||||
self.num_heads_q = kv_cache_spec.num_kv_heads
|
||||
|
||||
def build(
|
||||
self,
|
||||
@@ -192,13 +209,36 @@ def create_whisper_attention_backend_with_block_pooling(
|
||||
output_block_scale,
|
||||
)
|
||||
|
||||
if not issubclass(underlying_attn_backend, FlashAttentionBackend):
|
||||
_SUPPORTED_BACKENDS = tuple(
|
||||
b
|
||||
for b in (
|
||||
AiterFlashAttentionBackend,
|
||||
FlashAttentionBackend,
|
||||
RocmAttentionBackend,
|
||||
TritonAttentionBackend,
|
||||
)
|
||||
if b is not None
|
||||
)
|
||||
|
||||
if not issubclass(underlying_attn_backend, _SUPPORTED_BACKENDS):
|
||||
raise NotImplementedError(
|
||||
f"{underlying_attn_backend} is not yet supported."
|
||||
"Contributions to support more backends are much "
|
||||
"appreciated."
|
||||
)
|
||||
|
||||
if not issubclass(underlying_attn_backend, FlashAttentionBackend):
|
||||
logger.info(
|
||||
"Using %s for Whisper causal attention with block pooling. "
|
||||
"This backend was recently enabled for this model. "
|
||||
"If you encounter any accuracy or performance issues, "
|
||||
"please open an issue at "
|
||||
"https://github.com/vllm-project/vllm/issues "
|
||||
"with the [ROCm] tag so it can be triaged by the "
|
||||
"appropriate team.",
|
||||
underlying_attn_backend.get_name(),
|
||||
)
|
||||
|
||||
attn_backend = subclass_attention_backend_with_overrides(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
@@ -209,14 +249,14 @@ def create_whisper_attention_backend_with_block_pooling(
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
cache_dtype_str: (
|
||||
2,
|
||||
cache_dtype_str: underlying_attn_backend.get_kv_cache_shape(
|
||||
num_blocks,
|
||||
# we stretch each block by `block_pool_size`
|
||||
block_size * block_pool_size,
|
||||
num_kv_heads // block_pool_size,
|
||||
head_size,
|
||||
), # TODO: generalize to other backends
|
||||
cache_dtype_str,
|
||||
),
|
||||
"forward_includes_kv_cache_update": True,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -43,8 +43,8 @@ class MistralReasoningParser(BaseThinkingReasoningParser):
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
self.start_token_id = tokenizer.tokenizer.get_control_token(self.start_token)
|
||||
self.end_token_id = tokenizer.tokenizer.get_control_token(self.end_token)
|
||||
self.start_token_id = tokenizer.tokenizer.get_special_token(self.start_token)
|
||||
self.end_token_id = tokenizer.tokenizer.get_special_token(self.end_token)
|
||||
|
||||
if self.start_token_id is None or self.end_token_id is None:
|
||||
raise RuntimeError(
|
||||
|
||||
@@ -517,7 +517,7 @@ class MistralTokenizer(TokenizerLike):
|
||||
return [self.tokenizer.id_to_piece(token_id) for token_id in ids]
|
||||
|
||||
non_skip_special_tokens_ids = {
|
||||
self.tokenizer.get_control_token(SpecialTokens.tool_calls),
|
||||
self.tokenizer.get_special_token(SpecialTokens.tool_calls),
|
||||
}
|
||||
if isinstance(self.instruct, InstructTokenizerV13):
|
||||
if self.instruct.BEGIN_THINK:
|
||||
|
||||
@@ -425,8 +425,13 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
|
||||
sliding_window_configs: set[tuple[int, int] | None] = set()
|
||||
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
for layer in layers.values():
|
||||
assert isinstance(layer.impl, AiterFlashAttentionImpl)
|
||||
for name, layer in layers.items():
|
||||
if name not in layer_names:
|
||||
continue
|
||||
assert isinstance(layer.impl, AiterFlashAttentionImpl), (
|
||||
"Aiter Flash Attention Metadata Builder can only be used "
|
||||
"with Aiter Flash Attention Impl."
|
||||
)
|
||||
sliding_window_configs.add(layer.impl.sliding_window)
|
||||
|
||||
while len(sliding_window_configs) > 0:
|
||||
|
||||
Reference in New Issue
Block a user