[Voxtral] Add new streaming arch (#32861)

Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2026-01-23 12:41:52 +01:00
committed by GitHub
parent 5da4c7d789
commit 3f3f89529d
9 changed files with 767 additions and 313 deletions

View File

@@ -404,6 +404,7 @@ class LlamaModel(nn.Module):
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
**extra_layer_kwargs,
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
@@ -422,7 +423,9 @@ class LlamaModel(nn.Module):
):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, residual = layer(
positions, hidden_states, residual, **extra_layer_kwargs
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(