[Bugfix] Fix GLM-ASR audio encoder RoPE dim (#32540)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2026-01-18 19:17:59 +08:00
committed by GitHub
parent c826c72a96
commit 38bf2ffb21
2 changed files with 40 additions and 30 deletions

View File

@@ -89,6 +89,34 @@ def run_gemma3n(question: str, audio_count: int) -> ModelRequestData:
)
# GLM-ASR
def run_glmasr(question: str, audio_count: int) -> ModelRequestData:
model_name = "zai-org/GLM-ASR-Nano-2512"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# GLM-ASR uses <|pad|> token for audio
audio_placeholder = "<|pad|>" * audio_count
messages = [{"role": "user", "content": f"{audio_placeholder}{question}"}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count},
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# Granite Speech
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
# NOTE - the setting in this example are somewhat different from what is
@@ -358,34 +386,6 @@ def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
)
# GLM-ASR
def run_glmasr(question: str, audio_count: int) -> ModelRequestData:
model_name = "zai-org/GLM-ASR-Nano-2512"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# GLM-ASR uses <|pad|> token for audio
audio_placeholder = "<|pad|>" * audio_count
messages = [{"role": "user", "content": f"{audio_placeholder}{question}"}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count},
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# Whisper
def run_whisper(question: str, audio_count: int) -> ModelRequestData:
assert audio_count == 1, "Whisper only support single audio input per prompt"

View File

@@ -181,6 +181,12 @@ class GlmAsrEncoderAttention(nn.Module):
# Use vLLM's ApplyRotaryEmb CustomOp
# enforce_enable=True ensures the op is always enabled (important for ViT)
rope_params = getattr(config, "rope_parameters", None)
if rope_params:
partial_rotary_factor = rope_params.get("partial_rotary_factor", 0.5)
else:
partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
self.rotary_dim = int(self.head_dim * partial_rotary_factor)
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
# Use vLLM's MMEncoderAttention for hardware-optimized attention
@@ -226,8 +232,12 @@ class GlmAsrEncoderAttention(nn.Module):
# Apply rotary position embeddings using vLLM's ApplyRotaryEmb
# ApplyRotaryEmb expects x: [batch, seq, heads, head_dim]
# cos/sin: [seq_len, rotary_dim/2]
q = self.apply_rotary_emb(q, rotary_pos_emb_cos, rotary_pos_emb_sin)
k = self.apply_rotary_emb(k, rotary_pos_emb_cos, rotary_pos_emb_sin)
q[..., : self.rotary_dim] = self.apply_rotary_emb(
q[..., : self.rotary_dim], rotary_pos_emb_cos, rotary_pos_emb_sin
)
k[..., : self.rotary_dim] = self.apply_rotary_emb(
k[..., : self.rotary_dim], rotary_pos_emb_cos, rotary_pos_emb_sin
)
# MMEncoderAttention expects [batch, seq, num_heads, head_dim]
# It handles GQA internally via repeat_interleave