[Renderer] Separate out RendererConfig from ModelConfig (#30145)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -34,7 +34,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
|
||||
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, RendererConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||
@@ -840,7 +840,7 @@ class GraniteSpeechForConditionalGeneration(
|
||||
def get_generation_prompt(
|
||||
cls,
|
||||
audio: np.ndarray,
|
||||
model_config: ModelConfig,
|
||||
renderer_config: RendererConfig,
|
||||
stt_config: SpeechToTextConfig,
|
||||
language: str | None,
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
@@ -861,7 +861,7 @@ class GraniteSpeechForConditionalGeneration(
|
||||
else:
|
||||
raise ValueError(f"Unsupported task type {task_type}")
|
||||
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
tokenizer = cached_tokenizer_from_config(renderer_config)
|
||||
chat = [dict(role="user", content=user_prompt)]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
chat,
|
||||
@@ -882,10 +882,10 @@ class GraniteSpeechForConditionalGeneration(
|
||||
cls,
|
||||
audio_duration_s: float,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig,
|
||||
renderer_config: RendererConfig,
|
||||
) -> int | None:
|
||||
"""Get the number of audio tokens for an audio duration in sec."""
|
||||
processor = cached_processor_from_config(model_config)
|
||||
processor = cached_processor_from_config(renderer_config)
|
||||
hop_length = processor.audio_processor.melspec_kwargs["hop_length"]
|
||||
proj_win_size = processor.audio_processor.projector_window_size
|
||||
ds_rate = processor.audio_processor.projector_downsample_rate
|
||||
@@ -903,7 +903,9 @@ class GraniteSpeechForConditionalGeneration(
|
||||
|
||||
@classmethod
|
||||
def get_speech_to_text_config(
|
||||
cls, model_config: ModelConfig, task_type: str
|
||||
cls,
|
||||
renderer_config: RendererConfig,
|
||||
task_type: str,
|
||||
) -> SpeechToTextConfig:
|
||||
"""Get the stt config for this model."""
|
||||
# Default settings are reasonable for this model and we don't currently
|
||||
|
||||
Reference in New Issue
Block a user